Source code for srdatalog.ir.hir.rule_rewrite

'''Rule-rewrite passes (HIR Pass 0 / Pass 1). Mirror
src/srdatalog/hir/constant_rewriting.nim and head_constant_rewriting.nim.

Semi-join optimization (Pass 1.5) is 385 lines and is deferred to a
later turn; it doesn't fire on simple programs anyway.

Counters are per-pass-invocation (reset each call). Nim uses a
`compileTime` counter that persists across a single macro compilation —
for fixtures the tool compiles once, so starting at 0 matches.
'''

from __future__ import annotations

import dataclasses

from srdatalog.dsl import (
  ArgKind,
  Atom,
  ClauseArg,
  Filter,
  Let,
  Negation,
  PlanEntry,
  Rule,
)
from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.provenance import compiler_gen
from srdatalog.ir.hir.types import RelationDecl


def _atom_has_const(atom: Atom) -> bool:
  return any(a.kind is ArgKind.CONST for a in atom.args)


[docs] def rewrite_wildcards( rules: list[Rule], decls: list[RelationDecl], ) -> tuple[list[Rule], list[RelationDecl]]: '''Pre-pass: gensym each `_`-named LVAR to `_gen<N>`. Mirrors `parseClauseArg` in src/srdatalog/syntax.nim where, at macro parse time, every `_` (or `_`-prefixed) identifier is replaced with a fresh unique name drawn from a monotonically increasing counter. Two `_` slots in the same atom must be independent variables, which requires a fresh name per occurrence. Counter starts at 1 to match Nim's `wildcardCounter.inc` ordering. ''' counter = 0 def rewrite_arg(arg: ClauseArg) -> ClauseArg: nonlocal counter if arg.kind is ArgKind.LVAR and arg.var_name and arg.var_name.startswith("_"): counter += 1 return dataclasses.replace(arg, var_name=f"_gen{counter}") return arg def rewrite_atom(atom: Atom) -> Atom: return dataclasses.replace(atom, args=tuple(rewrite_arg(a) for a in atom.args)) new_rules: list[Rule] = [] for rule in rules: new_heads = tuple(rewrite_atom(h) for h in rule.heads) new_body: list = [] for clause in rule.body: if isinstance(clause, Atom): new_body.append(rewrite_atom(clause)) elif isinstance(clause, Negation): new_body.append(dataclasses.replace(clause, atom=rewrite_atom(clause.atom))) else: new_body.append(clause) new_rules.append(dataclasses.replace(rule, heads=new_heads, body=tuple(new_body))) return new_rules, decls
[docs] def rewrite_constants( rules: list[Rule], decls: list[RelationDecl], ) -> tuple[list[Rule], list[RelationDecl]]: '''Pass 0: body-constant rewriting. For each positive body clause with constant args, replace every Const with a fresh LVar (_cN) and append a Filter clause asserting equality immediately after the rewritten atom. Matches Nim's output byte-for-byte when fresh-name counters start at 0 (default). ''' counter = 0 new_rules: list[Rule] = [] for rule in rules: new_body: list = [] for clause in rule.body: if isinstance(clause, Atom) and _atom_has_const(clause): new_args: list[ClauseArg] = [] filter_vars: list[str] = [] filter_parts: list[str] = [] for arg in clause.args: if arg.kind is ArgKind.CONST: fresh = f"_c{counter}" counter += 1 new_args.append(ClauseArg(kind=ArgKind.LVAR, var_name=fresh)) filter_vars.append(fresh) filter_parts.append(f"{fresh} == {arg.const_cpp_expr}") else: new_args.append(arg) new_body.append(Atom(rel=clause.rel, args=tuple(new_args))) filter_code = "return " + " && ".join(filter_parts) + ";" new_body.append(Filter(vars=tuple(filter_vars), code=filter_code)) else: new_body.append(clause) new_rules.append(dataclasses.replace(rule, body=tuple(new_body))) return new_rules, decls
[docs] def rewrite_head_constants( rules: list[Rule], decls: list[RelationDecl], ) -> tuple[list[Rule], list[RelationDecl]]: '''Pass 1: head-constant rewriting. For each head with constant args, replace each Const with a fresh LVar (_hcN) and append a Let clause to the body binding that variable to the constant's C++ expression. Runs AFTER body-constant rewriting. ''' counter = 0 new_rules: list[Rule] = [] for rule in rules: extra_body: list[Let] = [] new_heads: list[Atom] = [] needs_rewrite = False for head in rule.heads: new_head_args: list[ClauseArg] = [] for arg in head.args: if arg.kind is ArgKind.CONST: fresh = f"_hc{counter}" counter += 1 new_head_args.append(ClauseArg(kind=ArgKind.LVAR, var_name=fresh)) assert arg.const_cpp_expr is not None extra_body.append(Let(var_name=fresh, code=arg.const_cpp_expr, deps=())) needs_rewrite = True else: new_head_args.append(arg) new_heads.append(Atom(rel=head.rel, args=tuple(new_head_args))) if needs_rewrite: new_body = tuple(rule.body) + tuple(extra_body) new_rules.append(dataclasses.replace(rule, heads=tuple(new_heads), body=new_body)) else: new_rules.append(rule) return new_rules, decls
# ----------------------------------------------------------------------------- # Pipeline pass wrappers # -----------------------------------------------------------------------------
[docs] class WildcardRewritePass: info = PassInfo( name="WildcardRewrite", level=PassLevel.RULE_REWRITE, order=-1, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, rules, decls): return rewrite_wildcards(rules, decls)
[docs] class ConstantRewritePass: info = PassInfo( name="ConstantRewrite", level=PassLevel.RULE_REWRITE, order=0, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, rules, decls): return rewrite_constants(rules, decls)
[docs] class HeadConstantRewritePass: info = PassInfo( name="HeadConstantRewrite", level=PassLevel.RULE_REWRITE, order=1, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, rules, decls): return rewrite_head_constants(rules, decls)
# ----------------------------------------------------------------------------- # Semi-join optimization (Pass 1.5 in Nim's compileToHir) # ----------------------------------------------------------------------------- def _clause_vars(clause) -> set[str]: '''Var names in an Atom or Negation; empty for Filter/Let.''' if isinstance(clause, Atom): return {a.var_name for a in clause.args if a.kind is ArgKind.LVAR and a.var_name is not None} if isinstance(clause, Negation): return { a.var_name for a in clause.atom.args if a.kind is ArgKind.LVAR and a.var_name is not None } return set() def _lvar_name(arg: ClauseArg) -> str: if arg.kind is ArgKind.LVAR and arg.var_name is not None: return arg.var_name return "" def _is_semi_join_candidate(filt, target) -> bool: '''filt is a semi-join filter for target iff its var set is a proper subset of target's and both are positive relation clauses. ''' if not isinstance(filt, Atom) or not isinstance(target, Atom): return False fvars = _clause_vars(filt) tvars = _clause_vars(target) if len(fvars) == 0 or len(fvars) >= len(tvars): return False return fvars <= tvars
[docs] def optimize_semi_joins( rules: list[Rule], decls: list[RelationDecl], ) -> tuple[list[Rule], list[RelationDecl]]: '''Single pass. Rule must have semi_join=True and body.len > 2 to be considered. For each candidate target/filter pair, synthesise an `_SJ_Target_Filter_<keptIndices>` relation whose generator rule is `Target ⋈ Filter`, and replace the target clause (dropping the filter clause) in the original rule. Mirrors src/srdatalog/hir/semi_join_optimization.nim. Fixed-point iteration lives in SemiJoinPass.run (Nim does it in compileToHir). ''' new_rules: list[Rule] = [] generated_rules: list[Rule] = [] generated_decls: list[RelationDecl] = [] generated_cache: set[str] = set() decls_map = {d.rel_name: d for d in decls} for rule in rules: if not rule.semi_join or len(rule.body) <= 2: new_rules.append(rule) continue body = list(rule.body) replaced: set[int] = set() # body indices removed rewrites: dict[int, tuple[Atom, int]] = {} # target_idx -> (new_atom, filter_idx) # Pass 1: find each target -> filter opportunity. for i, target in enumerate(body): if not isinstance(target, Atom): continue if target.rel.startswith("_SJ_"): continue # never double-optimise for j, filt in enumerate(body): if i == j or not _is_semi_join_candidate(filt, target): continue assert isinstance(filt, Atom) # _is_semi_join_candidate guarantees this filter_vars = _clause_vars(filt) # Vars shared with head or other body clauses. shared: set[str] = set() for a in rule.head.args: v = _lvar_name(a) if v in filter_vars: shared.add(v) for k, other in enumerate(body): if k in (i, j): continue for v in filter_vars: if v in _clause_vars(other): shared.add(v) # Keep target columns whose var is NOT a filter-only var. kept_idx: list[int] = [] for arg_idx, arg in enumerate(target.args): v = _lvar_name(arg) if v not in filter_vars or v in shared: kept_idx.append(arg_idx) suffix = "".join(f"_{k}" for k in kept_idx) new_rel_name = f"_SJ_{target.rel}_{filt.rel}{suffix}" # Generate decl + rule only once for this relation name. if new_rel_name not in generated_cache: generated_cache.add(new_rel_name) # Semiring / types inherited from target (if available). semiring = "NoProvenance" types: list[str] = [] if target.rel in decls_map: orig = decls_map[target.rel] semiring = orig.semiring for k in kept_idx: if k < len(orig.types): types.append(orig.types[k]) generated_decls.append( RelationDecl( rel_name=new_rel_name, types=types, semiring=semiring, is_generated=True, ) ) fresh = [f"v{k}" for k in range(len(target.args))] # _SJ_X(v_kept...) :- Target(v0..vN), Filter(filter_args@target_pos). gen_head = Atom( rel=new_rel_name, args=tuple(ClauseArg(kind=ArgKind.LVAR, var_name=fresh[k]) for k in kept_idx), ) gen_target = Atom( rel=target.rel, args=tuple( ClauseArg(kind=ArgKind.LVAR, var_name=fresh[k]) for k in range(len(target.args)) ), ) # Map each filter arg to its corresponding target column's fresh var. filter_args: list[ClauseArg] = [] for f_arg in filt.args: fv = _lvar_name(f_arg) for ti, t_arg in enumerate(target.args): if _lvar_name(t_arg) == fv: filter_args.append( ClauseArg( kind=ArgKind.LVAR, var_name=fresh[ti], ) ) break gen_filter = Atom(rel=filt.rel, args=tuple(filter_args)) prov = compiler_gen( parent_rule=rule.name or "", derived_from=target.rel, transform_pass="semi_join", ) generated_rules.append( Rule( heads=(gen_head,), body=(gen_target, gen_filter), name=f"{new_rel_name}_Gen", is_generated=True, prov=prov, ) ) # Build the replacement clause for the original rule. prov = compiler_gen( parent_rule=rule.name or "", derived_from=target.rel, transform_pass="semi_join", ) new_atom = Atom( rel=new_rel_name, args=tuple(target.args[k] for k in kept_idx), prov=prov, ) rewrites[i] = (new_atom, j) replaced.add(j) break # one opt per target # Pass 2: build the rewritten body + clause-index mapping. new_body: list = [] idx_map: dict[int, int] = {} filter_to_target = {fi: ti for ti, (_, fi) in rewrites.items()} next_idx = 0 for i, clause in enumerate(body): if i in replaced: # Filter removed — but it'll map to the _SJ_ index the target now owns. idx_map[i] = -2 if i in filter_to_target else -1 continue if i in rewrites: idx_map[i] = next_idx fi = rewrites[i][1] if fi in idx_map and idx_map[fi] == -2: idx_map[fi] = next_idx next_idx += 1 new_body.append(rewrites[i][0]) else: idx_map[i] = next_idx next_idx += 1 new_body.append(clause) # Pass 3: translate plans (delta / clause_order / var_order). body_vars: set[str] = set() for c in new_body: body_vars |= _clause_vars(c) translated_plans: list[PlanEntry] = [] for plan in rule.plans: new_delta = plan.delta if plan.delta >= 0: if plan.delta not in idx_map: continue mapped = idx_map[plan.delta] if mapped < 0: continue new_delta = mapped new_clause_order = tuple( idx_map[k] for k in plan.clause_order if k in idx_map and idx_map[k] >= 0 ) new_var_order = tuple(v for v in plan.var_order if v in body_vars) translated_plans.append( dataclasses.replace( plan, delta=new_delta, clause_order=new_clause_order, var_order=new_var_order, ) ) new_rules.append( dataclasses.replace( rule, body=tuple(new_body), plans=tuple(translated_plans), ) ) # Generated rules go FIRST so stratification schedules them upstream of # their consumers. Generated decls append to the existing list. return generated_rules + new_rules, list(decls) + generated_decls
[docs] class SemiJoinPass: '''Runs `optimize_semi_joins` to a fixed point. Nim handles the outer loop in compileToHir; we keep it inside the pass for symmetry with the other rule-rewrites. ''' info = PassInfo( name="SemiJoinOpt", level=PassLevel.RULE_REWRITE, order=2, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, rules, decls): while True: prev = len(rules) rules, decls = optimize_semi_joins(rules, decls) if len(rules) == prev: return rules, decls