'''Public compile entry point for the dialect-based codegen.
`compile_pipeline(ep, target='cuda')` emits a complete C++ JIT batch
file by routing the pipeline through the IIR-sorted-array dialect +
target.cuda emit, wrapped in the dialect's envelope helpers (file
prelude, banner, functor struct, view declarations, footer).
`compile_kernel_body(ep, ...)` is the lower-level entry: emits just
the operator() body (view_decls + dialect-emitted kernel logic),
parameterized by phase (count vs materialize) and output-var bindings.
The runner emit (Phase N3) calls into this for each kernel it wraps.
Pipeline shapes the dialect doesn't (yet) handle raise loudly via
`lower_scan_pipeline` — `_supported_pipeline()` is the authoritative
scope statement. Adding coverage for a new shape means adding a
lowering rule, not a fallback.
The byte-equivalence harnesses:
- tests/test_byte_equivalence_jit.py — materialize-phase kernel
functor against the upstream Nim goldens.
- tests/test_count_phase_byte_equivalence.py — count-phase body
against the legacy `jit_pipeline` count emit (the only
spec for count-phase shape, since the runner files contain
count bodies but no isolated count-only goldens exist).
See:
- docs/stage2_emitter_audit.md — the per-milestone migration plan.
- docs/ir_lowering_semantics.md — the formal lowering rules.
- docs/design_principles.md — discipline rules for the rewrite.
'''
from __future__ import annotations
from typing import Literal
import srdatalog.ir.mir.types as m
Target = Literal['cuda']
[docs]
def compile_pipeline(ep: m.ExecutePipeline, *, target: Target = 'cuda') -> str:
'''Compile an MIR ExecutePipeline to target C++ source via the dialect.
Raises ValueError on unsupported targets. Raises (via
`lower_scan_pipeline`) on pipeline shapes the dialect doesn't
cover — there is no legacy fallback.
'''
if target != 'cuda':
raise ValueError(f'compile_pipeline: unsupported target {target!r}')
from srdatalog.ir.codegen.cuda.envelope import (
assign_handle_positions,
emit_full_file,
)
pipeline = list(ep.pipeline)
assign_handle_positions(pipeline)
# Standalone jit_batch shape uses handle_idx-based view slots; runner
# contexts (compile_runner / compile_kernel_body) default to positional.
body = compile_kernel_body(ep, is_counting=False, slot_mode='handle_idx')
return emit_full_file(ep, body)
[docs]
def compile_runner(
ep: m.ExecutePipeline,
db_type_name: str,
rel_index_types: dict[str, str] | None = None,
) -> str:
'''Compile an ExecutePipeline to its full per-rule runner — the
`JitRunner_<rule>` struct + kernel definitions + out-of-line phase
methods + execute(). Production output: this is what
`jit_runner.<rule>.cpp` golden files capture.
The dialect's `codegen.cuda.runner` module owns the runner emission
surface. Most pieces (phase methods, execute, BG variants, fused
kernel) currently delegate to legacy helpers in
`ir.codegen.cuda.complete_runner`; later milestones (N2/N4/N5/N6/N8)
collapse them into native dialect emission.
Kernel *bodies* (count + materialize) already route through
`compile_kernel_body` when `_dialect_safe_kernel` holds — see the
swap inside `complete_runner._gen_kernel_count` /
`_gen_kernel_materialize`.
The byte-equivalence gate (`tests/test_runner_byte_equivalence.py`)
anchors this entry point to the upstream goldens throughout the
migration.
'''
from srdatalog.ir.codegen.cuda.runner import emit_runner_full
return emit_runner_full(ep, db_type_name, rel_index_types=rel_index_types)
[docs]
def compile_kernel_body(
ep: m.ExecutePipeline,
*,
is_counting: bool,
output_var_name: str = 'output',
output_vars: dict[str, str] | None = None,
slot_mode: str = 'positional',
rel_index_types: dict[str, str] | None = None,
tiled_cartesian: bool = False,
bg_enabled: bool = False,
) -> str:
'''Emit the operator() body for one kernel — view_decls followed by
the dialect-emitted kernel logic. Caller is responsible for the
envelope (file prelude, kernel signature, OutputContext setup).
Parameters mirror the legacy `_make_kernel_ctx` knobs the runner
emit (complete_runner.py) twiddles per kernel:
is_counting: True selects count-phase emit (`emit_direct()` with
no args, AddCount-style increments).
output_var_name: name of the OutputContext variable used by the
single-output InsertInto path (legacy default 'output';
runner uses 'output_ctx' in count phase, 'output_ctx_0' in
materialize).
output_vars: per-relation output-var override map. Multi-head
rules use this so each InsertInto resolves to its own dest's
OutputContext. Pass `{rel_name: '__skip_counting__'}` to
suppress count-phase emission for secondary outputs.
slot_mode: 'positional' (default, matches `jit_runner.<rule>.cpp`
production goldens) or 'handle_idx' (matches the standalone
`jit_batch.<rule>.cpp` test fixtures emitted via
`compile_pipeline`). See `emit_view_declarations` docstring.
rel_index_types: per-relation custom index type (e.g.,
`Device2LevelIndex`). Used to compute per-spec view_counts via
`relation.d2l` (and any future index dialect) so positional
slots advance by 2 per FULL_VER D2L source — matching legacy
`compute_view_slot_offsets`. Pass {} or None for plain DSAI.
'''
from srdatalog.ir.codegen.cuda.emit import EmitCtx, emit
from srdatalog.ir.codegen.cuda.envelope import (
assign_handle_positions,
collect_unique_view_specs,
emit_view_declarations,
)
from srdatalog.ir.dialects.relation.d2l import view_counts_for_specs
from srdatalog.ir.dialects.relation.sorted_array.lowerings import (
LoweringCtx,
lower_scan_pipeline,
)
pipeline = list(ep.pipeline)
assign_handle_positions(pipeline)
view_specs = collect_unique_view_specs(pipeline)
view_counts = view_counts_for_specs(view_specs, rel_index_types or {})
view_decls, view_vars = emit_view_declarations(
view_specs,
pipeline,
slot_mode=slot_mode,
view_counts=view_counts,
)
# Split the combined view_vars dict into name-only (handle_idx ->
# view_var) and base-slot (handle_idx -> slot) maps. The envelope
# emits both into the same dict via a `__base__<idx>` sentinel.
name_map = {k: v for k, v in view_vars.items() if k.isdigit()}
base_map = {
k.removeprefix('__base__'): int(v) for k, v in view_vars.items() if k.startswith('__base__')
}
lower_ctx = LoweringCtx(
view_var_names=name_map,
is_counting=is_counting,
output_var=output_var_name,
output_var_overrides=dict(output_vars) if output_vars else {},
rel_index_types=dict(rel_index_types) if rel_index_types else {},
view_slot_bases=base_map,
dedup_hash=ep.dedup_hash,
tiled_cartesian=tiled_cartesian,
bg_enabled=bg_enabled,
)
iir = lower_scan_pipeline(pipeline, lower_ctx)
emit_ctx = EmitCtx(indent_level=4)
return view_decls + emit(iir, emit_ctx)
__all__ = ['Target', 'compile_kernel_body', 'compile_pipeline', 'compile_runner']