'''C++ compiler wrapper — turns the on-disk `.cpp` tree from Phase 7
into a loadable `.so`.
Nim offloads compilation to its C++ backend via `{.compile:}` pragmas;
that backend shells out to the system compiler. Python has no such
backend, so we invoke a compiler (clang++/g++/nvcc) directly via
subprocess.
Design notes:
- `CompilerConfig` carries the full toolchain + flag set. All command
assembly flows through `_build_compile_cmd` / `_build_link_cmd` —
unit tests can verify argv without invoking a real compiler.
- Per-batch `.cpp` compilations run in parallel (ThreadPoolExecutor —
subprocess bottleneck is I/O, threads are fine).
- Hash-based build cache: we sha256 the source + the CLI flags and
write a `.stamp` sidecar file alongside each object file. If the
stamp matches, compile is a no-op. Independent of the source
mtime-guard in `cache.py` (content didn't change → no rewrite →
no mtime bump) but works even when the source file is touched.
- `SRDATALOG_JIT_COMPILE_JOBS=N` overrides parallelism (0 = cpu_count).
- `SRDATALOG_JIT_SKIP_COMPILE=1` skips compile+link entirely — the
stamp cache is trusted. Mirrors `SRDATALOG_SKIP_JIT_REGEN` on the
cache side.
Default include paths / link flags for `generalized_datalog` are
supplied by `srdatalog.runtime.runtime_include_paths()` etc. — this
module just consumes a `CompilerConfig`.
'''
from __future__ import annotations
import concurrent.futures
import hashlib
import os
import shutil
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from srdatalog.ir.codegen.cuda.build.cache import JitProjectLayout
# -----------------------------------------------------------------------------
# Config
# -----------------------------------------------------------------------------
[docs]
@dataclass
class CompilerConfig:
'''Compile + link configuration.
Defaults are minimal — `cxx_std=c++23` is what the runtime headers
require. Callers add include/link/libs via the list fields.
`extra_sources` is for object files / shared libs to feed into the
final link (e.g., pre-built runtime artifacts).
'''
cxx: str = "" # empty → auto-detect
cxx_std: str = "c++23"
include_paths: list[str] = field(default_factory=list)
defines: list[str] = field(default_factory=list)
cxx_flags: list[str] = field(default_factory=list)
link_flags: list[str] = field(default_factory=list)
libs: list[str] = field(default_factory=list)
extra_sources: list[str] = field(default_factory=list)
output_dir: str = "" # empty → use cache dir
jobs: int = 0 # 0 → env or cpu_count
shared: bool = True # shared .so vs static
[docs]
def resolved_cxx(self) -> str:
return self.cxx or _detect_cxx()
[docs]
def resolved_jobs(self) -> int:
if self.jobs > 0:
return self.jobs
env = os.environ.get("SRDATALOG_JIT_COMPILE_JOBS", "")
if env:
return max(1, int(env))
return os.cpu_count() or 1
def _detect_cxx() -> str:
'''First hit wins: $CXX → clang++ → g++. nvcc only if requested
explicitly via config (it's a wrapper, not a drop-in C++ compiler).'''
cxx = os.environ.get("CXX", "")
if cxx:
return cxx
for candidate in ("clang++", "g++"):
if shutil.which(candidate):
return candidate
raise RuntimeError("no C++ compiler found — set $CXX or CompilerConfig.cxx")
# -----------------------------------------------------------------------------
# Result types
# -----------------------------------------------------------------------------
[docs]
@dataclass
class CompileResult:
'''One compile invocation (source → object or link)'''
command: list[str]
output: str
returncode: int = 0
stdout: str = ""
stderr: str = ""
cached: bool = False # True = skipped via stamp
elapsed_sec: float = 0.0
[docs]
@dataclass
class BuildResult:
'''Full `compile_jit_project` outcome.'''
artifact: str # path to .so (or .a)
compile_results: list[CompileResult] = field(default_factory=list)
link_result: CompileResult | None = None
elapsed_sec: float = 0.0
[docs]
def ok(self) -> bool:
if self.link_result and self.link_result.returncode != 0:
return False
return all(r.returncode == 0 for r in self.compile_results)
# -----------------------------------------------------------------------------
# Command assembly (pure — no subprocess)
# -----------------------------------------------------------------------------
def _base_cxx_flags(config: CompilerConfig) -> list[str]:
out = [f"-std={config.cxx_std}"]
if config.shared:
out.append("-fPIC")
for d in config.defines:
out.append(f"-D{d}")
for i in config.include_paths:
out.append(f"-I{i}")
out.extend(config.cxx_flags)
return out
def _build_compile_cmd(
source: str,
output: str,
config: CompilerConfig,
) -> list[str]:
cxx = config.resolved_cxx()
# Flags go BEFORE the source so language switches like `-x cuda` take
# effect (clang warns `after last input file has no effect` otherwise).
cmd = [cxx] + _base_cxx_flags(config) + ["-c", source, "-o", output]
return cmd
def _build_link_cmd(
objects: list[str],
output: str,
config: CompilerConfig,
) -> list[str]:
cxx = config.resolved_cxx()
cmd = [cxx]
if config.shared:
cmd.append("-shared")
cmd += objects + config.extra_sources
cmd += ["-o", output]
cmd += config.link_flags
for lib in config.libs:
cmd.append(f"-l{lib}")
return cmd
# -----------------------------------------------------------------------------
# Stamp-based cache
# -----------------------------------------------------------------------------
def _stamp_digest(source_path: str, argv: list[str]) -> str:
'''Hash the source contents + the exact argv to invoke the compiler.
If either changed (rename of flag, new -D, different include order),
the stamp misses and we recompile.'''
h = hashlib.sha256()
try:
with open(source_path, "rb") as f:
h.update(f.read())
except FileNotFoundError:
h.update(b"<missing>")
# argv serialized as null-separated bytes — avoids ambiguity on
# flags that contain literal spaces.
for a in argv:
h.update(a.encode())
h.update(b"\x00")
return h.hexdigest()
def _stamp_path(object_path: str) -> str:
return object_path + ".stamp"
def _check_stamp(object_path: str, digest: str) -> bool:
'''True iff the stamp file exists AND its content matches `digest`
AND the object file itself exists (don't trust a stale stamp whose
object was deleted).'''
if not os.path.exists(object_path):
return False
p = _stamp_path(object_path)
try:
with open(p) as f:
return f.read().strip() == digest
except FileNotFoundError:
return False
def _write_stamp(object_path: str, digest: str) -> None:
with open(_stamp_path(object_path), "w") as f:
f.write(digest)
# -----------------------------------------------------------------------------
# Single-file compile
# -----------------------------------------------------------------------------
[docs]
def compile_cpp(
source: str,
output: str,
config: CompilerConfig,
) -> CompileResult:
'''Compile one `.cpp` → `.o`. Short-circuits via stamp cache when
the source + argv haven't changed. Never raises on compile error
— returns a `CompileResult` with `returncode != 0` so the caller
can aggregate.'''
os.makedirs(os.path.dirname(output) or ".", exist_ok=True)
cmd = _build_compile_cmd(source, output, config)
digest = _stamp_digest(source, cmd)
if _check_stamp(output, digest):
return CompileResult(command=cmd, output=output, cached=True)
if os.environ.get("SRDATALOG_JIT_SKIP_COMPILE", "") == "1":
return CompileResult(command=cmd, output=output, cached=True)
start = time.perf_counter()
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
)
elapsed = time.perf_counter() - start
if proc.returncode == 0:
_write_stamp(output, digest)
return CompileResult(
command=cmd,
output=output,
returncode=proc.returncode,
stdout=proc.stdout,
stderr=proc.stderr,
elapsed_sec=elapsed,
)
[docs]
def link_shared(
objects: list[str],
output: str,
config: CompilerConfig,
) -> CompileResult:
'''Link objects + extra_sources into a shared library.'''
os.makedirs(os.path.dirname(output) or ".", exist_ok=True)
cmd = _build_link_cmd(objects, output, config)
if os.environ.get("SRDATALOG_JIT_SKIP_COMPILE", "") == "1":
return CompileResult(command=cmd, output=output, cached=True)
start = time.perf_counter()
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
)
elapsed = time.perf_counter() - start
return CompileResult(
command=cmd,
output=output,
returncode=proc.returncode,
stdout=proc.stdout,
stderr=proc.stderr,
elapsed_sec=elapsed,
)
# -----------------------------------------------------------------------------
# Top-level: compile a Phase-7 project tree
# -----------------------------------------------------------------------------
def _artifact_name(project_dir: str, shared: bool) -> str:
stem = os.path.basename(project_dir.rstrip("/"))
ext = ".so" if shared else ".a"
return os.path.join(project_dir, f"lib{stem}{ext}")
[docs]
def compile_jit_project(
project_result: JitProjectLayout,
config: CompilerConfig | None = None,
*,
use_ninja: bool | None = None,
) -> BuildResult:
'''Compile the `.cpp` tree written by `cache.write_jit_project` into
a shared library. `project_result` is the dict returned by that
function — we pull `main` and `batches` out and feed them in.
Returns a `BuildResult`. The caller inspects `.ok()` and
`.compile_results`/`.link_result` for errors — this function never
raises on a compile/link error.
`use_ninja` selects the backend:
- True / None (default): emit a build.ninja + PCH rule and invoke
the ninja binary (from the `ninja` PyPI wheel). Best wall time
on multi-TU projects because `srdatalog.h` is precompiled once
and reused across every shard.
- False: fall through to the ThreadPoolExecutor path below (one
subprocess per TU, no PCH). Useful on hosts without ninja or
for debugging a single TU's compile command.
`SRDATALOG_JIT_NO_NINJA=1` forces use_ninja=False regardless of the
argument.
'''
config = config or CompilerConfig()
if use_ninja is None:
use_ninja = os.environ.get("SRDATALOG_JIT_NO_NINJA", "") != "1"
if use_ninja:
try:
from srdatalog.ir.codegen.cuda.build.compiler_ninja import compile_jit_project_ninja
return compile_jit_project_ninja(project_result, config)
except RuntimeError as e:
# ninja binary not found; fall through with a one-line notice.
print(f"[compile_jit_project] ninja unavailable ({e}); falling back to pool")
project_dir = str(project_result["dir"])
main_cpp = str(project_result["main"])
batches = list(project_result["batches"])
output_dir = config.output_dir or project_dir
os.makedirs(output_dir, exist_ok=True)
sources: list[str] = [main_cpp, *batches]
objects: list[str] = []
results: list[CompileResult] = []
start = time.perf_counter()
# Parallel compile — ThreadPoolExecutor is enough since subprocess
# doesn't hold the GIL during `run`.
jobs = config.resolved_jobs()
with concurrent.futures.ThreadPoolExecutor(max_workers=jobs) as pool:
future_to_src = {}
for src in sources:
obj = os.path.join(
output_dir,
Path(src).stem + ".o",
)
objects.append(obj)
future_to_src[pool.submit(compile_cpp, src, obj, config)] = src
# Preserve input order so reports read sensibly.
done = {f.result().output: f.result() for f in concurrent.futures.as_completed(future_to_src)}
results = [done[obj] for obj in objects]
all_ok = all(r.returncode == 0 for r in results)
link_result: CompileResult | None = None
if all_ok:
artifact = _artifact_name(output_dir, config.shared)
link_result = link_shared(objects, artifact, config)
else:
artifact = ""
elapsed = time.perf_counter() - start
return BuildResult(
artifact=artifact if link_result and link_result.returncode == 0 else "",
compile_results=results,
link_result=link_result,
elapsed_sec=elapsed,
)