Source code for srdatalog.ir.hir.split

'''Split-rule HIR transforms: Pass 4.5 (temp relation decl synthesis) and
Pass 5.5 (temp index registration). Mirror the inline blocks in
src/srdatalog/hir/hir.nim's compileToHir.

Both passes run as regular HIR transforms, positioned around
IndexSelectionPass:
  JoinPlannerPass        (order 200)
    TempRelSynthesisPass    (order 250)  <-- this file
  IndexSelectionPass     (order 300)
    TempIndexRegistrationPass (order 350)  <-- this file

Split metadata (split_at, temp_vars, temp_rel_name) is populated by
JoinPlannerPass from hir_plan._plan_variant.
'''

from __future__ import annotations

from srdatalog.dsl import ArgKind, Atom
from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.types import HirProgram, RelationDecl


def _infer_temp_types(variant, decls: list[RelationDecl]) -> list[str]:
  '''Pick each temp var's type from the first above-split positive clause
  that contains it (source decl lookup). Falls back to "int". Mirrors
  the inference loop in compileToHir's Pass 4.5.
  '''
  rule = variant.original_rule
  decls_map = {d.rel_name: d for d in decls}
  types: list[str] = []
  for tv in variant.temp_vars:
    found = False
    for clause_idx in range(variant.split_at):
      clause = rule.body[clause_idx]
      if not isinstance(clause, Atom):
        continue
      for arg_idx, arg in enumerate(clause.args):
        if arg.kind is ArgKind.LVAR and arg.var_name == tv:
          if clause.rel in decls_map:
            orig = decls_map[clause.rel]
            if arg_idx < len(orig.types):
              types.append(orig.types[arg_idx])
              found = True
          break
      if found:
        break
    if not found:
      types.append("int")
  return types


[docs] class TempRelSynthesisPass: '''Pass 4.5: for each split variant, add a `_temp_<RuleName>` RelationDecl to the HirProgram (if not already declared). Temp is marked `is_generated=True, is_temp=True`. Column types inferred from above-split source relations. ''' info = PassInfo( name="TempRelSynthesis", level=PassLevel.HIR_TRANSFORM, order=250, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, hir: HirProgram) -> HirProgram: existing = {d.rel_name for d in hir.relation_decls} for stratum in hir.strata: for v in list(stratum.base_variants) + list(stratum.recursive_variants): if v.split_at >= 0 and v.temp_rel_name and v.temp_rel_name not in existing: types = _infer_temp_types(v, hir.relation_decls) hir.relation_decls.append( RelationDecl( rel_name=v.temp_rel_name, types=types, semiring="NoProvenance", is_generated=True, is_temp=True, ) ) existing.add(v.temp_rel_name) return hir
[docs] class TempIndexRegistrationPass: '''Pass 5.5: register the identity index [0..arity-1] for each split variant's temp relation in its enclosing stratum's required_indices and canonical_index. Runs after the main IndexSelectionPass so it only affects temp rels (which weren't seen by the selection pass because they're head-only here). ''' info = PassInfo( name="TempIndexRegistration", level=PassLevel.HIR_TRANSFORM, order=350, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, hir: HirProgram) -> HirProgram: for stratum in hir.strata: for v in list(stratum.base_variants) + list(stratum.recursive_variants): if v.split_at < 0 or not v.temp_rel_name: continue arity = len(v.temp_vars) temp_idx = list(range(arity)) if v.temp_rel_name not in stratum.required_indices: stratum.required_indices[v.temp_rel_name] = [temp_idx] stratum.canonical_index[v.temp_rel_name] = temp_idx return hir