Source code for srdatalog.ir.dialects.relation.sorted_array.ops

'''Sorted-array dialect ops — M1 subset.

Each op references its operands by *name* (string) for M1 pragmatism.
The legacy emitter passes view/handle variable names through string
keys; matching that lets the byte-equivalence gate compare directly.
A later milestone replaces this with lexical binding (D8) once the
gate has been validated end-to-end.

Operand naming convention:
  view_name   — name of a previously-declared view variable
                (`auto view_<rel>_<cols>_<ver> = views[<slot>];`).
  handle_name — name of a previously-bound handle variable (declared
                via iir.cf.Bind).
'''

from __future__ import annotations

from dataclasses import dataclass
from typing import final

from srdatalog.ir.core import Op


[docs] @final @dataclass(frozen=True, slots=True) class SaRoot(Op): '''Root handle into the sorted-array's full row-range. Lowers (target.cuda) to: `HandleType(0, <view_name>.num_rows_, 0)`. ''' view_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaValid(Op): '''Whether a handle is non-empty / non-degenerate. Lowers (target.cuda) to: `<handle_name>.valid()`. ''' handle_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaDegree(Op): '''Branching factor / row count at a handle position. Lowers (target.cuda) to: `<handle_name>.degree()`. ''' handle_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaGetVal(Op): '''Get the value at column `col`, row `idx_var_name`, in the view. Used inside a root scan: each var binding fetches its column. Lowers (target.cuda) to: `<view_name>.get_value(<col>, <idx_var_name>)`. ''' view_name: str col: int idx_var_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaGetValAt(Op): '''Get the value at column position `col` of a narrowed handle's child at slot `idx_var_name`. Used in nested ColumnJoin paths; M1 declares the op for completeness but doesn't yet lower it. Lowers (target.cuda) to: `<view_name>.get_value_at(<handle_name>.begin(), <idx_var_name>)`. ''' handle_name: str view_name: str idx_var_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaHint(Op): '''Range-narrowed root handle constructor (expression-shaped). Lowers (target.cuda) to: `HandleType(<lo_var>, <hi_var>, <depth>)`. Typically composed with SaPrefCoop: `HandleType(lo, hi, 0).prefix(root_val, tile, view)`. ''' lo_var: str hi_var: str depth: int = 0
[docs] @final @dataclass(frozen=True, slots=True) class SaPrefCoop(Op): '''Cooperative prefix-narrowing on a parent handle expression. Lowers (target.cuda) to: `<parent>.prefix(<key>, tile, <view>)`. Used in multi-source root CJ where 32 threads cooperatively binary-search the parent handle for `key`. ''' parent: Op key_var: str view_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaPrefSeq(Op): '''Sequential (per-thread) prefix-narrowing on a parent handle. Lowers (target.cuda) to: `<parent>.prefix_seq(<key>, <view>)`. Used inside a Cartesian loop where each thread already has its own (idx0, idx1, …) decomposition and runs the prefix narrowing independently — the cooperative form would be wrong because the threads don't agree on `key`. ''' parent: Op key_var: str view_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaIterators(Op): '''Iterator pair for a handle, suitable to hand to `intersect_handles`. Lowers (target.cuda) to: `<handle>.iterators(<view>)`. ''' handle_name: str view_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaChildRange(Op): '''Narrowed child range from a handle. Lowers (target.cuda) to: `<handle>.child_range(<pos_expr>, <key_var>, tile, <view>)`. Used inside an IntersectIter body to produce per-source child handles for the next nesting level. ''' handle_name: str pos_expr: str key_var: str view_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaGetValAtPos(Op): '''Get the value at a column for a row inside a narrowed handle's begin-offset. Lowers (target.cuda) to: `<view>.get_value(<col>, <handle>.begin() + <idx_var>)`. Used by the nested CartesianJoin to bind per-source vars from the flat-decomposed indices. ''' view_name: str col: int handle_name: str idx_var_name: str
[docs] @final @dataclass(frozen=True, slots=True) class SaTiledCartesian2D(Op): '''2-source nested CartesianJoin with tiled smem pre-load + ballot writes. Lowers to the legacy `_emit_tiled_cartesian` shape: if (<total_var> > 32) { // Tiled Cartesian: smem pre-load reads, standard emit_direct writes for (uint32_t <t0_base> = 0; <t0_base> < <degree_var0>; <t0_base> += kCartTileSize) { uint32_t <t0_len> = min(<t0_base> + (uint32_t)kCartTileSize, <degree_var0>) - <t0_base>; for (uint32_t _ti = <lane_var>; _ti < <t0_len>; _ti += <group_size_var>) s_cart[warp_in_block][0][_ti] = <view_var0>.get_value(<col0>, <handle_var0>.begin() + <t0_base> + _ti); for (uint32_t <t1_base> = 0; <t1_base> < <degree_var1>; <t1_base> += kCartTileSize) { uint32_t <t1_len> = min(<t1_base> + (uint32_t)kCartTileSize, <degree_var1>) - <t1_base>; for (uint32_t _ti = <lane_var>; _ti < <t1_len>; _ti += <group_size_var>) s_cart[warp_in_block][1][_ti] = <view_var1>.get_value(<col1>, <handle_var1>.begin() + <t1_base> + _ti); <tile_var>.sync(); uint32_t <tile_total> = <t0_len> * <t1_len>; for (uint32_t <batch_var> = 0; <batch_var> < <tile_total>; <batch_var> += <group_size_var>) { uint32_t <flat_idx_var> = <batch_var> + <lane_var>; bool <valid_var> = <flat_idx_var> < <tile_total>; auto <var_name0> = <valid_var> ? s_cart[warp_in_block][0][<flat_idx_var> / <t1_len>] : ValueType{0}; auto <var_name1> = <valid_var> ? s_cart[warp_in_block][1][<flat_idx_var> % <t1_len>] : ValueType{0}; <tiled_body> } <tile_var>.sync(); } } } else { for (uint32_t <fb_batch_var> = 0; <fb_batch_var> < <total_var>; <fb_batch_var> += <group_size_var>) { uint32_t <flat_idx_var> = <fb_batch_var> + <lane_var>; bool <valid_var> = <flat_idx_var> < <total_var>; const bool <major_var> = (<degree_var1> >= <degree_var0>); uint32_t <idx0_var>, <idx1_var>; if (<major_var>) { <idx0_var> = <flat_idx_var> / <degree_var1>; <idx1_var> = <flat_idx_var> % <degree_var1>; } else { <idx1_var> = <flat_idx_var> / <degree_var0>; <idx0_var> = <flat_idx_var> % <degree_var0>; } auto <var_name0> = <view_var0>.get_value(<col0>, <handle_var0>.begin() + <idx0_var>); auto <var_name1> = <view_var1>.get_value(<col1>, <handle_var1>.begin() + <idx1_var>); <fallback_body> } } `body` is the trailing-InsertInto run rendered as a TiledBallotBlock (the `ctx.tiled_cartesian_valid_var`-driven InsertInto variant). The legacy `_emit_tiled_cartesian` emits `tiled_body` in both branches when set — so the dialect carries one body and uses it twice. Body emits at the surrounding scope's indent (legacy quirk where body was rendered before the tiled wrap textually surrounded it). ''' view_var0: str view_var1: str handle_var0: str handle_var1: str col0: int col1: int var_name0: str var_name1: str lane_var: str group_size_var: str total_var: str degree_var0: str degree_var1: str flat_idx_var: str t0_base: str t1_base: str t0_len: str t1_len: str tile_total: str batch_var: str valid_var: str fb_batch_var: str major_var: str idx0_var: str idx1_var: str body: Op