'''JIT cache directory management + batch file writer.
Port of `src/srdatalog/codegen/target_jit/jit_file.nim`:
- `getJitCacheDir` / `ensureJitCacheDir`
- `JitBatchManager` with `addKernel` / `writeBatchFiles` /
`writeSchemaHeader` / `writeKernelDeclHeader`
Python uses `~/.cache/srdatalog/jit/<project>_<hash>/` (vs Nim's
`~/.cache/nim/jit/...`) so the two toolchains don't clobber each
other's caches. Callers that need Nim-compatible output can pass an
explicit cache dir via the `cache_dir` arg.
The one-shot entry point `write_jit_project()` glues everything
together — given the string outputs from `main_file.py` + per-rule
complete runner emissions, it lays out the full .cpp tree on disk.
Set `SRDATALOG_SKIP_JIT_REGEN=1` to reuse existing files (debugging
mode — matches Nim's behavior).
'''
from __future__ import annotations
import hashlib
import os
from typing import TypedDict
[docs]
class JitProjectLayout(TypedDict):
dir: str
main: str
batches: list[str]
schema_header: str
kernel_header: str
# Match Nim's RULES_PER_BATCH default (jit_file.nim:30). Override via
# env `SRDATALOG_RULES_PER_BATCH=N`.
_DEFAULT_RULES_PER_BATCH = int(os.environ.get("SRDATALOG_RULES_PER_BATCH", "8"))
MAX_BATCH_FILES = 16
# -----------------------------------------------------------------------------
# Cache directory
# -----------------------------------------------------------------------------
def _project_hash(project_name: str) -> str:
'''4-hex-digit project hash — same shape as Nim's `(hash(name) and 0xFFFF).toHex(4)`.
We use sha256 (first 4 hex digits) so the mapping is reproducible across
Python versions (`hash()` is randomized).'''
h = hashlib.sha256(project_name.encode()).hexdigest()
return h[:4].upper()
[docs]
def get_jit_cache_dir(project_name: str, base: str | None = None) -> str:
'''`~/.cache/srdatalog/jit/<project>_<hash4>/`. `base` overrides
`~/.cache/srdatalog` — e.g. tests pass a tmpdir.'''
if base is None:
home = os.environ.get("HOME", "/tmp")
base = os.path.join(home, ".cache", "srdatalog")
return os.path.join(base, "jit", f"{project_name}_{_project_hash(project_name)}")
[docs]
def ensure_jit_cache_dir(project_name: str, base: str | None = None) -> str:
'''Create the cache dir if needed; return the path.'''
d = get_jit_cache_dir(project_name, base)
os.makedirs(d, exist_ok=True)
return d
[docs]
def get_batch_file_name(batch_index: int) -> str:
return f"jit_batch_{batch_index}.cpp"
# -----------------------------------------------------------------------------
# JIT batch common header / footer (matches jit_file.nim constants)
# -----------------------------------------------------------------------------
JIT_COMMON_INCLUDES = """\
// JIT-Generated Rule Kernel Batch
// This file is auto-generated - do not edit
#define SRDATALOG_JIT_BATCH // Guard: exclude host-side helpers from JIT compilation
// Main project header - includes all necessary boost/hana, etc.
#include "srdatalog.h"
#include <cstdint>
#include <cooperative_groups.h>
// JIT-specific headers (relative to generalized_datalog/)
#include "gpu/device_sorted_array_index.h"
#include "gpu/runtime/output_context.h"
#include "gpu/runtime/jit/intersect_handles.h"
#include "gpu/runtime/jit/jit_executor.h"
#include "gpu/runtime/jit/materialized_join.h"
#include "gpu/runtime/jit/ws_infrastructure.h" // WCOJTask, WCOJTaskQueue, ChunkedOutputContext
#include "gpu/runtime/query.h" // For DeviceRelationType
namespace cg = cooperative_groups;
// Make JIT helpers visible without full namespace qualification
using SRDatalog::GPU::JIT::intersect_handles;
"""
JIT_FILE_FOOTER = """
// End of JIT batch file
"""
# -----------------------------------------------------------------------------
# JitBatchManager
# -----------------------------------------------------------------------------
[docs]
class JitBatchManager:
'''Shards per-rule runner code across fixed-size batch files,
then writes them + the schema/kernel headers to the cache dir.
Mirrors Nim's `JitBatchManager` in `jit_file.nim:100-270`.
'''
def __init__(
self,
project_name: str,
rules_per_batch: int = _DEFAULT_RULES_PER_BATCH,
cache_base: str | None = None,
):
self.project_name = project_name
self.rules_per_batch = rules_per_batch
self.cache_base = cache_base
self.batches: dict[int, list[str]] = {}
self.rule_names: list[str] = []
self.rule_count = 0
self.schema_definitions = ""
self.db_type_alias = ""
self.kernel_declarations = ""
# --- public accumulators ---
[docs]
def set_schema_definitions(self, schema_defs: str) -> None:
self.schema_definitions = schema_defs
[docs]
def set_db_type_alias(self, db_type: str) -> None:
self.db_type_alias = db_type
[docs]
def add_kernel_declaration(self, decl_code: str) -> None:
self.kernel_declarations += decl_code
[docs]
def add_kernel(self, kernel_code: str, rule_name: str | None = None) -> None:
'''Add one `JitRunner_<rule>` struct (complete with __global__
kernels) to the next batch slot.'''
batch_idx = self.rule_count // self.rules_per_batch
self.batches.setdefault(batch_idx, []).append(kernel_code)
if rule_name is not None:
self.rule_names.append(rule_name)
self.rule_count += 1
[docs]
def batch_count(self) -> int:
return len(self.batches)
# --- content generation (no I/O) ---
[docs]
def generate_batch_file(
self,
batch_idx: int,
extra_headers: list[str] | None = None,
) -> str:
if batch_idx not in self.batches:
return ""
extra = extra_headers or []
code = JIT_COMMON_INCLUDES
if extra:
code += "// Extra index headers (from registered plugins)\n"
for h in extra:
code += f'#include "{h}"\n'
code += "\n"
if self.schema_definitions:
code += "// Project-specific schema definitions (inlined)\n"
code += self.schema_definitions
code += "\n\n"
if self.db_type_alias:
code += "// DB type alias for JitRunner type derivation\n"
code += self.db_type_alias
code += "\n\n"
code += f"// Batch {batch_idx} - {len(self.batches[batch_idx])} rules\n\n"
for kernel_code in self.batches[batch_idx]:
code += kernel_code
code += "\n"
code += JIT_FILE_FOOTER
return code
# --- I/O ---
def _skip_regen(self) -> bool:
return os.environ.get("SRDATALOG_SKIP_JIT_REGEN", "") == "1"
def _write_if_changed(self, path: str, content: str) -> bool:
'''Write content to `path` only if it differs from what's already
there. Returns True if the file was written. Avoids touching
mtimes when the contents haven't changed — builds that depend on
timestamp-based rebuild (ninja, make) won't re-run.'''
try:
with open(path) as f:
if f.read() == content:
return False
except FileNotFoundError:
pass
with open(path, "w") as f:
f.write(content)
return True
[docs]
def write_batch_files(
self,
extra_headers: list[str] | None = None,
) -> list[str]:
'''Write all shards + headers to the cache dir. Returns the list
of batch file paths (headers are written but not returned — they
aren't compiled directly).'''
cache_dir = ensure_jit_cache_dir(self.project_name, self.cache_base)
skip = self._skip_regen()
self.write_schema_header()
self.write_kernel_decl_header()
written: list[str] = []
for batch_idx in sorted(self.batches.keys()):
path = os.path.join(cache_dir, get_batch_file_name(batch_idx))
if skip and os.path.exists(path):
written.append(path)
continue
self._write_if_changed(path, self.generate_batch_file(batch_idx, extra_headers))
written.append(path)
return written
# -----------------------------------------------------------------------------
# One-shot project writer
# -----------------------------------------------------------------------------
[docs]
def write_jit_project(
project_name: str,
main_file_content: str,
per_rule_runners: list[tuple[str, str]],
*,
schema_definitions: str = "",
db_type_alias: str = "",
extra_headers: list[str] | None = None,
cache_base: str | None = None,
main_file_name: str = "main.cpp",
) -> JitProjectLayout:
'''Lay out the full .cpp tree for a project.
Args:
project_name: cache dir name (e.g. "TrianglePlan_DB").
main_file_content: output of `main_file.gen_main_file_content`.
per_rule_runners: list of `(rule_name, full_runner_cpp)` tuples —
typically the `full` returned by `complete_runner.gen_complete_runner`
for each non-materialized `ExecutePipeline`. Gets sharded across
jit_batch_N.cpp files.
schema_definitions: optional project schema header content.
db_type_alias: optional DB type alias string (inlined into each
batch file for template derivation).
extra_headers: per-rule plugin headers (e.g.
"gpu/device_2level_index.h") #include'd into every batch file.
cache_base: override `~/.cache/srdatalog` (tests pass a tmpdir).
main_file_name: output name for the top-level main file.
Returns: dict with keys `dir`, `main`, `batches` (list[str]),
`schema_header`, `kernel_header` (possibly "") — every path absolute.
'''
mgr = JitBatchManager(project_name, cache_base=cache_base)
if schema_definitions:
mgr.set_schema_definitions(schema_definitions)
if db_type_alias:
mgr.set_db_type_alias(db_type_alias)
for rule_name, runner_code in per_rule_runners:
mgr.add_kernel(runner_code, rule_name)
batch_paths = mgr.write_batch_files(extra_headers)
schema_path = mgr.write_schema_header()
kernel_path = mgr.write_kernel_decl_header()
cache_dir = ensure_jit_cache_dir(project_name, cache_base)
main_path = os.path.join(cache_dir, main_file_name)
mgr._write_if_changed(main_path, main_file_content)
return {
"dir": cache_dir,
"main": main_path,
"batches": batch_paths,
"schema_header": schema_path,
"kernel_header": kernel_path,
}