Source code for srdatalog.ir.hir.index

'''HIR Pass 5: Index Selection.

Populates per-stratum `required_indices` / `canonical_index` and program-level
`global_index_map`. Required for downstream MIR lowering, which embeds the
canonical index columns into RebuildIndex / ComputeDelta / MergeIndex etc.

Mirrors src/srdatalog/hir/index_selection.nim.

Algorithm:
  1. First pass: collect every index used by any variant in any stratum
     into `hir.global_index_map`.
  2. Second pass: for each stratum, per SCC member:
     - Recursive: DELTA-version indices first, FULL-version second,
       global fallback, identity fallback.
     - Non-recursive: local patterns, global fallback, identity fallback.
  3. Canonical index: prefer a full-arity index that is also accessed as
     FULL in the stratum (so the maintained FULL-version index is actually
     read; avoids building a dead index).

SCC member iteration is sorted for reproducibility (Python set order is
hash-random); Nim relies on its deterministic string hash, but these
fields are internal to downstream passes and not emitted, so the order
only has to be stable, not byte-matched against Nim.
'''

from __future__ import annotations

from srdatalog.ir.hir.pass_ import IRLevel, PassInfo, PassLevel
from srdatalog.ir.hir.types import (
  HirProgram,
  HirRuleVariant,
  RelationDecl,
  Version,
)


[docs] def get_arity(rel_name: str, decls: list[RelationDecl]) -> int: for d in decls: if d.rel_name == rel_name: return len(d.types) return 0
[docs] def default_index(rel_name: str, decls: list[RelationDecl]) -> list[int]: return list(range(get_arity(rel_name, decls)))
[docs] def complete_index(idx: list[int], arity: int) -> list[int]: '''Pad a partial index to full arity by appending missing columns in column order.''' result = list(idx) for col in range(arity): if col not in result: result.append(col) return result
[docs] def canonical_index( rel_name: str, indices: list[list[int]], decls: list[RelationDecl], full_indices: dict[str, set[tuple[int, ...]]] | None = None, ) -> list[int]: '''Pick a full-arity canonical index. Prefer one whose FULL-version is actually read by joins (so the maintained FULL index is actively used). ''' arity = get_arity(rel_name, decls) if full_indices and rel_name in full_indices: for idx in indices: if len(idx) == arity and tuple(idx) in full_indices[rel_name]: return list(idx) for idx in indices: if len(idx) == arity: return list(idx) return default_index(rel_name, decls)
def _collect_indices_from_variants( variants: list[HirRuleVariant], version: Version | None = None ) -> dict[str, set[tuple[int, ...]]]: '''Gather {rel_name: {tuple(index_cols), ...}} from access and negation patterns, optionally filtered to a specific Version. ''' out: dict[str, set[tuple[int, ...]]] = {} for v in variants: for pat in v.access_patterns: if version is None or pat.version is version: out.setdefault(pat.rel_name, set()).add(tuple(pat.index_cols)) for pat in v.negation_patterns: if version is None or pat.version is version: out.setdefault(pat.rel_name, set()).add(tuple(pat.index_cols)) return out def _append_unique_indices( dest: list[list[int]], source_tuples: set[tuple[int, ...]], arity: int ) -> None: '''Append completed indices to `dest` in sorted order, skipping duplicates.''' for t in sorted(source_tuples): cidx = complete_index(list(t), arity) if cidx not in dest: dest.append(cidx)
[docs] def select_indices(hir: HirProgram) -> HirProgram: '''Pass 5 entry. Mutates and returns the HirProgram.''' decls = hir.relation_decls # ----- First pass: global_index_map over all strata. for stratum in hir.strata: all_variants = stratum.base_variants + stratum.recursive_variants all_idx = _collect_indices_from_variants(all_variants, version=None) for rel_name in sorted(all_idx.keys()): if rel_name not in hir.global_index_map: hir.global_index_map[rel_name] = [] for t in sorted(all_idx[rel_name]): idx_list = list(t) if idx_list not in hir.global_index_map[rel_name]: hir.global_index_map[rel_name].append(idx_list) # ----- Second pass: per-stratum required + canonical. for stratum in hir.strata: if stratum.is_recursive: delta_idx = _collect_indices_from_variants(stratum.recursive_variants, Version.DELTA) full_idx = _collect_indices_from_variants(stratum.recursive_variants, Version.FULL) all_idx = {} else: delta_idx = {} full_idx = {} all_idx = _collect_indices_from_variants(stratum.base_variants, version=None) for rel_name in sorted(stratum.scc_members): arity = get_arity(rel_name, decls) indices: list[list[int]] = [] if stratum.is_recursive: if rel_name in delta_idx: _append_unique_indices(indices, delta_idx[rel_name], arity) if rel_name in full_idx: _append_unique_indices(indices, full_idx[rel_name], arity) else: if rel_name in all_idx: _append_unique_indices(indices, all_idx[rel_name], arity) # Fallback 1: global index map (indices used by OTHER strata). if not indices and rel_name in hir.global_index_map: for idx in hir.global_index_map[rel_name]: cidx = complete_index(idx, arity) if cidx not in indices: indices.append(cidx) # Fallback 2: identity index. if not indices: indices.append(default_index(rel_name, decls)) stratum.required_indices[rel_name] = indices stratum.canonical_index[rel_name] = canonical_index( rel_name, indices, decls, full_indices=full_idx if stratum.is_recursive else None, ) return hir
[docs] class IndexSelectionPass: info = PassInfo( name="IndexSelection", level=PassLevel.HIR_TRANSFORM, order=300, source_dialect=IRLevel.HIR, target_dialect=IRLevel.HIR, )
[docs] def run(self, hir: HirProgram) -> HirProgram: return select_indices(hir)