Source code for srdatalog.ir.mir.types

'''Python mirror of src/srdatalog/mir/mir_types.nim.

Pure IR data. Distinct from the existing python/mir_commands.py which
emits C++ for mhk's codegen and carries codegen-side state (cursor slots,
program backref). My MirNode types carry only what the Nim MIR carries;
they are suitable for S-expr printing and byte-diff against Nim golden.

Mapping to Nim:
  moColumnSource         -> ColumnSource
  moScan                 -> Scan
  moColumnJoin           -> ColumnJoin
  moCartesianJoin        -> CartesianJoin
  moFilter               -> Filter
  moNegation             -> Negation
  moInsertInto           -> InsertInto
  moExecutePipeline      -> ExecutePipeline
  moRebuildIndex         -> RebuildIndex
  moClearRelation        -> ClearRelation
  moCheckSize            -> CheckSize
  moComputeDelta         -> ComputeDelta
  moComputeDeltaIndex    -> ComputeDeltaIndex
  moMergeIndex           -> MergeIndex
  moMergeRelation        -> MergeRelation
  moRebuildIndexFromIndex -> RebuildIndexFromIndex
  moFixpointPlan         -> FixpointPlan
  moBlock                -> Block
  moProgram              -> Program

All registered Nim MIR op kinds are now covered. Advanced ops
(Aggregate, CreateFlatView, InnerPipeline, ProbeJoin, GatherColumn)
have types + emitters but aren't yet produced by Python lowering;
they're here so downstream codegen bridges and future lowering
extensions have the node shapes to target.
'''

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Union

from srdatalog.ir.hir.types import Version

# -----------------------------------------------------------------------------
# Leaf / pipeline ops
# -----------------------------------------------------------------------------


[docs] @dataclass class ColumnSource: '''(column-source #:index (Rel cols...) #:ver V #:prefix (vars))''' rel_name: str version: Version index: list[int] # column order prefix_vars: list[str] = field(default_factory=list) handle_start: int = -1 # codegen-internal; not emitted clause_idx: int = -1 # origin clause index (post-reorder pass)
[docs] @dataclass class Scan: '''(scan #:vars (...) #:index (Rel cols...) #:ver V #:prefix (...))''' vars: list[str] rel_name: str version: Version index: list[int] prefix_vars: list[str] = field(default_factory=list) handle_start: int = -1
[docs] @dataclass class ColumnJoin: '''(column-join #:var x #:sources (...))''' var_name: str sources: list[ColumnSource] handle_start: int = -1
[docs] @dataclass class CartesianJoin: '''(cartesian-join #:vars (...) #:var-from-source ((...)) #:sources (...))''' vars: list[str] sources: list[ColumnSource] var_from_source: list[list[str]] = field(default_factory=list) handle_start: int = -1
[docs] @dataclass class Filter: '''(filter #:vars (...) #:code "...")''' vars: list[str] code: str
[docs] @dataclass class ConstantBind: '''(constant-bind #:var v #:code "..." #:deps (...)) Generated by head-constant rewriting: a HIR LetClause lowers to this MIR node, which binds `var_name` to `code` once `deps` are bound. ''' var_name: str code: str deps: list[str]
[docs] @dataclass class Aggregate: '''(aggregate #:var cnt #:func AggCount #:index (Rel 0 1) #:ver FULL #:prefix (x y)) Aggregation body clause lowered to MIR. `result_var` is the variable bound by the aggregate; `func` is the C++ aggregation type (AggCount / AggSum / ... / custom). `prefix_vars` are join-prefix vars read by the aggregate's index lookup. ''' result_var: str agg_func: str rel_name: str version: Version index: list[int] prefix_vars: list[str] = field(default_factory=list) handle_start: int = -1
[docs] @dataclass class CreateFlatView: '''(create-flat-view #:schema R #:index (cols...) #:ver V) Emitted by split-rule lowering to expose temp-relation intern columns as an unsorted view (avoids a GPU sort between Pipeline A and B). ''' rel_name: str version: Version index: list[int]
[docs] @dataclass class InnerPipeline: '''(inner-pipeline #:rule R #:bound-vars (...) #:handles (...) #:ops (...)) JIT-generated inner device function for nested joins (Level 2+). Nim emits an explicit C++ functor rather than recursive template metaprogramming. ''' rule_name: str input_handles: list[MirNode] = field(default_factory=list) inner_ops: list[MirNode] = field(default_factory=list) bound_vars: list[str] = field(default_factory=list)
[docs] @dataclass class ProbeJoin: '''(probe-join ...) — binary-join mode node. Performs a binary-search probe of `probe_rel` keyed on `join_key`, writing row-id pairs into `output_buffer`. Uses merge-path balancing for unbalanced output. ''' probe_rel: str probe_version: Version probe_index: list[int] join_key: str input_buffer: str = "" # empty for the first join in a pipeline output_buffer: str = ""
[docs] @dataclass class GatherColumn: '''(gather-column ...) — binary-join mode node. Dereferences `input_buffer` of row IDs into values of `column` from `rel_name`, binding the result to `output_var`. ''' rel_name: str rel_version: Version column: int output_var: str input_buffer: str = ""
[docs] @dataclass class Negation: '''(negation #:schema R #:ver V #:index (R cols...) #:prefix (...))''' rel_name: str version: Version index: list[int] prefix_vars: list[str] = field(default_factory=list) const_args: list[tuple[int, int]] = field(default_factory=list) handle_start: int = -1
[docs] @dataclass class InsertInto: '''(insert-into #:schema R #:ver V #:dedup-index (cols...) #:terms (vars))''' rel_name: str version: Version vars: list[str] index: list[int] # dedup index columns
# ----------------------------------------------------------------------------- # Fixpoint maintenance ops (scalar — no children) # -----------------------------------------------------------------------------
[docs] @dataclass class RebuildIndex: '''(rebuild-index #:index (R cols...) #:ver V)''' rel_name: str version: Version index: list[int]
[docs] @dataclass class ClearRelation: '''(clear-relation #:schema R #:ver V)''' rel_name: str version: Version
[docs] @dataclass class CheckSize: '''(check-size #:schema R #:ver V)''' rel_name: str version: Version
[docs] @dataclass class ComputeDelta: '''(compute-delta #:schema R)''' rel_name: str index: list[int] = field(default_factory=list) # canonical index cols
[docs] @dataclass class ComputeDeltaIndex: '''(compute-delta-index #:schema R #:canonical-index (cols...))''' rel_name: str canonical_index: list[int]
[docs] @dataclass class MergeIndex: '''(merge-index #:index (R cols...))''' rel_name: str index: list[int]
[docs] @dataclass class MergeRelation: '''(merge-relation #:schema R)''' rel_name: str
[docs] @dataclass class RebuildIndexFromIndex: '''(rebuild-index-from-index #:source (R cols...) #:target (R cols...) #:ver V)''' rel_name: str source_index: list[int] target_index: list[int] version: Version
# ----------------------------------------------------------------------------- # Structural ops # -----------------------------------------------------------------------------
[docs] @dataclass class ExecutePipeline: '''(execute-pipeline #:rule N #:sources (tuple ...) #:dests (tuple ...) <body>)''' pipeline: list[MirNode] # column-source / scan / negation / aggregate leaves for scheduler source_specs: list[Union[ColumnSource, Scan, Negation, Aggregate]] dest_specs: list[InsertInto] # insert-into targets rule_name: str = "" clause_order: list[int] = field(default_factory=list) use_fan_out: bool = False work_stealing: bool = False block_group: bool = False dedup_hash: bool = False count: bool = False concurrent_write: bool = False
[docs] @dataclass class FixpointPlan: '''(fixpoint-plan <instructions...>)''' instructions: list[MirNode] schema_arities: list[tuple[str, int]] = field(default_factory=list)
[docs] @dataclass class Block: '''(block <instructions...>)''' instructions: list[MirNode]
[docs] @dataclass class BalancedScan: '''(balanced-scan #:group-var v #:source1 (...) #:source2 (...)) Pre-computes a work-distribution histogram for a skewed join: partition work across (source1 × source2) pairs grouped by `group_var` (the "balanced root"). Emitted by lowering when a rule's plan specifies `balanced_root` + `balanced_sources`. ''' group_var: str source1: ColumnSource source2: ColumnSource vars1: list[str] = field(default_factory=list) vars2: list[str] = field(default_factory=list) handle_start: int = -1
[docs] @dataclass class PositionedExtract: '''(positioned-extract #:var v #:sources ((...) (...)) #:bind (...)) After a BalancedScan binds its group variable, any subsequent ColumnJoin for that variable becomes a positioned extract: point-lookup rather than iteration. ''' sources: list[ColumnSource] var_name: str bind_vars: list[str] = field(default_factory=list)
[docs] @dataclass class ParallelGroup: '''(parallel-group <ops...>) — independent ops that can run concurrently.''' ops: list[MirNode]
[docs] @dataclass class InjectCppHook: '''(inject-cpp-hook #:rule R #:code "...") — raw C++ injection point. The Nim emitter prints #:code "..." without the actual code body (just the ellipsis), matching debug/tooling display; real content stays in the node's code field for codegen. ''' code: str rule_name: str = ""
[docs] @dataclass class PostStratumReconstructInternCols: '''(post-stratum-reconstruct-intern-cols #:rel R #:canonical-index (cols...)) Emitted once per stratum per relation: single cleanup step that replaces a per-index RebuildIndex loop, pushing complexity into the C++ runtime. ''' rel_name: str canonical_index: list[int]
[docs] @dataclass class Program: '''(program (step #:recursive b <plan>) ...)''' steps: list[tuple[MirNode, bool]] # (plan, is_recursive)
# ----------------------------------------------------------------------------- # Union type and convenience # ----------------------------------------------------------------------------- MirNode = Union[ ColumnSource, Scan, ColumnJoin, CartesianJoin, Filter, ConstantBind, Negation, Aggregate, CreateFlatView, InnerPipeline, ProbeJoin, GatherColumn, InsertInto, RebuildIndex, ClearRelation, CheckSize, ComputeDelta, ComputeDeltaIndex, MergeIndex, MergeRelation, RebuildIndexFromIndex, ExecutePipeline, FixpointPlan, Block, ParallelGroup, BalancedScan, PositionedExtract, InjectCppHook, PostStratumReconstructInternCols, Program, ]