Source code for srdatalog.ir.hir.semi_naive
'''HIR Pass 3: Semi-Naive Variant Generation.
Populates `base_variants` / `recursive_variants` on each HirStratum.
For each stratum:
- Non-recursive: one HirRuleVariant per rule (base, delta_idx=-1).
- Recursive: for each rule, one DELTA variant per body clause whose
relation is an SCC member. The variant pins `delta_idx` to that clause
index and sets `clause_versions[idx] = DELTA` (others FULL).
Mirrors `generateVariants` in src/srdatalog/hir/semi_naive.nim.
Negations are NOT delta candidates (only positive occurrences of SCC
members get incrementalized).
'''
from __future__ import annotations
from srdatalog.dsl import Atom, Rule
from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.types import HirProgram, HirRuleVariant, Version
[docs]
def find_scc_clause_indices(rule: Rule, scc_members: set[str]) -> list[int]:
'''Indices of positive body clauses (Atoms) whose relation is in scc_members.'''
return [i for i, b in enumerate(rule.body) if isinstance(b, Atom) and b.rel in scc_members]
[docs]
def create_base_variant(rule: Rule) -> HirRuleVariant:
return HirRuleVariant(
original_rule=rule,
delta_idx=-1,
clause_versions=[Version.FULL] * len(rule.body),
)
[docs]
def create_delta_variant(rule: Rule, delta_idx: int) -> HirRuleVariant:
cvs = [Version.FULL] * len(rule.body)
cvs[delta_idx] = Version.DELTA
return HirRuleVariant(
original_rule=rule,
delta_idx=delta_idx,
clause_versions=cvs,
)
[docs]
def generate_variants(hir: HirProgram) -> HirProgram:
'''Populate base_variants and recursive_variants per stratum. Mutates
and returns the same HirProgram.
'''
for stratum in hir.strata:
scc_members = stratum.scc_members
if stratum.is_recursive:
for rule in stratum.stratum_rules:
for idx in find_scc_clause_indices(rule, scc_members):
stratum.recursive_variants.append(create_delta_variant(rule, idx))
else:
for rule in stratum.stratum_rules:
stratum.base_variants.append(create_base_variant(rule))
return hir
[docs]
class SemiNaiveVariantPass:
'''Pipeline wrapper. Runs right after stratify (order=100) so downstream
HIR transforms that care about variants can assume they exist.
'''
info = PassInfo(
name="SemiNaiveVariants",
level=PassLevel.HIR_TRANSFORM,
order=100,
source_dialect=IRLevel.HIR,
target_dialect=IRLevel.HIR,
)
[docs]
def run(self, hir: HirProgram) -> HirProgram:
return generate_variants(hir)