'''HIR Pass 1: Stratification.
Input: Program (DSL-level: relations + rules)
Output: HirProgram with strata populated (scc_members, is_recursive, stratum_rules)
Mirrors src/srdatalog/hir/stratification.nim. Key behavior: for a recursive SCC
that has some rules with no SCC dependency in their body (true base case) and
some with SCC dependency (recursive step), we emit TWO separate strata -- base
first, then recursive -- so the base runs once before the fixpoint loop.
The subsequent fusion pass merges consecutive non-recursive strata that have no
inter-dependency, enabling parallel execution at the MIR/codegen level.
NOTE: Python's set iteration order depends on the hash-random seed across runs.
To keep SCC ordering reproducible, we sort every set iteration that feeds into
downstream ordering (dependency edges, SCC membership iteration).
'''
from __future__ import annotations
from srdatalog.dsl import Agg, Atom, Negation, Rule
from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.types import HirProgram, HirStratum, RelationDecl
def _body_relations(rule: Rule) -> set[str]:
'''Relation names referenced in a rule's body (Atom, Negation, Agg).
Filter / Let clauses don't reference a relation — they're inline
predicates / bindings — and are skipped.
'''
out: set[str] = set()
for b in rule.body:
if isinstance(b, Negation):
out.add(b.atom.rel)
elif isinstance(b, Atom) or isinstance(b, Agg):
out.add(b.rel)
# Filter / Let: no relation reference, skip.
return out
def _build_dep_graph(rules: list[Rule]) -> dict[str, set[str]]:
'''head_rel -> set of body rels (only IDBs: rels that appear as some rule's head).
Multi-head rules contribute every head_rel -> body-rels edge.'''
idbs: set[str] = {h.rel for r in rules for h in r.heads}
graph: dict[str, set[str]] = {rel: set() for rel in idbs}
for r in rules:
body_idbs = {b for b in _body_relations(r) if b in idbs}
for h in r.heads:
graph[h.rel] |= body_idbs
return graph
def _compute_sccs(rules: list[Rule], graph: dict[str, set[str]]) -> list[set[str]]:
'''Tarjan's SCC. Returns SCCs in reverse topological order (bottom-up).
Starting nodes are iterated in rule-definition order, and neighbors in
sorted order, so output is reproducible across runs.
'''
index: dict[str, int] = {}
lowlink: dict[str, int] = {}
on_stack: set[str] = set()
stack: list[str] = []
counter = [0]
sccs: list[set[str]] = []
def strongconnect(v: str) -> None:
index[v] = counter[0]
lowlink[v] = counter[0]
counter[0] += 1
stack.append(v)
on_stack.add(v)
for w in sorted(graph.get(v, set())):
if w not in index:
strongconnect(w)
lowlink[v] = min(lowlink[v], lowlink[w])
elif w in on_stack:
lowlink[v] = min(lowlink[v], index[w])
if lowlink[v] == index[v]:
new_scc: set[str] = set()
while True:
w = stack.pop()
on_stack.discard(w)
new_scc.add(w)
if w == v:
break
sccs.append(new_scc)
seen: set[str] = set()
for r in rules:
for h in r.heads:
name = h.rel
if name in graph and name not in seen:
seen.add(name)
if name not in index:
strongconnect(name)
return sccs
def _merge_sccs_for_multi_head(sccs: list[set[str]], rules: list[Rule]) -> list[set[str]]:
'''Union SCCs that share a multi-head rule. Mirrors
mergeSCCsForMultiHeadRules in src/srdatalog/hir/stratification.nim.
A rule with N>1 heads must emit all heads from a single pipeline, so
every head-rel's SCC is fused into one.
'''
rel_to_scc: dict[str, int] = {}
for i, scc in enumerate(sccs):
for rel in scc:
rel_to_scc[rel] = i
parent = list(range(len(sccs)))
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
for rule in rules:
if len(rule.heads) > 1:
head_sccs = [rel_to_scc[h.rel] for h in rule.heads if h.rel in rel_to_scc]
for i in range(1, len(head_sccs)):
union(head_sccs[0], head_sccs[i])
# Rebuild, preserving original order by first-occurrence of each root.
merged: dict[int, set[str]] = {}
root_order: list[int] = []
for i, scc in enumerate(sccs):
root = find(i)
if root not in merged:
merged[root] = set()
root_order.append(root)
merged[root] |= scc
return [merged[r] for r in root_order]
def _is_recursive_scc(scc: set[str], graph: dict[str, set[str]]) -> bool:
if len(scc) > 1:
return True
for rel in scc: # singleton
if rel in graph.get(rel, set()):
return True
return False
def _rules_for_scc(rules: list[Rule], scc: set[str]) -> list[Rule]:
return [r for r in rules if any(h.rel in scc for h in r.heads)]
def _split_base_rec(scc_rules: list[Rule], scc: set[str]) -> tuple[list[Rule], list[Rule]]:
'''Partition SCC rules into (base, recursive) by whether the body references
any SCC member. Base rules have no SCC dependency; they can run once before
the fixpoint loop.
'''
base, rec = [], []
for r in scc_rules:
if _body_relations(r) & scc:
rec.append(r)
else:
base.append(r)
return base, rec
def _stratum_depends_on(s: HirStratum, produced: set[str]) -> bool:
for r in s.stratum_rules:
if _body_relations(r) & produced:
return True
return False
def _fuse_independent_strata(strata: list[HirStratum]) -> list[HirStratum]:
'''Merge consecutive non-recursive strata with no inter-dependency.
A non-recursive stratum that is the BASE of a recursive SCC is never fused
(it must stay pinned right before its recursive sibling so the evaluator
runs them in sequence).
'''
if len(strata) <= 1:
return strata
recursive_members: set[str] = set()
for s in strata:
if s.is_recursive:
recursive_members.update(s.scc_members)
def is_fusable(s: HirStratum) -> bool:
if s.is_recursive:
return False
return not (s.scc_members & recursive_members)
fused: list[HirStratum] = []
i = 0
while i < len(strata):
current = strata[i]
if not is_fusable(current):
fused.append(current)
i += 1
continue
produced = set(current.scc_members)
j = i + 1
while j < len(strata):
nxt = strata[j]
if not is_fusable(nxt):
break
if _stratum_depends_on(nxt, produced):
break
current.scc_members.update(nxt.scc_members)
current.stratum_rules.extend(nxt.stratum_rules)
if nxt.before_hook:
current.before_hook += nxt.before_hook
if nxt.after_hook:
current.after_hook += nxt.after_hook
current.is_generated = current.is_generated and nxt.is_generated
produced = set(current.scc_members)
j += 1
fused.append(current)
i = j
return fused
[docs]
def stratify(rules: list[Rule], decls: list[RelationDecl]) -> HirProgram:
'''HIR Pass 1 entry point. Takes (rules, decls) after any rule-rewrite
passes; produces a HirProgram with strata populated. Mirrors stratify()
in src/srdatalog/hir/stratification.nim.
This is the fixed entry of the HIR pipeline — it's the point where
(rules, decls) becomes HirProgram. See `hir_pass.Pipeline.compile_to_hir`.
'''
graph = _build_dep_graph(rules)
sccs = _compute_sccs(rules, graph)
sccs = _merge_sccs_for_multi_head(sccs, rules)
strata: list[HirStratum] = []
for scc in sccs:
scc_rules = _rules_for_scc(rules, scc)
if not scc_rules:
continue
if _is_recursive_scc(scc, graph):
base_rules, rec_rules = _split_base_rec(scc_rules, scc)
if base_rules:
strata.append(
HirStratum(scc_members=set(scc), is_recursive=False, stratum_rules=base_rules)
)
if rec_rules:
strata.append(HirStratum(scc_members=set(scc), is_recursive=True, stratum_rules=rec_rules))
else:
strata.append(HirStratum(scc_members=set(scc), is_recursive=False, stratum_rules=scc_rules))
strata = _fuse_independent_strata(strata)
return HirProgram(strata=strata, relation_decls=decls)
# -----------------------------------------------------------------------------
# Pass wrapper — lets the Pipeline treat stratification uniformly even though
# it's signature-wise distinct (it's the HIR entry, not a transform).
# -----------------------------------------------------------------------------
[docs]
class StratificationPass:
info = PassInfo(
name="Stratification",
level=PassLevel.HIR_TRANSFORM,
order=0,
source_dialect=IRLevel.HIR,
target_dialect=IRLevel.HIR,
)
[docs]
def run(self, rules: list[Rule], decls: list[RelationDecl]) -> HirProgram:
return stratify(rules, decls)