Source code for srdatalog.ir.codegen.cuda.pipeline_utils

'''Pipeline-shape utilities used by `complete_runner`.

Originally ported from `src/srdatalog/codegen/target_jit/
jit_emit_helpers.nim` — the legacy Filter / ConstantBind / InsertInto
emit procs (`jit_filter`, `jit_constant_bind`, `jit_insert_into`) used
to live here too. They've been retired alongside the rest of the
legacy `ir/dialects/target/cuda/pipeline.py` chain; the dialect now owns those
emits via `dialects.relation.sorted_array.lowerings._lower_inner_chain`
and `_lower_insert_into`.

What remains:
  - `has_balanced_scan` / `get_balanced_scan_info` — used to decide
    whether to emit a balanced-scan kernel variant.
  - `has_tiled_cartesian_eligible` — runner uses this to decide
    whether to enable the tiled-Cartesian materialize path.
  - `assign_handle_positions` / `count_handles_in_pipeline` — pipeline
    pre-pass + view-slot counting. Note: `dialects.target.cuda.
    envelope` has its own `assign_handle_positions` for the dialect
    path; this one stays for the runner-side prepass that mutates
    `mutable_pipe` before kernel emission.
'''

from __future__ import annotations

from dataclasses import dataclass, field

import srdatalog.ir.mir.types as m

# -----------------------------------------------------------------------------
# Balanced-partitioning detection
# -----------------------------------------------------------------------------


[docs] def has_balanced_scan(ops: list[m.MirNode]) -> bool: '''True if the pipeline's first op is a BalancedScan (root level).''' return len(ops) > 0 and isinstance(ops[0], m.BalancedScan)
[docs] def has_tiled_cartesian_eligible(ops: list[m.MirNode]) -> bool: '''Pipeline contains a 2-source CartesianJoin where each source binds exactly one variable — eligible for the atomic-free tiled/coalesced write optimization. ''' for op in ops: if isinstance(op, m.CartesianJoin): if ( len(op.sources) == 2 and len(op.var_from_source) == 2 and len(op.var_from_source[0]) == 1 and len(op.var_from_source[1]) == 1 ): return True return False
[docs] @dataclass class BalancedScanInfo: '''Lightweight struct returned by get_balanced_scan_info.''' group_var: str = "" src1_rel_name: str = "" src1_index: list[int] = field(default_factory=list) src1_handle_idx: int = -1 src2_rel_name: str = "" src2_index: list[int] = field(default_factory=list) src2_handle_idx: int = -1
[docs] def get_balanced_scan_info(ops: list[m.MirNode]) -> BalancedScanInfo: '''Extract group var + per-source (rel, index, handle_idx) from the root BalancedScan. Returns an all-empty BalancedScanInfo when the root op isn't a BalancedScan (matching Nim's sentinel return). ''' if has_balanced_scan(ops): bs = ops[0] assert isinstance(bs, m.BalancedScan) s1, s2 = bs.source1, bs.source2 return BalancedScanInfo( group_var=bs.group_var, src1_rel_name=s1.rel_name, src1_index=list(s1.index), src1_handle_idx=getattr(s1, "handle_start", -1), src2_rel_name=s2.rel_name, src2_index=list(s2.index), src2_handle_idx=getattr(s2, "handle_start", -1), ) return BalancedScanInfo()
# ----------------------------------------------------------------------------- # Pipeline handle counting # ----------------------------------------------------------------------------- def _assign_handle_positions_rec(node: m.MirNode, offset_box: list[int]) -> None: '''Recursively assign `handle_start` to this node and any children. `offset_box` is a single-element list used as a mutable counter (Python closures can't reassign captured ints cleanly in a loop). ''' if ( isinstance(node, m.ColumnSource) or isinstance(node, m.Scan) or isinstance(node, m.Aggregate) or isinstance(node, m.Negation) ): node.handle_start = offset_box[0] offset_box[0] += 1 elif isinstance(node, m.ColumnJoin) or isinstance(node, m.CartesianJoin): node.handle_start = offset_box[0] for src in node.sources: _assign_handle_positions_rec(src, offset_box) elif isinstance(node, m.BalancedScan): node.handle_start = offset_box[0] _assign_handle_positions_rec(node.source1, offset_box) _assign_handle_positions_rec(node.source2, offset_box) elif isinstance(node, m.PositionedExtract): for src in node.sources: _assign_handle_positions_rec(src, offset_box)
[docs] def assign_handle_positions(ops: list[m.MirNode]) -> None: '''Assign `handle_start` to every source-bearing node in pipeline order starting from 0. Mutates `ops` in place. Mirrors Nim's assignHandlePositions.''' offset_box = [0] for op in ops: _assign_handle_positions_rec(op, offset_box)
[docs] def count_handles_in_pipeline(ops: list[m.MirNode]) -> int: '''Max `handle_start + 1` seen across the pipeline — the number of view slots the kernel's `views[]` array needs. Zero when no op carries a handle_start (caller should still allocate 1 slot in that case; this function faithfully returns 0 to match Nim). ''' result = 0 for op in ops: if isinstance(op, m.ColumnJoin) or isinstance(op, m.CartesianJoin): for src in op.sources: h = getattr(src, "handle_start", -1) result = max(result, h + 1) elif isinstance(op, m.Scan) or isinstance(op, m.Negation) or isinstance(op, m.Aggregate): result = max(result, getattr(op, "handle_start", -1) + 1) return result