Source code for srdatalog.ir.codegen.cuda.build.loader

'''Python ↔ C++ bridge: dlopen a compiled `.so` and call into it.

Phase 9 — closes the end-to-end loop started by Phases 6-8:
  - main_file.py emits the C++ source
  - cache.py writes the tree to disk
  - compiler.py turns it into a .so
  - THIS module turns that .so into callable Python.

The C++ runner methods (`<Ruleset>_Runner::run`, `load_data`, etc.)
are templates and can't be called directly from ctypes. Instead, the
user generates a small `extern "C"` shim that wraps the templated
calls into C-ABI entry points; this module handles the Python side
of that contract.

Public API:
  - `EntryPoint` — argtypes/restype spec for one C symbol.
  - `gen_runtime_shim_template(...)` — produces a starter shim.cpp
    the user fills in. Returned as a string so the caller can write
    it into the cache dir alongside main.cpp / jit_batch_N.cpp.
  - `JitRuntime` — ctypes.CDLL wrapper. Resolves symbols, applies
    the argspec, exposes typed `.call(name, *args)` / attribute
    shortcuts.
  - `build_and_load(...)` — one-shot: takes the project_result dict
    from `cache.write_jit_project` + a CompilerConfig, runs the
    full build, returns a ready `JitRuntime` on success.

Thread-safety: a single `JitRuntime` wraps one dlopen'd handle.
ctypes serializes calls through the C ABI so concurrent calls from
multiple Python threads are safe iff the underlying C function is.
We don't protect against that — it's the user's shim.
'''

from __future__ import annotations

import ctypes
import os
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any

from srdatalog.ir.codegen.cuda.build.cache import JitProjectLayout

# -----------------------------------------------------------------------------
# EntryPoint spec
# -----------------------------------------------------------------------------


[docs] @dataclass class EntryPoint: '''Binding spec for one `extern "C"` function in the loaded .so. `name` — the exported symbol name (post-`extern "C"` name-mangling). `argtypes` — ctypes argument types, in positional order. `restype` — return type. Default `c_int` suits "return 0 on success". `errcheck` — optional callable run after the call for result mapping / error translation. Same protocol as ctypes.Function.errcheck: `(result, func, arguments) -> final_result_or_raise`. ''' name: str argtypes: list[Any] = field(default_factory=list) restype: Any = ctypes.c_int errcheck: Any = None
def _apply_errcheck_default(result, func, arguments): '''Default errcheck for entry points returning int: nonzero → raise. Keeps the common pattern (return 0 on OK, nonzero on error) terse. ''' if result != 0: raise RuntimeError(f"{func.__name__} returned nonzero status {result}") return result # ----------------------------------------------------------------------------- # JitRuntime # -----------------------------------------------------------------------------
[docs] class JitRuntime: '''Wrapper around `ctypes.CDLL` with a declared entry-point map. Usage: rt = JitRuntime("libmyproj.so", entry_points=[ EntryPoint("srdatalog_init", restype=ctypes.c_int), EntryPoint("srdatalog_run", argtypes=[ctypes.c_char_p]), EntryPoint("srdatalog_get_size", argtypes=[ctypes.c_char_p], restype=ctypes.c_uint64, errcheck=None), ]) rt.srdatalog_init() rt.srdatalog_run(b"/data") n = rt.srdatalog_get_size(b"Path") ''' def __init__( self, lib_path: str, entry_points: Sequence[EntryPoint] = (), *, mode: int = ctypes.RTLD_GLOBAL, ): if not os.path.exists(lib_path): raise FileNotFoundError(f"JitRuntime: {lib_path} not found") self.lib_path = lib_path # RTLD_GLOBAL makes symbols available to subsequently loaded libs # — important if the user dlopens dependent .so's. self._cdll = ctypes.CDLL(lib_path, mode=mode) self._entry_points: dict[str, EntryPoint] = {} for ep in entry_points: self.bind(ep)
[docs] def bind(self, ep: EntryPoint) -> Any: '''Resolve `ep.name` in the library and apply argtypes / restype / errcheck. Returns the bound ctypes function.''' try: fn = getattr(self._cdll, ep.name) except AttributeError as e: raise AttributeError(f"JitRuntime: symbol {ep.name!r} not found in {self.lib_path}") from e fn.argtypes = list(ep.argtypes) fn.restype = ep.restype # errcheck semantics: # None → apply default (raise-on-nonzero) when restype is c_int # False → explicitly opt out of any errcheck # callable → use as-is if callable(ep.errcheck): fn.errcheck = ep.errcheck elif ep.errcheck is None and ep.restype is ctypes.c_int: fn.errcheck = _apply_errcheck_default # else: ep.errcheck is False or restype isn't c_int — leave # ctypes' default (no errcheck) in place. self._entry_points[ep.name] = ep return fn
[docs] def __getattr__(self, name: str) -> Any: '''Attribute-style access to bound entry points. `rt.some_fn(args...)` resolves through the CDLL the first time and caches the typed function. Unbound symbols still surface as untyped ctypes functions — mirrors CDLL's default behavior. ''' if name.startswith("_"): raise AttributeError(name) return getattr(self._cdll, name)
[docs] def close(self) -> None: '''Drop the ctypes reference. Actual dlclose timing depends on Python's garbage collector — ctypes doesn't expose direct dlclose, so this is best-effort.''' self._cdll = None # type: ignore[assignment]
# ----------------------------------------------------------------------------- # Shim template generator # -----------------------------------------------------------------------------
[docs] def gen_runtime_shim_template( ruleset_name: str, db_blueprint_name: str, dest_relations: Sequence[tuple[str, str]] = (), ) -> str: '''Emit a starter `runtime_shim.cpp` with `extern "C"` entry points the Python loader binds. Args: ruleset_name: matches the `_Runner` struct name (so the shim calls `<ruleset>_Runner::load_data`, `run`). db_blueprint_name: user-declared blueprint type (e.g., "TriangleDBBlueprint"). dest_relations: list of `(symbol_suffix, cpp_type_name)` tuples exposing per-relation size queries. For each entry, the shim emits `uint64_t srdatalog_size_<suffix>()`. The template is a STARTING POINT — the caller may hand-edit the body (e.g., to add custom result-extraction logic) before handing it to `cache.write_jit_project`. The shim assumes the user has already included the main file (which defines the `_Runner` struct and DB blueprint) via `#include "main.cpp"` or similar — callers that shard the build differently need to adjust that line. ''' header = [ "// Auto-generated C-ABI shim for Python ctypes loader", "// Edit the bodies to suit your runtime shape.", "", '#include "srdatalog.h"', '#include "gpu/runtime/gpu.h"', '#include "gpu/runtime/query.h"', "", "// main.cpp provides the `_Runner` struct + DB blueprint.", '#include "main.cpp"', "", "namespace {", "using HostDB = SRDatalog::AST::SemiNaiveDatabase<" + db_blueprint_name + ">;", "HostDB* g_host_db = nullptr;", "", "// Device DB is templated on DeviceRelationType; hold via `void*` + cast on use.", "void* g_device_db = nullptr;", "} // anon namespace", "", 'extern "C" {', "", "int srdatalog_init() {", " try {", " SRDatalog::GPU::init_cuda();", " return 0;", " } catch (...) { return 1; }", "}", "", "int srdatalog_run(const char* data_dir) {", " try {", " if (g_host_db) { delete g_host_db; g_host_db = nullptr; }", " g_host_db = new HostDB();", f" {ruleset_name}_Runner::load_data(*g_host_db, std::string(data_dir));", " auto device_db = SRDatalog::GPU::copy_host_to_device(*g_host_db);", f" {ruleset_name}_Runner::run(device_db);", " return 0;", " } catch (const std::exception& e) {", ' std::cerr << "srdatalog_run: " << e.what() << std::endl;', " return 1;", " }", "}", "", "int srdatalog_shutdown() {", " if (g_host_db) { delete g_host_db; g_host_db = nullptr; }", " return 0;", "}", ] # Per-destination size queries. for suffix, cpp_ty in dest_relations: header += [ "", f"uint64_t srdatalog_size_{suffix}() {{", " if (!g_host_db) return 0;", " try {", f" auto& rel = get_relation_by_schema<{cpp_ty}, FULL_VER>(*g_host_db);", " return rel.size();", " } catch (...) { return 0; }", "}", ] header += ["", "} // extern \"C\""] return "\n".join(header) + "\n"
# ----------------------------------------------------------------------------- # One-shot build + load # -----------------------------------------------------------------------------
[docs] def build_and_load( project_result: JitProjectLayout, entry_points: Sequence[EntryPoint], compiler_config: Any | None = None, # CompilerConfig; avoid circular import *, required_artifact: str | None = None, ) -> JitRuntime: '''Compile the Phase-7 project tree (via Phase 8) and dlopen the resulting .so. Raises on compile/link failure with the captured stderr in the message — the common failure mode during dev. `required_artifact` lets the caller override the default artifact name (e.g. the runner library name for a well-known runtime). ''' from srdatalog.ir.codegen.cuda.build.compiler import CompilerConfig, compile_jit_project config = compiler_config or CompilerConfig() build = compile_jit_project(project_result, config) if not build.ok(): errors: list[str] = [] for r in build.compile_results: if r.returncode != 0: errors.append(f"[compile {r.output}] {r.stderr.strip() or '(no stderr)'}") if build.link_result and build.link_result.returncode != 0: errors.append(f"[link] {build.link_result.stderr.strip() or '(no stderr)'}") raise RuntimeError("build failed:\n" + "\n".join(errors)) artifact = required_artifact or build.artifact if not artifact or not os.path.exists(artifact): raise FileNotFoundError(f"build succeeded but artifact missing: {artifact!r}") return JitRuntime(artifact, entry_points)