'''Materialized binary-join code generation.
Port of src/srdatalog/codegen/target_jit/jit_materialized.nim.
Materialized joins take a different shape from fused joins (kernel
functors with nested loops). They emit host-side C++ that:
1. Materializes row-id pairs in buffers between join stages
2. Uses binary-search probes against already-indexed relations
3. Applies merge-path load balancing for unbalanced outputs
The generated C++ uses Thrust primitives (thrust::lower_bound /
upper_bound, exclusive_scan, gather) plus three CUDA kernels declared
in `gen_materialized_join_helpers()`.
Public API:
is_materialized_pipeline(ops) -> bool
True iff any op is a ProbeJoin. `jit_complete_runner` uses this
as an early-dispatch check; when True it emits a materialized
runner instead of a kernel functor.
gen_materialized_runner(node, db_type_name) -> str
The main entry. Given an ExecutePipeline MIR node, emits a host-
side Thrust executor:
struct JitRunner_<rule> {
using DB = <dbTypeName>;
using FirstSchema = ...;
...
static void execute(DB& db, uint32_t iteration = 0) { ... }
};
gen_materialized_join_helpers() -> str
Returns the three __global__ CUDA kernels
(probe_count_matches_kernel, probe_materialize_pairs_kernel,
gather_column_kernel) that the runner's execute() body calls.
gen_materialized_join_kernel(ops, rule_name, ctx) -> str
Legacy in-kernel variant. Not used by Nim's live code and not
exercised by our fixtures — ported for completeness.
'''
from __future__ import annotations
import srdatalog.ir.mir.types as m
from srdatalog.ir.codegen.cuda.context import CodeGenContext, ind
from srdatalog.ir.hir.types import Version
# -----------------------------------------------------------------------------
# Detection
# -----------------------------------------------------------------------------
[docs]
def is_materialized_pipeline(ops: list[m.MirNode]) -> bool:
'''True iff the pipeline contains a ProbeJoin — which switches
dispatch to the materialized runner.'''
return any(isinstance(op, m.ProbeJoin) for op in ops)
# -----------------------------------------------------------------------------
# Helper kernels (fixed string, matches Nim)
# -----------------------------------------------------------------------------
[docs]
def gen_materialized_join_helpers() -> str:
'''Three __global__ CUDA helper kernels + comments. Embedded verbatim
into the generated batch file when a materialized pipeline is present.
'''
return """
// ==========================================================================
// Materialized Join Helpers (Thrust-based)
// ==========================================================================
// Count matches for each input key (for load balancing)
template<typename KeyT, typename ViewT>
__global__ void probe_count_matches_kernel(
const uint32_t* input_rowids,
const KeyT* input_keys,
ViewT view,
uint32_t* counts,
size_t n
) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
KeyT key = input_keys[tid];
auto range = view.get_range(key);
counts[tid] = range.second - range.first;
}
// Materialize (left_rowid, right_rowid) pairs with merge-path partitioning
template<typename KeyT, typename ViewT>
__global__ void probe_materialize_pairs_kernel(
const uint32_t* input_rowids,
const KeyT* input_keys,
ViewT view,
const uint32_t* offsets,
uint32_t* out_left,
uint32_t* out_right,
size_t n
) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
uint32_t left_rowid = input_rowids[tid];
KeyT key = input_keys[tid];
auto range = view.get_range(key);
uint32_t out_offset = offsets[tid];
for (uint32_t i = range.first; i < range.second; ++i) {
out_left[out_offset] = left_rowid;
out_right[out_offset] = i; // Right side row ID
++out_offset;
}
}
// Gather column values using row IDs
template<typename Schema, int Col, typename T>
__global__ void gather_column_kernel(
const uint32_t* rowids,
const T* column_data,
T* output,
size_t n
) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
output[tid] = column_data[rowids[tid]];
}
"""
# -----------------------------------------------------------------------------
# Version normalization
# -----------------------------------------------------------------------------
def _version_cpp(ver: str | Version) -> str:
'''Map whatever the MIR carries (string or Version) to the C++
`*_VER` constant. Matches Nim's `case` on the stringly-typed version.
'''
if isinstance(ver, Version):
return ver.code
v = str(ver)
if v in ("DELTA", "DeltaVer", "DELTA_VER"):
return "DELTA_VER"
if v in ("FULL", "FullVer", "FULL_VER"):
return "FULL_VER"
if v in ("NEW", "NewVer", "NEW_VER"):
return "NEW_VER"
return "FULL_VER"
# -----------------------------------------------------------------------------
# Legacy in-kernel variant (gen_materialized_join_kernel)
# -----------------------------------------------------------------------------
[docs]
def gen_materialized_join_kernel(
ops: list[m.MirNode],
rule_name: str,
ctx: CodeGenContext,
) -> str:
'''Emit in-kernel materialized-join code. Unused by Nim's live
codegen (superseded by the host-side runner below) but ported for
completeness.
'''
i = ind(ctx)
code = ""
scan_op: m.Scan | None = None
probe_join_ops: list[m.ProbeJoin] = []
gather_ops: list[m.GatherColumn] = []
insert_ops: list[m.InsertInto] = []
for op in ops:
if isinstance(op, m.Scan):
scan_op = op
elif isinstance(op, m.ProbeJoin):
probe_join_ops.append(op)
elif isinstance(op, m.GatherColumn):
gather_ops.append(op)
elif isinstance(op, m.InsertInto):
insert_ops.append(op)
if scan_op is None:
return code + i + "// ERROR: Materialized join requires moScan as first op\n"
# Phase 1: Initial scan
code += i + f"// ===== Materialized Join: {rule_name} =====\n"
code += i + "// Phase 1: Scan first relation\n"
scan_rel = scan_op.rel_name
scan_idx_join_under = "_".join(str(c) for c in scan_op.index)
scan_view_var = f"view_{scan_rel}_{scan_idx_join_under}"
scan_idx_join_comma = ", ".join(str(c) for c in scan_op.index)
code += i + f"auto {scan_view_var} = db.get_view<{scan_rel}, {scan_idx_join_comma}>();\n"
code += i + f"auto {scan_rel}_size = {scan_view_var}.size();\n"
code += i + f"if ({scan_rel}_size == 0) return;\n\n"
current_buffer_var = f"{scan_rel}_rowids"
current_size_var = f"{scan_rel}_size"
code += i + "// Initial row IDs are just [0, N)\n"
code += i + f"thrust::device_vector<uint32_t> {current_buffer_var}({current_size_var});\n"
code += i + f"thrust::sequence({current_buffer_var}.begin(), {current_buffer_var}.end());\n\n"
# Phase 2: Probe joins
buffer_counter = 0
for probe_op in probe_join_ops:
buffer_counter += 1
probe_rel = probe_op.probe_rel
join_key = probe_op.join_key
output_buf = probe_op.output_buffer
probe_idx_join_under = "_".join(str(c) for c in probe_op.probe_index)
probe_idx_join_comma = ", ".join(str(c) for c in probe_op.probe_index)
code += i + f"// Phase 2.{buffer_counter}: Probe {probe_rel} on {join_key}\n"
probe_view_var = f"view_{probe_rel}_{probe_idx_join_under}"
code += i + f"auto {probe_view_var} = db.get_view<{probe_rel}, {probe_idx_join_comma}>();\n"
count_var = f"{output_buf}_counts"
input_n = current_size_var
code += i + "// Count matches per input row (for load balancing)\n"
code += i + f"thrust::device_vector<uint32_t> {count_var}({input_n});\n"
code += (
i + f"probe_count_matches({current_buffer_var}, {join_key}_gather, "
f"{probe_view_var}, {count_var}.data());\n"
)
offset_var = f"{output_buf}_offsets"
total_var = f"{output_buf}_total"
code += i + f"thrust::device_vector<uint32_t> {offset_var}({input_n} + 1);\n"
code += (
i + f"thrust::exclusive_scan({count_var}.begin(), {count_var}.end(), "
f"{offset_var}.begin(), 0u);\n"
)
code += i + f"{offset_var}[{input_n}] = {offset_var}[{input_n}-1] + {count_var}[{input_n}-1];\n"
code += i + f"uint32_t {total_var} = {offset_var}[{input_n}];\n"
code += i + f"if ({total_var} == 0) return;\n\n"
pair_buf_left = f"{output_buf}_left"
pair_buf_right = f"{output_buf}_right"
code += i + "// Materialize row ID pairs\n"
code += i + f"thrust::device_vector<uint32_t> {pair_buf_left}({total_var});\n"
code += i + f"thrust::device_vector<uint32_t> {pair_buf_right}({total_var});\n"
code += (
i + f"probe_materialize_pairs({current_buffer_var}, {join_key}_gather, "
f"{probe_view_var}, {offset_var}.data(), {pair_buf_left}.data(), "
f"{pair_buf_right}.data());\n\n"
)
current_buffer_var = pair_buf_left
current_size_var = total_var
# Phase 3: Gather columns
if gather_ops:
code += i + "// Phase 3: Gather output columns\n"
for gather_op in gather_ops:
gather_rel = gather_op.rel_name
gather_col = gather_op.column
output_var = gather_op.output_var
input_buf = gather_op.input_buffer
code += i + f"thrust::device_vector<int64_t> {output_var}_gather({current_size_var});\n"
code += (
i + f"gather_column<{gather_rel}, {gather_col}>({input_buf}"
f"_right, {output_var}_gather.data());\n"
)
code += "\n"
# Phase 4: Insert results
if insert_ops:
code += i + "// Phase 4: Insert results\n"
for insert_op in insert_ops:
dest_rel = insert_op.rel_name
dest_vars = insert_op.vars
code += i + f"// Insert into {dest_rel}\n"
code += i + f"auto {dest_rel}_dest = db.get_new<{dest_rel}>();\n"
gather_args = ", ".join(f"{v}_gather" for v in dest_vars)
code += i + f"insert_gathered_tuples({dest_rel}_dest, {gather_args});\n"
return code
# -----------------------------------------------------------------------------
# Materialized runner (the main entry)
# -----------------------------------------------------------------------------
[docs]
def gen_materialized_runner(
node: m.ExecutePipeline,
db_type_name: str,
) -> str:
'''Emit a host-side Thrust executor for a materialized-join pipeline.
Produces `struct JitRunner_<rule_name>` with an `execute(DB&, iter)`
static method that walks Scan → ProbeJoin+ → GatherColumn* → InsertInto
using Thrust primitives.
Matches Nim's `genMaterializedRunner` byte-for-byte (modulo the usual
clang-format whitespace).
'''
assert isinstance(node, m.ExecutePipeline)
rule_name = node.rule_name
pipeline = list(node.pipeline)
scan_op: m.Scan | None = None
probe_join_ops: list[m.ProbeJoin] = []
gather_ops: list[m.GatherColumn] = []
insert_ops: list[m.InsertInto] = []
for op in pipeline:
if isinstance(op, m.Scan):
scan_op = op
elif isinstance(op, m.ProbeJoin):
probe_join_ops.append(op)
elif isinstance(op, m.GatherColumn):
gather_ops.append(op)
elif isinstance(op, m.InsertInto):
insert_ops.append(op)
if scan_op is None:
return "// ERROR: Materialized join requires moScan as first op\n"
first_schema = scan_op.rel_name
first_version = _version_cpp(scan_op.version)
result = ""
result += "// =============================================================\n"
result += f"// JIT-Generated Materialized Runner: {rule_name}\n"
result += "// Host-side Thrust executor (no CUDA kernels)\n"
result += "// =============================================================\n\n"
result += f"struct JitRunner_{rule_name} {{\n"
result += f" using DB = {db_type_name};\n"
result += f" using FirstSchema = {first_schema};\n"
if node.dest_specs:
first_dest = node.dest_specs[0]
if isinstance(first_dest, m.InsertInto):
result += f" using DestSchema = {first_dest.rel_name};\n"
result += " using ValueType = typename FirstSchema::intern_value_type;\n"
result += (
" using RelType = std::decay_t<decltype(get_relation_by_schema<FirstSchema, "
f"{first_version}>(std::declval<DB&>()))>;\n"
)
result += " using IndexType = typename RelType::IndexTypeInst;\n"
result += " using ViewType = typename IndexType::NodeView;\n\n"
# execute()
result += " static void execute(DB& db, uint32_t iteration = 0) {\n"
result += f' nvtxRangePushA("{rule_name}");\n\n'
result += " using namespace SRDatalog::GPU::JIT;\n\n"
# Phase 1
scan_idx_comma = ", ".join(str(c) for c in scan_op.index)
result += " // Phase 1: Get initial data from first relation\n"
result += f" auto& rel_0 = get_relation_by_schema<{first_schema}, {first_version}>(db);\n"
result += (
f" auto& idx_0 = rel_0.ensure_index(SRDatalog::IndexSpec{{{{{scan_idx_comma}}}}}, false);\n"
)
result += " auto view_0 = idx_0.view();\n"
result += " size_t n_0 = view_0.num_rows_;\n"
result += " if (n_0 == 0) { nvtxRangePop(); return; }\n\n"
result += " // Current row ID buffer (starts as simple [0..N) sequence)\n"
result += " thrust::device_vector<uint32_t> rowids_0(n_0);\n"
result += " thrust::sequence(rowids_0.begin(), rowids_0.end());\n"
result += " size_t current_n = n_0;\n\n"
current_left_rowids = "rowids_0"
current_right_rowids = ""
prev_rel_idx = 0
# Phase 2: probe joins
for buffer_idx, probe_op in enumerate(probe_join_ops):
probe_rel = probe_op.probe_rel
probe_idx_comma = ", ".join(str(c) for c in probe_op.probe_index)
join_key = probe_op.join_key
probe_version = _version_cpp(probe_op.probe_version)
rel_idx = buffer_idx + 1
pair_left = f"left_{rel_idx}"
pair_right = f"right_{rel_idx}"
result += f" // Phase 2.{buffer_idx + 1}: Probe {probe_rel} on {join_key}\n"
result += (
f" auto& rel_{rel_idx} = get_relation_by_schema<{probe_rel}, {probe_version}>(db);\n"
)
result += (
f" auto& idx_{rel_idx} = rel_{rel_idx}.ensure_index("
f"SRDatalog::IndexSpec{{{{{probe_idx_comma}}}}}, false);\n"
)
result += f" auto view_{rel_idx} = idx_{rel_idx}.view();\n\n"
# Gather join key from current rowids.
result += f" // Gather join key '{join_key}' from current buffer\n"
result += f" thrust::device_vector<ValueType> keys_{rel_idx}(current_n);\n"
result += " // Access column 0 via col_data_ (col_data_ + 0 * stride_)\n"
result += (
f" thrust::gather(thrust::device, {current_left_rowids}.begin(), "
f"{current_left_rowids}.end(),\n"
)
result += f" view_{prev_rel_idx}.col_data_, keys_{rel_idx}.begin());\n\n"
# Count matches.
result += " // Count matches per input row\n"
result += f" thrust::device_vector<uint32_t> counts_{rel_idx}(current_n);\n"
result += (
f" probe_count_matches({current_left_rowids}, keys_{rel_idx}, "
f"view_{rel_idx}, counts_{rel_idx});\n\n"
)
# Offsets.
result += " // Compute output offsets\n"
result += f" thrust::device_vector<uint32_t> offsets_{rel_idx}(current_n);\n"
result += (
f" uint32_t total_{rel_idx} = compute_output_offsets(counts_{rel_idx}, "
f"offsets_{rel_idx});\n"
)
result += f" if (total_{rel_idx} == 0) {{ nvtxRangePop(); return; }}\n\n"
# Materialize pairs.
result += " // Materialize (left, right) row ID pairs\n"
result += f" thrust::device_vector<uint32_t> {pair_left}(total_{rel_idx});\n"
result += f" thrust::device_vector<uint32_t> {pair_right}(total_{rel_idx});\n"
result += (
f" probe_materialize_pairs({current_left_rowids}, keys_{rel_idx}, view_{rel_idx},\n"
)
result += f" offsets_{rel_idx}, {pair_left}, {pair_right});\n\n"
result += " // Update current state\n"
result += f" current_n = total_{rel_idx};\n\n"
current_left_rowids = pair_left
current_right_rowids = pair_right
prev_rel_idx = rel_idx
result += "\n"
# Phase 3: gather columns (stubbed — matches Nim which has TODO comments)
if gather_ops:
result += " // Phase 3: Gather output columns\n"
for gather_op in gather_ops:
gather_var = gather_op.output_var
gather_rel = gather_op.rel_name
gather_col = gather_op.column
result += f" thrust::device_vector<ValueType> {gather_var}_data(current_n);\n"
result += f" // Note: Gathering {gather_var} from {gather_rel} column {gather_col}\n"
result += " // Using simplified gather from last right-side row IDs\n"
result += " // TODO: Track which buffer corresponds to which relation\n\n"
result += "\n"
# Phase 4: insert results.
if insert_ops:
first_insert = insert_ops[0]
dest_rel = first_insert.rel_name
dest_vars = list(first_insert.vars)
arity = len(dest_vars)
result += f" // Phase 4: Insert into {dest_rel}\n"
result += f" auto& dest = get_relation_by_schema<{dest_rel}, NEW_VER>(db);\n"
result += " size_t old_size = dest.size();\n"
result += " size_t new_size = old_size + current_n;\n"
result += " dest.resize_interned_columns(new_size);\n\n"
result += " // Gather output columns into device vectors\n"
for col_idx, _dest_var in enumerate(dest_vars):
result += f" thrust::device_vector<ValueType> out_col_{col_idx}(current_n);\n"
result += "\n"
# First output column: chain back through the left buffers to rowids_0,
# then gather from view_0 col 0.
if len(dest_vars) >= 1:
result += " // Output column 0 (e.g. invocation): gather from view_0 col 0 via chained left row IDs\n"
result += " // Chained gather: left_N -> left_N-1 -> ... -> rowids_0\n"
num_probes = len(probe_join_ops)
if num_probes == 0:
result += (
" thrust::gather(thrust::device, rowids_0.begin(), rowids_0.begin() + current_n,\n"
)
result += " view_0.col_data_, out_col_0.begin());\n"
else:
result += " // Final left buffer chains back to view_0\n"
result += " thrust::device_vector<uint32_t> chained_rowids(current_n);\n"
result += (
f" thrust::copy(left_{num_probes}.begin(), left_{num_probes}.end(), "
"chained_rowids.begin());\n"
)
for back_idx in range(num_probes - 1, 0, -1):
result += (
" thrust::gather(thrust::device, chained_rowids.begin(), chained_rowids.end(),\n"
)
result += f" left_{back_idx}.begin(), chained_rowids.begin());\n"
result += (
" thrust::gather(thrust::device, chained_rowids.begin(), chained_rowids.end(),\n"
)
result += " rowids_0.begin(), chained_rowids.begin());\n"
result += (
" thrust::gather(thrust::device, chained_rowids.begin(), chained_rowids.end(),\n"
)
result += " view_0.col_data_, out_col_0.begin());\n"
result += "\n"
# Second output column: gather from last view's col 2 via right rowids.
if len(dest_vars) >= 2:
last_rel_idx = len(probe_join_ops)
result += " // Output column 1 (e.g. toMeth): from last view via right row IDs\n"
result += f" thrust::gather(thrust::device, right_{last_rel_idx}.begin(), right_{last_rel_idx}.end(),\n"
result += (
f" view_{last_rel_idx}.col_data_ + 2 * view_{last_rel_idx}.stride_,\n"
)
result += " out_col_1.begin());\n\n"
# Copy gathered columns to destination.
result += " // Copy to destination relation columns\n"
for col_idx, _dest_var in enumerate(dest_vars):
result += f" thrust::copy(out_col_{col_idx}.begin(), out_col_{col_idx}.end(),\n"
result += f" dest.template interned_column<{col_idx}>() + old_size);\n"
result += "\n"
result += "\n nvtxRangePop();\n"
result += " }\n"
result += "};\n\n"
return result