'''View/handle slot mapping + view-declaration emission.
Port of src/srdatalog/codegen/target_jit/jit_view_management.nim.
Deduplicates pipeline source views by (rel_name, index, version),
computes view-slot offsets for multi-view sources (e.g. Device2LevelIndex
contributes >1 slot per source via the plugin), and emits the
`auto view_X = views[i];` block at the top of each generated kernel.
A "view spec" is the triple (rel_name, index cols, version) plus the
handle_idx of the first op that referenced it. Two uses of the same
relation with different index orderings are distinct views; two uses
with the same (rel, idx, ver) share one view slot even across nested
CJ / Cartesian handles.
'''
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
import srdatalog.ir.mir.types as m
from srdatalog.ir.codegen.cuda.context import (
CodeGenContext,
gen_view_access,
get_view_slot_base,
ind,
)
from srdatalog.ir.codegen.cuda.plugin import plugin_view_count
# -----------------------------------------------------------------------------
# Source-spec helpers
# -----------------------------------------------------------------------------
_SOURCE_SPEC_TYPES = (m.ColumnSource, m.Scan, m.Negation, m.Aggregate)
[docs]
def get_source_index(src_spec: m.MirNode) -> list[int]:
'''Extract column ordering from a source-bearing node.'''
if isinstance(src_spec, _SOURCE_SPEC_TYPES):
return list(src_spec.index)
return []
def _source_rel_version(src_spec: m.MirNode) -> tuple[str, str]:
'''Pull (rel_name, version-as-string) off a source-bearing node.
Version is converted to its Nim `*_VER` form via `.code` so spec keys
match what the existing fixtures use.'''
if isinstance(src_spec, _SOURCE_SPEC_TYPES):
return src_spec.rel_name, src_spec.version.code
return "", ""
[docs]
def source_spec_key(src_spec: m.MirNode) -> str:
'''Unique key for a source spec: `<relName>_<VERSION>_<cols_joined>`.
Two uses of the same relation with different index orderings → distinct
keys. Includes version to keep DELTA / FULL views distinct.
'''
rel, ver = _source_rel_version(src_spec)
idx_str = "_".join(str(c) for c in get_source_index(src_spec))
return f"{rel}_{ver}_{idx_str}"
def _handle_start_of(src_spec: m.MirNode) -> int:
return getattr(src_spec, "handle_start", -1)
# -----------------------------------------------------------------------------
# View-slot computation
# -----------------------------------------------------------------------------
[docs]
def compute_total_view_count(
source_specs: Sequence[m.MirNode],
rel_index_types: dict[str, str],
) -> int:
'''Total view slots needed for all unique sources. Nested CJ/Cart
handles for the same (rel, version, index) share slots.
'''
total = 0
seen: list[str] = []
for src_spec in source_specs:
rel, ver = _source_rel_version(src_spec)
key = source_spec_key(src_spec)
if key not in seen:
seen.append(key)
index_type = rel_index_types.get(rel, "")
total += plugin_view_count(ver, index_type)
return total
[docs]
def compute_view_slot_offsets(
source_specs: Sequence[m.MirNode],
rel_index_types: dict[str, str],
) -> dict[int, int]:
'''Map `handle_idx` → base slot in `views[]`.
Multiple handles for the same relation+version+index share one slot —
nested CJ/Cart handles reference the same physical view as the root
handle.
'''
out: dict[int, int] = {}
slot = 0
seen_slots: dict[str, int] = {}
for src_spec in source_specs:
handle_idx = _handle_start_of(src_spec)
rel, ver = _source_rel_version(src_spec)
key = source_spec_key(src_spec)
if handle_idx >= 0:
if key in seen_slots:
out[handle_idx] = seen_slots[key]
else:
out[handle_idx] = slot
seen_slots[key] = slot
index_type = rel_index_types.get(rel, "")
slot += plugin_view_count(ver, index_type)
return out
[docs]
def register_pipeline_handles(
offsets: dict[int, int],
pipeline: list[m.MirNode],
rel_index_types: dict[str, str],
root_slots: dict[str, int],
) -> None:
'''Walk the pipeline body and register every ColumnSource `handle_start`
against the `root_slots` table. Mutates `offsets` in place — matches
Nim's `registerPipelineHandles` var-param signature.
'''
for node in pipeline:
if isinstance(node, m.ColumnSource):
if node.handle_start not in offsets:
key = source_spec_key(node)
if key in root_slots:
offsets[node.handle_start] = root_slots[key]
elif isinstance(node, m.ColumnJoin) or isinstance(node, m.CartesianJoin):
for src in node.sources:
if isinstance(src, m.ColumnSource) and src.handle_start not in offsets:
key = source_spec_key(src)
if key in root_slots:
offsets[src.handle_start] = root_slots[key]
[docs]
def build_root_slot_map(
source_specs: Sequence[m.MirNode],
rel_index_types: dict[str, str],
) -> dict[str, int]:
'''Map `<relName>_<VER>_<cols>` → view-slot base for each root source.
First occurrence wins; subsequent duplicates share the first's slot.
'''
out: dict[str, int] = {}
slot = 0
for src_spec in source_specs:
rel, ver = _source_rel_version(src_spec)
key = source_spec_key(src_spec)
if key not in out:
out[key] = slot
index_type = rel_index_types.get(rel, "")
slot += plugin_view_count(ver, index_type)
return out
# -----------------------------------------------------------------------------
# ViewSpec + unique view collection
# -----------------------------------------------------------------------------
[docs]
@dataclass
class ViewSpec:
'''(rel_name, index, version, handle_idx) — handle_idx is the index
of the FIRST op that referenced this view.'''
rel_name: str
index: list[int]
version: str
handle_idx: int
[docs]
def spec_key(rel_name: str, index: list[int], version: str = "") -> str:
'''`Rel_<cols>_<VER>` — version-suffixed so DELTA and FULL share-key
doesn't collapse them.'''
base = rel_name + "_" + "_".join(str(c) for c in index)
return base + "_" + version if version else base
def _record_spec(
specs: list[ViewSpec],
seen: set[str],
rel: str,
idx: list[int],
ver: str,
handle: int,
) -> None:
k = spec_key(rel, idx, ver)
if k in seen:
return
seen.add(k)
specs.append(ViewSpec(rel_name=rel, index=list(idx), version=ver, handle_idx=handle))
[docs]
def collect_unique_view_specs(ops: list[m.MirNode]) -> list[ViewSpec]:
'''Walk the pipeline body and collect a de-duplicated list of
`ViewSpec`s, first-occurrence order. Covers every op kind that
references a view: ColumnJoin, CartesianJoin, Scan, Negation,
Aggregate, BalancedScan, PositionedExtract.
'''
specs: list[ViewSpec] = []
seen: set[str] = set()
for op in ops:
if isinstance(op, m.ColumnJoin) or isinstance(op, m.CartesianJoin):
for src in op.sources:
if isinstance(src, m.ColumnSource):
_record_spec(
specs,
seen,
src.rel_name,
list(src.index),
src.version.code,
src.handle_start,
)
elif isinstance(op, m.Scan) or isinstance(op, m.Negation) or isinstance(op, m.Aggregate):
_record_spec(
specs,
seen,
op.rel_name,
list(op.index),
op.version.code,
op.handle_start,
)
elif isinstance(op, m.BalancedScan):
s1, s2 = op.source1, op.source2
_record_spec(
specs,
seen,
s1.rel_name,
list(s1.index),
s1.version.code,
s1.handle_start,
)
_record_spec(
specs,
seen,
s2.rel_name,
list(s2.index),
s2.version.code,
s2.handle_start,
)
elif isinstance(op, m.PositionedExtract):
for src in op.sources:
if isinstance(src, m.ColumnSource):
_record_spec(
specs,
seen,
src.rel_name,
list(src.index),
src.version.code,
src.handle_start,
)
return specs
# -----------------------------------------------------------------------------
# View declaration emission
# -----------------------------------------------------------------------------
[docs]
def jit_emit_view_declarations(
specs: list[ViewSpec],
ops: list[m.MirNode],
ep_source_specs: Sequence[m.MirNode],
ctx: CodeGenContext,
) -> str:
'''Emit the top-of-kernel `auto view_X = views[i];` block.
Populates `ctx.view_vars` with both:
- spec key (`rel_idx_VER`) → view_var
- str(handle_idx) → view_var (so nested op emitters can resolve
"which view is this handle referring to?" directly)
'''
if not specs:
return ""
code = ""
i = ind(ctx)
code += i + "using ViewType = std::remove_cvref_t<decltype(views[0])>;\n"
code += i + "using HandleType = ViewType::NodeHandle;\n\n"
if ctx.debug:
code += (
i + "// View declarations (deduplicated by spec, " + str(len(specs)) + " unique views)\n"
)
spec_to_view_var: list[tuple[str, str]] = []
for sp in specs:
key = spec_key(sp.rel_name, sp.index, sp.version)
idx_str = "_".join(str(v) for v in sp.index)
view_var = f"view_{sp.rel_name}_{idx_str}" + (f"_{sp.version}" if sp.version else "")
view_idx = get_view_slot_base(ctx, sp.handle_idx)
code += i + f"auto {view_var} = {gen_view_access(view_idx)};\n"
spec_to_view_var.append((key, view_var))
ctx.view_vars[key] = view_var
# Map every op's handle_start to its view_var so later emitters can
# resolve handle -> view quickly.
for op in ops:
if isinstance(op, m.ColumnJoin) or isinstance(op, m.CartesianJoin):
for src in op.sources:
if not isinstance(src, m.ColumnSource):
continue
k = spec_key(src.rel_name, list(src.index), src.version.code)
for kv_key, view_var in spec_to_view_var:
if kv_key == k:
ctx.view_vars[str(src.handle_start)] = view_var
break
elif isinstance(op, m.Scan) or isinstance(op, m.Negation) or isinstance(op, m.Aggregate):
k = spec_key(op.rel_name, list(op.index), op.version.code)
for kv_key, view_var in spec_to_view_var:
if kv_key == k:
ctx.view_vars[str(op.handle_start)] = view_var
break
code += "\n"
return code