'''Pass kinds, registration decorators, and driver.
A pass is a transformation on the IR. Two flavors:
Lowering — matches an op of one dialect, produces ops of another
(or of a target dialect). Each Lowering carries the op
class it matches, an `apply` callable that performs the
transformation, declared `consumes` / `produces` dialect
names for dependency validation, and a name for diagnostics.
Rewrite — matches an op, produces ops of the *same* dialect.
Used for internal optimizations like the IIR-sorted-array
count-as-product or hint-narrowing rules.
Per the project memory note (`feedback_decorator_registries.md`),
registration uses Triton-style decorators rather than imperative
`register_X(OpClass, fn)` calls or class-based dispatch:
from srdatalog.ir.core import lowering, rewrite, verifier
from srdatalog.ir.dialects.relation.sorted_array import DIALECT
@lowering(DIALECT, mir.ExecutePipeline,
consumes=('mir',), produces=('iir.cf', 'relation.sorted_array'))
def _lower_execute_pipeline(ep, ctx):
...
@rewrite(DIALECT, SaPrefCoop)
def _hint_introduction(op, ctx):
...
@verifier(DIALECT)
def _verify_sorted_array(prog):
return [] # list of VerificationError
The decorators mutate `dialect.lowerings` / `dialect.rewrites` /
`dialect.verifier` in place. Decoration is the only registration
path — there is no imperative API exposed.
The `PassDriver` walks `compiler.dialects` to validate dependencies
(every Lowering's `consumes` must be in the registered dialect set;
otherwise raise `PassDependencyError`). Actual op-level dispatch is
left to callers for now (production code calls lowering functions
directly); the registry exists so external consumers can introspect
"who lowers what" and so future stages can add a tree-walking dispatcher.
See docs/ir_lowering_semantics.md, section 21.
'''
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from srdatalog.ir.core.dialect import Compiler, Dialect
[docs]
@dataclass
class Lowering:
'''A lowering rule from one dialect to another.
Fields:
matches — the Op subclass this rule matches (e.g. MirColumnJoin).
apply — callable taking (op, context) and returning the
replacement IR (single op or list of ops; type depends
on the target dialect's contract).
name — short identifier for diagnostics and pass tracing.
consumes — dialect names whose ops this lowering reads. Used by
PassDriver to validate that every required dialect is
registered before the lowering runs.
produces — dialect names whose ops this lowering emits. Used by
PassDriver for topological ordering of multi-pass
pipelines (a pass that produces dialect D must run
before any pass that consumes D).
'''
matches: type
apply: Callable[[Any, Any], Any]
name: str = ''
consumes: tuple[str, ...] = field(default_factory=tuple)
produces: tuple[str, ...] = field(default_factory=tuple)
[docs]
@dataclass
class Rewrite:
'''A rewrite rule within a single dialect.
Same shape as Lowering, but conventionally produces ops of the same
dialect as `matches`. `consumes` / `produces` typically equal the
dialect's own name; PassDriver still uses them for dependency
validation if a rewrite reads ops from a sibling dialect.
'''
matches: type
apply: Callable[[Any, Any], Any]
name: str = ''
consumes: tuple[str, ...] = field(default_factory=tuple)
produces: tuple[str, ...] = field(default_factory=tuple)
# -----------------------------------------------------------------------------
# Decorator-style registration
# -----------------------------------------------------------------------------
[docs]
def lowering(
dialect: Dialect,
matches: type,
*,
consumes: tuple[str, ...] = (),
produces: tuple[str, ...] = (),
name: str = '',
) -> Callable[[Callable[[Any, Any], Any]], Callable[[Any, Any], Any]]:
'''Decorator: wrap fn as a Lowering and register on dialect.lowerings.
Usage:
@lowering(MY_DIALECT, mir.SomeOp,
consumes=('mir',),
produces=('iir.cf', 'relation.sorted_array'))
def _lower_some_op(op, ctx):
return ...
Returns the original function (so other decorators can stack).
'''
def _wrap(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
inst = Lowering(
matches=matches,
apply=fn,
name=name or fn.__name__,
consumes=tuple(consumes),
produces=tuple(produces),
)
dialect.lowerings.append(inst)
return fn
return _wrap
[docs]
def rewrite(
dialect: Dialect,
matches: type,
*,
consumes: tuple[str, ...] = (),
produces: tuple[str, ...] = (),
name: str = '',
) -> Callable[[Callable[[Any, Any], Any]], Callable[[Any, Any], Any]]:
'''Decorator: wrap fn as a Rewrite and register on dialect.rewrites.
Usage:
@rewrite(MY_DIALECT, SomeOp)
def _hint_introduction(op, ctx):
return ...
'''
def _wrap(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
inst = Rewrite(
matches=matches,
apply=fn,
name=name or fn.__name__,
consumes=tuple(consumes),
produces=tuple(produces),
)
dialect.rewrites.append(inst)
return fn
return _wrap
[docs]
def verifier(dialect: Dialect) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]:
'''Decorator: register fn as the dialect's verifier.
Usage:
@verifier(MY_DIALECT)
def _verify(prog):
return [] # list of VerificationError, [] = OK
Raises ValueError if the dialect already has a verifier registered.
'''
def _wrap(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
if dialect.verifier is not None:
raise ValueError(f'verifier already registered on {dialect.name!r}')
dialect.verifier = fn
return fn
return _wrap
# -----------------------------------------------------------------------------
# Pass dependency error
# -----------------------------------------------------------------------------
[docs]
class PassDependencyError(Exception):
'''A registered pass declared `consumes=(D, ...)` but dialect D is
not registered with the Compiler. Raised by PassDriver.run before
any pass executes.
The recommended posture (per docs/stage3a_execution_plan.md §9) is
loud failure over silent fallback: a pipeline opting out of a
dialect's passes does so by not registering those passes, not by
not registering the dialect.
'''
def __init__(self, pass_name: str, missing_dialect: str, in_dialect: str) -> None:
self.pass_name = pass_name
self.missing_dialect = missing_dialect
self.in_dialect = in_dialect
super().__init__(
f'pass {pass_name!r} in dialect {in_dialect!r} declares '
f'consumes={missing_dialect!r}, but that dialect is not registered '
f'with the Compiler. Either register the dialect or unregister the pass.'
)
# -----------------------------------------------------------------------------
# PassDriver
# -----------------------------------------------------------------------------
[docs]
class PassDriver:
'''Runs registered rewrites and lowerings.
Today the driver does dependency validation (catches "pass P
consumes dialect D not enabled") and verifier dispatch. Op-level
dispatch (a tree walker that finds the registered Lowering for
each op kind and applies it) lands when the first production
consumer needs it; until then, callers invoke lowering functions
directly and the registry serves as introspection metadata.
The driver does not know about specific dialects. New dialects
participate by being registered; the driver consults the registry.
'''
def __init__(self, compiler: Compiler) -> None:
self._compiler = compiler
[docs]
def validate_dependencies(self) -> None:
'''Check every registered Lowering / Rewrite's `consumes` against
the registered dialect set. Raises PassDependencyError on the
first unmet dependency.'''
registered = {d.name for d in self._compiler.dialects}
for d in self._compiler.dialects:
for p in (*d.lowerings, *d.rewrites):
for needed in p.consumes:
if needed not in registered:
raise PassDependencyError(
pass_name=p.name,
missing_dialect=needed,
in_dialect=d.name,
)
[docs]
def verify_all(self, prog: Any) -> list[Any]:
'''Invoke every registered dialect's verifier on `prog` and
aggregate the returned VerificationErrors. Returns [] if all
verifiers pass.'''
errors: list[Any] = []
for d in self._compiler.dialects:
if d.verifier is not None:
errors.extend(d.verifier(prog))
return errors
[docs]
def run(self, prog: Any) -> Any:
'''Run all registered passes against `prog`. Returns the (possibly
transformed) program.
Validates pass dependencies first; raises PassDependencyError on
unmet consumes. Then runs verifiers and returns prog unchanged
(op-level dispatch is caller-driven for now; see class docstring).
'''
self.validate_dependencies()
errors = self.verify_all(prog)
if errors:
raise RuntimeError(f'Verification failed: {errors}')
return prog
__all__ = [
'Lowering',
'PassDependencyError',
'PassDriver',
'Rewrite',
'lowering',
'rewrite',
'verifier',
]