Source code for srdatalog.ir.hir.plan

'''HIR Pass 4: Join Planning.

For each variant, compute:
  - clause_order: execution order of body clauses (heuristic: deps → delta
    priority → max-overlap with bound vars)
  - var_order: variable binding order (join vars first, starting from the
    delta clause for recursive rules; then remaining join vars in clause
    order; then non-join vars in clause order)
  - access_patterns: per-positive-clause AccessPattern (rel, version,
    access_order, prefix_len, index_cols, clause_idx)
  - negation_patterns: per-negation-clause AccessPattern (version forced
    to FULL)

Mirrors src/srdatalog/hir/join_planner.nim. Not yet ported: user-provided
rule.plans (DSL doesn't support them), split clauses (SplitClause),
balanced partitioning pragmas, IfClause/LetClause/AggClause handling.
'''

from __future__ import annotations

from dataclasses import dataclass, field

from srdatalog.dsl import Agg, ArgKind, Atom, Filter, Let, Negation, PlanEntry, Rule, Split
from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.types import AccessPattern, HirProgram, HirRuleVariant, Version

# -----------------------------------------------------------------------------
# Rule Analysis
# -----------------------------------------------------------------------------


[docs] @dataclass class RuleAnalysis: vars: set[str] = field(default_factory=set) clause_vars: list[set[str]] = field(default_factory=list) join_vars: set[str] = field(default_factory=set) # appear in >1 POSITIVE clause head_vars: set[str] = field(default_factory=set)
[docs] def analyze_rule(rule: Rule) -> RuleAnalysis: r = RuleAnalysis() positive_count: dict[str, int] = {} for body in rule.body: cvars: set[str] = set() if isinstance(body, Negation): # Negation vars are tracked but NOT counted for join_vars (it's a filter, not a join). for arg in body.atom.args: if arg.kind is ArgKind.LVAR and arg.var_name is not None: cvars.add(arg.var_name) r.vars.add(arg.var_name) elif isinstance(body, Atom): for arg in body.args: if arg.kind is ArgKind.LVAR and arg.var_name is not None: cvars.add(arg.var_name) r.vars.add(arg.var_name) positive_count[arg.var_name] = positive_count.get(arg.var_name, 0) + 1 elif isinstance(body, Agg): # Aggregate args count as positive (like RelClause), plus the # result_var is a produced positive var. Mirrors Nim analyzeRule. for arg in body.args: if arg.kind is ArgKind.LVAR and arg.var_name is not None: cvars.add(arg.var_name) r.vars.add(arg.var_name) positive_count[arg.var_name] = positive_count.get(arg.var_name, 0) + 1 cvars.add(body.result_var) r.vars.add(body.result_var) positive_count[body.result_var] = positive_count.get(body.result_var, 0) + 1 # Filter / Let: no relation args to analyze; leave cvars empty. r.clause_vars.append(cvars) for head in rule.heads: for arg in head.args: if arg.kind is ArgKind.LVAR and arg.var_name is not None: r.head_vars.add(arg.var_name) for v, c in positive_count.items(): if c > 1: r.join_vars.add(v) return r
# ----------------------------------------------------------------------------- # Clause ordering heuristic # ----------------------------------------------------------------------------- def _clause_lvar_names(body) -> list[str]: '''Ordered LVar names for a body clause. Column order for Atom/Negation; for Filter the `vars` field (what the filter reads); for Let the single `var_name` (what it binds); for Agg: all arg LVars then the result_var last (matches Nim's extractClauseVars). Split contributes nothing. ''' if isinstance(body, Atom): return [a.var_name for a in body.args if a.kind is ArgKind.LVAR and a.var_name is not None] if isinstance(body, Negation): return [a.var_name for a in body.atom.args if a.kind is ArgKind.LVAR and a.var_name is not None] if isinstance(body, Filter): return list(body.vars) if isinstance(body, Let): return [body.var_name] if isinstance(body, Agg): names = [a.var_name for a in body.args if a.kind is ArgKind.LVAR and a.var_name is not None] names.append(body.result_var) return names return [] # Split, unknown def _get_dependencies(body) -> set[str]: '''Vars that must be bound before this clause is runnable. - Atom, Agg: no deps (both are generators). - Negation: safe-negation -> all args must be bound. - Filter: every var it references. - Let: every var in its `deps` list (NOT the var it binds). ''' if isinstance(body, Negation): return {a.var_name for a in body.atom.args if a.kind is ArgKind.LVAR and a.var_name is not None} if isinstance(body, Filter): return set(body.vars) if isinstance(body, Let): return set(body.deps) return set() def _get_produced_vars(body) -> set[str]: '''Vars newly bound by this clause. - Atom: all its LVar args. - Let / Agg: the single bound var (`var_name` / `result_var`). - Negation / Filter: nothing. ''' if isinstance(body, Atom): return {a.var_name for a in body.args if a.kind is ArgKind.LVAR and a.var_name is not None} if isinstance(body, Let): return {body.var_name} if isinstance(body, Agg): return {body.result_var} return set()
[docs] def compute_clause_order(rule: Rule, delta_idx: int = -1) -> list[int]: '''Pick body clause execution order using the Nim heuristic: 1. Dependency-gated runnable set (sorted by source idx for tie-breaking). 2. Delta clause first when runnable (for recursive variants). 3. Max overlap of clause vars with currently-bound vars. ''' bound: set[str] = set() scheduled: list[int] = [] remaining: set[int] = set(range(len(rule.body))) while remaining: runnable = sorted(i for i in remaining if _get_dependencies(rule.body[i]) <= bound) if not runnable: # Deadlock fallback (shouldn't happen for stratified DSL input, but # matches Nim behavior): prefer an Atom, else lowest index. atoms = [i for i in sorted(remaining) if isinstance(rule.body[i], Atom)] runnable = [atoms[0] if atoms else sorted(remaining)[0]] # Priority 2.1: Filters get first crack so the planner pushes selection # DOWN (drops ineligible bindings ASAP). best = -1 for r in runnable: if isinstance(rule.body[r], Filter): best = r break # Priority 2.2: delta clause (seed for recursive variants). if best == -1 and delta_idx in runnable: best = delta_idx elif best == -1: # Priority 2.3: max overlap of clause args with already-bound vars. max_overlap = -1 for r in runnable: body = rule.body[r] if isinstance(body, Atom): overlap = sum(1 for v in _clause_lvar_names(body) if v in bound) else: overlap = 0 if overlap > max_overlap: max_overlap = overlap best = r if best == -1: best = runnable[0] scheduled.append(best) remaining.discard(best) bound.update(_get_produced_vars(rule.body[best])) return scheduled
# ----------------------------------------------------------------------------- # Variable ordering heuristic # -----------------------------------------------------------------------------
[docs] def compute_var_order_from_clauses( rule: Rule, clause_order: list[int], join_vars: set[str], delta_idx: int = -1 ) -> list[str]: '''Derive variable binding order from clause execution order. Pass 1 (recursive variants only): join vars appearing in the delta clause. Pass 2: remaining join vars, in clauseOrder. Pass 3: non-join vars, in clauseOrder. ''' result: list[str] = [] seen: set[str] = set() if 0 <= delta_idx < len(rule.body): for v in _clause_lvar_names(rule.body[delta_idx]): if v in join_vars and v not in seen: seen.add(v) result.append(v) for idx in clause_order: for v in _clause_lvar_names(rule.body[idx]): if v in join_vars and v not in seen: seen.add(v) result.append(v) for idx in clause_order: for v in _clause_lvar_names(rule.body[idx]): if v not in join_vars and v not in seen: seen.add(v) result.append(v) return result
# ----------------------------------------------------------------------------- # Access pattern computation # -----------------------------------------------------------------------------
[docs] def compute_access_pattern( body, version: Version, join_vars: set[str], var_order: list[str], clause_idx: int ) -> AccessPattern: '''Build the AccessPattern for one body clause (Atom or Negation). - access_order = var_order filtered to clause vars, plus any remaining clause vars appended in a deterministic (sorted) order. - prefix_len: Atoms → count of join vars; Negations → count of non-wildcard vars. - index_cols: maps access_order to column positions; then completes to full arity. - For Negations: prepend constant-column indices and force version=FULL (caller). ''' if isinstance(body, (Atom, Negation)): atom_args = body.atom.args if isinstance(body, Negation) else body.args rel = body.atom.rel if isinstance(body, Negation) else body.rel else: # IfClause/LetClause etc. — not supported yet. return AccessPattern(rel_name="", version=version, clause_idx=clause_idx) clause_vars: set[str] = set() const_args: list[tuple[int, int]] = [] for col, arg in enumerate(atom_args): if arg.kind is ArgKind.LVAR and arg.var_name is not None: clause_vars.add(arg.var_name) elif arg.kind is ArgKind.CONST and arg.const_value is not None: const_args.append((col, arg.const_value)) access_order: list[str] = [v for v in var_order if v in clause_vars] seen = set(access_order) # Remaining clause vars (shouldn't happen for auto-generated var_order, but # can with user-provided plans that have wildcard holes). Sort for determinism. for v in sorted(clause_vars - seen): access_order.append(v) if isinstance(body, Atom): prefix_len = sum(1 for v in access_order if v in join_vars) else: # Negation: wildcard vars (starting with "_") don't count toward prefix. prefix_len = sum(1 for v in access_order if not v.startswith("_")) # index_cols: position-in-atom for each var in access_order, then any # missing columns appended to complete full arity. index_cols: list[int] = [] for v in access_order: for col, arg in enumerate(atom_args): if arg.kind is ArgKind.LVAR and arg.var_name == v: index_cols.append(col) break for col in range(len(atom_args)): if col not in index_cols: index_cols.append(col) if isinstance(body, Negation): const_cols = [c for c, _ in const_args] index_cols = const_cols + [c for c in index_cols if c not in const_cols] return AccessPattern( rel_name=rel, version=version, access_order=access_order, index_cols=index_cols, prefix_len=prefix_len, clause_idx=clause_idx, const_args=const_args, )
# ----------------------------------------------------------------------------- # Derive clause order from user-provided var_order (when the plan omits it) # -----------------------------------------------------------------------------
[docs] def derive_clause_order_from_var_order( rule: Rule, var_order: list[str], delta_idx: int = -1 ) -> list[int]: '''Mirror deriveClauseOrderFromVarOrder in join_planner.nim. Walks var_order left to right; for each not-yet-bound variable, picks a runnable clause that introduces it (preferring the delta clause when applicable; source-order ties otherwise). Any unscheduled clauses (filters, disconnected negations) are appended at the end in a runnable sweep. ''' scheduled: list[int] = [] remaining: set[int] = set(range(len(rule.body))) bound: set[str] = set() def clause_vars(body) -> set[str]: return set(_clause_lvar_names(body)) def can_schedule(idx: int) -> bool: return _get_dependencies(rule.body[idx]) <= bound for target in var_order: if target in bound: continue candidates = [ idx for idx in remaining if can_schedule(idx) and target in clause_vars(rule.body[idx]) ] if not candidates: continue if delta_idx in candidates: picked = delta_idx else: candidates.sort() picked = candidates[0] scheduled.append(picked) remaining.discard(picked) bound.update(clause_vars(rule.body[picked])) # Sweep remaining (filters/negations not referenced by var_order). # Prefer Filter clauses first (push-down policy consistent with the main # heuristic). Nim's sweep iterates a HashSet whose hash-bucket order # happens to surface Filter before Negation in common cases; the # explicit priority here matches that behavior deterministically. while remaining: scheduled_this_round = False for idx in sorted(remaining): if can_schedule(idx) and isinstance(rule.body[idx], Filter): scheduled.append(idx) remaining.discard(idx) bound.update(clause_vars(rule.body[idx])) scheduled_this_round = True break if scheduled_this_round: continue for idx in sorted(remaining): if can_schedule(idx): scheduled.append(idx) remaining.discard(idx) bound.update(clause_vars(rule.body[idx])) scheduled_this_round = True break if not scheduled_this_round: # Deadlock fallback: force-pick lowest remaining. idx = min(remaining) scheduled.append(idx) remaining.discard(idx) bound.update(clause_vars(rule.body[idx])) return scheduled
# ----------------------------------------------------------------------------- # Main entry # ----------------------------------------------------------------------------- def _find_plan(rule: Rule, delta_idx: int) -> PlanEntry | None: for p in rule.plans: if p.delta == delta_idx: return p return None
[docs] def detect_split(rule: Rule) -> int: '''Body index of the Split clause, or -1 if absent.''' for i, clause in enumerate(rule.body): if isinstance(clause, Split): return i return -1
[docs] def compute_temp_vars(rule: Rule, split_at: int) -> list[str]: '''Variables that cross a rule's split boundary: bound above the split AND used below (body clauses or head). Mirrors Nim's computeTempVars. Order: first the vars shared with below-split body clauses (join vars for Pipeline B), then head-only vars. Within each group, vars are walked in clause-walk insertion order (first occurrence in clauses 0..split_at-1, position-by-position). Why insertion order: Nim uses `HashSet[string]` iteration which is hash-bucket order (`hash & (cap-1)`, ascending slot, with linear probe on collision). For variable names that fit our codebase's typical mnemonic style (e.g., `blk`, `blockUsed`, `varr`, `varp`), the FarmHash slots happen to come out in roughly insertion order. Insertion order is byte-equivalent to Nim's hash-bucket order on every fixture in the test set, more deterministic across Python / Nim runs, and avoids needing to bit-port Nim's `hashFarm` for strings. F1 fix; see ddisasm StackLiveVarBlockEnd1_D0_split{A,B}. ''' vars_above_ordered: list[str] = [] vars_above_seen: set[str] = set() for i in range(split_at): clause = rule.body[i] if isinstance(clause, (Atom, Negation)): for v in _clause_lvar_names(clause): if v not in vars_above_seen: vars_above_ordered.append(v) vars_above_seen.add(v) below_body_vars: set[str] = set() for i in range(split_at + 1, len(rule.body)): clause = rule.body[i] if isinstance(clause, Atom): # Nim only counts RelClause for belowBodyVars below_body_vars.update(_clause_lvar_names(clause)) vars_below: set[str] = set(below_body_vars) # Also include head args (every head for multi-head rules) for head in rule.heads: for a in head.args: if a.kind is ArgKind.LVAR and a.var_name is not None: vars_below.add(a.var_name) # Include negation clauses' vars from below (Nim does this) for i in range(split_at + 1, len(rule.body)): clause = rule.body[i] if isinstance(clause, Negation): vars_below.update(_clause_lvar_names(clause)) result: list[str] = [] # Pass 1: vars_above ∩ below_body_vars ∩ vars_below (join vars) for v in vars_above_ordered: if v in below_body_vars and v in vars_below: result.append(v) # Pass 2: vars_above ∩ vars_below minus below_body_vars (head-only) for v in vars_above_ordered: if v in vars_below and v not in below_body_vars: result.append(v) return result
def _plan_variant(v: HirRuleVariant) -> None: rule = v.original_rule analysis = analyze_rule(rule) d = v.delta_idx # -1 for base variants plan = _find_plan(rule, d) if plan is not None and plan.var_order: var_order = list(plan.var_order) if plan.clause_order: clause_order = list(plan.clause_order) else: clause_order = derive_clause_order_from_var_order(rule, var_order, delta_idx=d) else: clause_order = compute_clause_order(rule, delta_idx=d) var_order = compute_var_order_from_clauses(rule, clause_order, analysis.join_vars, delta_idx=d) # Propagate pragma flags whenever a plan is attached — the pragma # branch above is gated on `plan.var_order`, but pragmas like # `work_stealing: true` frequently appear on plan entries that have # no custom var_order (e.g. Polonius subset_trans). Using the flags # regardless of var_order matches Nim's planJoins which copies the # pragma set unconditionally. if plan is not None: v.fanout = plan.fanout v.work_stealing = plan.work_stealing v.block_group = plan.block_group v.dedup_hash = plan.dedup_hash v.balanced_root = list(plan.balanced_root) v.balanced_sources = list(plan.balanced_sources) v.count = rule.count v.clause_order = clause_order v.var_order = var_order v.join_vars = analysis.join_vars # Split-rule metadata. Auto-detect the `split` marker in the body and # attach temp-rel fields so downstream passes (temp decl synthesis, # temp index registration, split-aware lowering) can key off them. split_at = detect_split(rule) v.split_at = split_at if split_at >= 0 and rule.name: v.temp_vars = compute_temp_vars(rule, split_at) v.temp_rel_name = f"_temp_{rule.name}" for k, body in enumerate(rule.body): version = v.clause_versions[k] pattern = compute_access_pattern(body, version, analysis.join_vars, var_order, k) if not pattern.rel_name: continue if isinstance(body, Negation): pattern.version = Version.FULL v.negation_patterns.append(pattern) else: v.access_patterns.append(pattern)
[docs] def plan_joins(hir: HirProgram) -> HirProgram: '''HIR Pass 4 entry. Mutates and returns the HirProgram.''' for stratum in hir.strata: for v in stratum.base_variants: _plan_variant(v) for v in stratum.recursive_variants: _plan_variant(v) return hir
[docs] class JoinPlannerPass: info = PassInfo( name="JoinPlanning", level=PassLevel.HIR_TRANSFORM, order=200, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, hir: HirProgram) -> HirProgram: return plan_joins(hir)