'''target.cuda — full-file envelope emission.
The dialect-emitted kernel body sits inside a fixed-shape envelope:
JIT_FILE_PRELUDE (constant header)
+ banner(rule_name, num_handles)
+ functor_start(rule_name, ...)
+ [optional DedupTable struct]
+ view_declarations + <BODY> (operator() body)
+ functor_end()
+ "\\n"
+ JIT_FILE_FOOTER (constant footer)
Pure string emitters — no algorithm dispatch, no feature-flag-driven
branching beyond the dedup-hash struct injection. Lives in the dialect
because the envelope shape is target-specific (CUDA cooperative-groups
signature, `__device__` qualifier, etc.); a `target.cpp_tbb` envelope
would emit a different shape from the same MIR pipeline.
The legacy `ir/dialects/target/cuda/` modules still own their own copies of these
helpers during the Stage 2 transition (they are imported by the
byte-equivalence harness which compares the legacy emitter against
the dialect path). Once the legacy inner-body emitters are deleted,
the legacy copies go with them.
'''
from __future__ import annotations
from dataclasses import dataclass
import srdatalog.ir.mir.types as m
# -----------------------------------------------------------------------------
# File prelude / footer (byte-identical to the legacy strings)
# -----------------------------------------------------------------------------
JIT_FILE_PRELUDE = """\
// JIT-Generated Rule Kernel Batch
// This file is auto-generated - do not edit
#define SRDATALOG_JIT_BATCH // Guard: exclude host-side helpers from JIT compilation
// Main project header - includes all necessary boost/hana, etc.
#include "srdatalog.h"
#include <cstdint>
#include <cooperative_groups.h>
// JIT-specific headers (relative to generalized_datalog/)
#include "gpu/device_sorted_array_index.h"
#include "gpu/runtime/output_context.h"
#include "gpu/runtime/jit/intersect_handles.h"
#include "gpu/runtime/jit/jit_executor.h"
#include "gpu/runtime/jit/materialized_join.h"
#include "gpu/runtime/jit/ws_infrastructure.h" // WCOJTask, WCOJTaskQueue, ChunkedOutputContext
#include "gpu/runtime/query.h" // For DeviceRelationType
namespace cg = cooperative_groups;
// Make JIT helpers visible without full namespace qualification
using SRDatalog::GPU::JIT::intersect_handles;
"""
JIT_FILE_FOOTER = """
// End of JIT batch file
"""
# -----------------------------------------------------------------------------
# Pipeline walks (handle assignment + counting)
# -----------------------------------------------------------------------------
def _assign_handle_positions_rec(node: m.MirNode, offset_box: list[int]) -> None:
'''Recursive helper for `assign_handle_positions`. `offset_box` is a
one-element list used as a mutable counter (Python closures can't
reassign captured ints).'''
if isinstance(node, m.ColumnSource | m.Scan | m.Aggregate | m.Negation):
node.handle_start = offset_box[0]
offset_box[0] += 1
elif isinstance(node, m.ColumnJoin | m.CartesianJoin):
node.handle_start = offset_box[0]
for src in node.sources:
_assign_handle_positions_rec(src, offset_box)
elif isinstance(node, m.BalancedScan):
node.handle_start = offset_box[0]
_assign_handle_positions_rec(node.source1, offset_box)
_assign_handle_positions_rec(node.source2, offset_box)
elif isinstance(node, m.PositionedExtract):
for src in node.sources:
_assign_handle_positions_rec(src, offset_box)
[docs]
def assign_handle_positions(ops: list[m.MirNode]) -> None:
'''Assign `handle_start` to every source-bearing node in pipeline
order starting from 0. Mutates `ops` in place.'''
offset_box = [0]
for op in ops:
_assign_handle_positions_rec(op, offset_box)
[docs]
def count_handles(ops: list[m.MirNode]) -> int:
'''Number of `views[]` slots needed by the kernel — `max(handle_start) + 1`.'''
result = 0
for op in ops:
if isinstance(op, m.ColumnJoin | m.CartesianJoin):
for src in op.sources:
h = getattr(src, 'handle_start', -1)
result = max(result, h + 1)
elif isinstance(op, m.Scan | m.Negation | m.Aggregate):
result = max(result, getattr(op, 'handle_start', -1) + 1)
return result
[docs]
def first_dest_arity(ops: list[m.MirNode]) -> int:
'''Arity of the first InsertInto's column set. Sizes the DedupTable's
hash function (one v0..vN-1 column per parameter).'''
for op in ops:
if isinstance(op, m.InsertInto):
return len(op.vars)
return 0
# -----------------------------------------------------------------------------
# View declarations
# -----------------------------------------------------------------------------
[docs]
@dataclass
class ViewSpec:
'''(rel_name, index, version, handle_idx) — handle_idx is the handle
position of the FIRST op that referenced this view spec.'''
rel_name: str
index: list[int]
version: str
handle_idx: int
def _spec_key(rel_name: str, index: list[int], version: str = '') -> str:
'''`Rel_<cols>_<VER>` — version-suffixed so DELTA / FULL share-key
doesn't collapse them.'''
base = rel_name + '_' + '_'.join(str(c) for c in index)
return base + '_' + version if version else base
def _record_spec(
specs: list[ViewSpec],
seen: set[str],
rel: str,
idx: list[int],
ver: str,
handle: int,
) -> None:
k = _spec_key(rel, idx, ver)
if k in seen:
return
seen.add(k)
specs.append(ViewSpec(rel_name=rel, index=list(idx), version=ver, handle_idx=handle))
[docs]
def collect_unique_view_specs(ops: list[m.MirNode]) -> list[ViewSpec]:
'''Walk the pipeline and collect a deduplicated list of `ViewSpec`s in
first-occurrence order. Covers every op that references a view
(ColumnJoin, CartesianJoin, Scan, Negation, Aggregate, BalancedScan,
PositionedExtract).'''
specs: list[ViewSpec] = []
seen: set[str] = set()
for op in ops:
if isinstance(op, m.ColumnJoin | m.CartesianJoin):
for src in op.sources:
if isinstance(src, m.ColumnSource):
_record_spec(
specs,
seen,
src.rel_name,
list(src.index),
src.version.code,
src.handle_start,
)
elif isinstance(op, m.Scan | m.Negation | m.Aggregate):
_record_spec(
specs,
seen,
op.rel_name,
list(op.index),
op.version.code,
op.handle_start,
)
elif isinstance(op, m.BalancedScan):
s1, s2 = op.source1, op.source2
_record_spec(
specs,
seen,
s1.rel_name,
list(s1.index),
s1.version.code,
s1.handle_start,
)
_record_spec(
specs,
seen,
s2.rel_name,
list(s2.index),
s2.version.code,
s2.handle_start,
)
elif isinstance(op, m.PositionedExtract):
for src in op.sources:
if isinstance(src, m.ColumnSource):
_record_spec(
specs,
seen,
src.rel_name,
list(src.index),
src.version.code,
src.handle_start,
)
return specs
[docs]
def emit_view_declarations(
specs: list[ViewSpec],
pipeline: list[m.MirNode],
*,
indent_level: int = 4,
debug: bool = True,
slot_mode: str = 'handle_idx',
view_counts: list[int] | None = None,
) -> tuple[str, dict[str, str]]:
'''Emit the top-of-kernel `auto view_X = views[i];` block.
Returns `(decls_string, view_vars)`. `view_vars` maps both:
- spec key (`<rel>_<cols>_<VER>`) → view variable name
- `str(handle_idx)` → view variable name (so handle-bearing ops can
look up "which view does this handle name reference?")
`slot_mode` controls how the index into `views[]` is chosen:
- `'handle_idx'`: use `sp.handle_idx` directly (matches the
`jit_batch.<rule>.cpp` standalone-kernel goldens, which don't
apply slot-offset compaction).
- `'positional'`: use cumulative-sum-of-view_counts slot per spec
(matches the `jit_runner.<rule>.cpp` production goldens via
`compute_view_slot_offsets`).
`view_counts` (per-spec, parallel to `specs`) is the number of
physical view slots each spec consumes. Default = all 1s (DSAI).
D2L FULL_VER specs consume 2 slots (HEAD + FULL); the dialect's
view decl emits the BASE view at the first slot only (the second
slot is referenced by BG histogram via `views[base+seg]`).
'''
view_vars: dict[str, str] = {}
if not specs:
return '', view_vars
if view_counts is None:
view_counts = [1] * len(specs)
elif len(view_counts) != len(specs):
raise ValueError(
f'emit_view_declarations: view_counts length {len(view_counts)} '
f'does not match specs length {len(specs)}'
)
indent = ' ' * indent_level
code = ''
code += indent + 'using ViewType = std::remove_cvref_t<decltype(views[0])>;\n'
code += indent + 'using HandleType = ViewType::NodeHandle;\n\n'
if debug:
code += indent + f'// View declarations (deduplicated by spec, {len(specs)} unique views)\n'
spec_to_view_var: list[tuple[str, str]] = []
spec_to_base_slot: dict[str, int] = {}
positional_cursor = 0
for sp, vc in zip(specs, view_counts, strict=True):
if slot_mode == 'positional':
slot = positional_cursor
positional_cursor += vc
else:
slot = sp.handle_idx
key = _spec_key(sp.rel_name, sp.index, sp.version)
idx_str = '_'.join(str(v) for v in sp.index)
view_var = f'view_{sp.rel_name}_{idx_str}' + (f'_{sp.version}' if sp.version else '')
code += indent + f'auto {view_var} = views[{slot}];\n'
spec_to_view_var.append((key, view_var))
spec_to_base_slot[key] = slot
view_vars[key] = view_var
for op in pipeline:
if isinstance(op, m.ColumnJoin | m.CartesianJoin):
for src in op.sources:
k = _spec_key(src.rel_name, list(src.index), src.version.code)
for kv_key, view_var in spec_to_view_var:
if kv_key == k:
view_vars[str(src.handle_start)] = view_var
view_vars[f'__base__{src.handle_start}'] = str(spec_to_base_slot[k])
break
elif isinstance(op, m.Scan | m.Negation | m.Aggregate):
k = _spec_key(op.rel_name, list(op.index), op.version.code)
for kv_key, view_var in spec_to_view_var:
if kv_key == k:
view_vars[str(op.handle_start)] = view_var
view_vars[f'__base__{op.handle_start}'] = str(spec_to_base_slot[k])
break
code += '\n'
return code, view_vars
# -----------------------------------------------------------------------------
# Functor envelope
# -----------------------------------------------------------------------------
[docs]
def emit_dedup_table_struct(arity: int) -> str:
'''Emit the DedupTable struct nested inside a rule's kernel scope.
GPU hash table over `arity` 32-bit columns: `try_insert` (atomicCAS
during count phase) and `check_winner` (read-only during materialize
phase). Linear probing over up to 128 slots from an FNV-1a hash; the
capacity is power-of-2 (host-side runner zero-initializes hash_slots
between phases).
'''
code = (
' // GPU dedup hash table: full 64-bit hash + separate thread_id array\n'
' struct DedupTable {\n'
' unsigned long long* hash_slots; // full 64-bit hash per slot\n'
' uint32_t* tid_slots; // winner thread_id per slot\n'
' uint32_t capacity; // must be power of 2\n\n'
)
code += ' __device__ __forceinline__ unsigned long long compute_hash(\n'
for col in range(arity):
sep = ',' if col < arity - 1 else ')'
code += f' uint32_t v{col}{sep}\n'
code += ' {\n'
code += ' uint64_t h = 14695981039346656037ULL;\n'
for col in range(arity):
code += f' h ^= (uint64_t)v{col}; h *= 1099511628211ULL;\n'
code += ' return h | 1ULL; // ensure non-zero\n'
code += ' }\n\n'
code += ' __device__ __forceinline__ bool try_insert(\n'
code += ' uint32_t thread_id,\n'
for col in range(arity):
sep = ',' if col < arity - 1 else ')'
code += f' uint32_t v{col}{sep}\n'
code += ' {\n'
code += ' unsigned long long h = compute_hash('
for col in range(arity):
code += f'v{col}' + (', ' if col < arity - 1 else ');\n')
code += ' uint32_t base = (uint32_t)(h ^ (h >> 32)) & (capacity - 1);\n'
code += ' for (uint32_t p = 0; p < 128; p++) {\n'
code += ' uint32_t s = (base + p) & (capacity - 1);\n'
code += ' unsigned long long old = atomicCAS(&hash_slots[s], 0ULL, h);\n'
code += ' if (old == 0ULL) { tid_slots[s] = thread_id; return true; } // claimed\n'
code += ' if (old == h) return false; // same hash = duplicate\n'
code += ' // old != h: collision with different tuple -> probe next\n'
code += ' }\n'
code += ' return true; // probe overflow -> emit (safe)\n'
code += ' }\n\n'
code += ' __device__ __forceinline__ bool check_winner(\n'
code += ' uint32_t thread_id,\n'
for col in range(arity):
sep = ',' if col < arity - 1 else ')'
code += f' uint32_t v{col}{sep}\n'
code += ' {\n'
code += ' unsigned long long h = compute_hash('
for col in range(arity):
code += f'v{col}' + (', ' if col < arity - 1 else ');\n')
code += ' uint32_t base = (uint32_t)(h ^ (h >> 32)) & (capacity - 1);\n'
code += ' for (uint32_t p = 0; p < 128; p++) {\n'
code += ' uint32_t s = (base + p) & (capacity - 1);\n'
code += ' unsigned long long stored = hash_slots[s];\n'
code += ' if (stored == h) return tid_slots[s] == thread_id; // found: am I winner?\n'
code += ' if (stored == 0ULL) return true; // not found -> probe overflow, emit\n'
code += ' // different hash -> probe next (collision resolution)\n'
code += ' }\n'
code += ' return true; // probe overflow -> emit\n'
code += ' }\n'
code += ' };\n\n'
return code
[docs]
def emit_functor_start(
rule_name: str,
*,
scalar_mode: bool = False,
dedup_hash: bool = False,
) -> str:
'''Open `struct Kernel_<rule> { ... operator()(...) const {`.
When `dedup_hash=True`, operator() takes an additional
`DedupTable dedup_table` parameter.
'''
group_size = 1 if scalar_mode else 32
mode_comment = (
'// SCALAR MODE: 1 thread per row, sequential search'
if scalar_mode
else '// WARP MODE: 32 threads share 1 row, cooperative search'
)
dedup_param = ' DedupTable dedup_table,\n' if dedup_hash else ''
return (
mode_comment + '\n'
f'struct Kernel_{rule_name} {{\n'
' static constexpr int kBlockSize = 256;\n'
f' static constexpr int kGroupSize = {group_size};\n\n'
' template <typename Tile, typename Views, typename ValueType, typename Output>\n'
' __device__ void operator()(\n'
' Tile& tile,\n'
' const Views* views,\n'
' const ValueType* __restrict__ root_unique_values,\n'
' uint32_t num_unique_root_keys,\n'
' uint32_t num_root_keys,\n'
' uint32_t warp_id,\n'
' uint32_t num_warps,\n'
f'{dedup_param}'
' Output& output\n'
' ) const {\n'
)
[docs]
def emit_functor_end() -> str:
return ' }\n};\n'
# -----------------------------------------------------------------------------
# Top-level: wrap a body in the file envelope
# -----------------------------------------------------------------------------
[docs]
def emit_full_file(
ep: m.ExecutePipeline,
body: str,
*,
scalar_mode: bool = False,
) -> str:
'''Wrap a dialect-emitted operator() body in the standard file envelope.
`body` must be everything between `operator() {` and the closing `}`,
i.e. the view declarations followed by the dialect-emitted kernel
logic. Caller composes `view_decls + emit(iir, emit_ctx)`.
'''
pipeline = list(ep.pipeline)
num_handles = count_handles(pipeline)
banner = (
'// =============================================================\n'
f'// JIT-Generated Kernel Functor: {ep.rule_name}\n'
f'// Handles: {num_handles}\n'
'// =============================================================\n\n'
)
header = banner + emit_functor_start(
ep.rule_name,
scalar_mode=scalar_mode,
dedup_hash=ep.dedup_hash,
)
# Dedup table goes between the constexpr decls and the operator()
# signature, matching the legacy line-for-line placement.
if ep.dedup_hash:
arity = first_dest_arity(pipeline)
marker = 'static constexpr int kGroupSize = '
marker_pos = header.find(marker)
if marker_pos != -1:
newline_after = header.find('\n\n', marker_pos)
if newline_after != -1:
insert_at = newline_after + 2
header = header[:insert_at] + emit_dedup_table_struct(arity) + header[insert_at:]
full_kernel = header + body + emit_functor_end()
return JIT_FILE_PRELUDE + full_kernel + '\n' + JIT_FILE_FOOTER
__all__ = [
'JIT_FILE_FOOTER',
'JIT_FILE_PRELUDE',
'ViewSpec',
'assign_handle_positions',
'collect_unique_view_specs',
'count_handles',
'emit_dedup_table_struct',
'emit_full_file',
'emit_functor_end',
'emit_functor_start',
'emit_view_declarations',
'first_dest_arity',
]