'''Strategy combinators for IR rewrite walks.
Lifted from Stratego (Visser, *Program Transformation with Stratego/XT*).
The idea: a *strategy* is a partial function `Op -> Op | None` where
`None` means "the strategy didn't apply at this op." Strategies compose
via combinators — `seq`, `choice`, `try_`, `repeat`, `top_down`,
`bottom_up`, `all_`, `one`. The composition is total, deterministic,
and pure.
Why strategies and not ad-hoc walkers:
- **Composable**: rewrite logic is built from named combinators
rather than re-implementing tree traversal in every pass.
- **Pure**: each strategy is a function. No mutable rewriter
objects. Unlike MLIR's `PatternRewriter`, strategies don't carry
hidden state.
- **Confluence-friendly**: applying `top_down(repeat(rule))` is a
standard idiom; correctness arguments lift directly from the
term-rewriting literature.
Children of an Op are discovered generically: any field whose value is
an Op or a list/tuple of Ops is considered a child. This is the "scrap
your boilerplate" lesson — generic traversal beats per-op visitor
methods.
References:
Visser, "Strategic pattern matching", RTA 1999.
Visser & Benaissa, "A core language for rewriting", FroCoS 1998.
Lämmel & Peyton Jones, "Scrap your boilerplate", TLDI 2003.
'''
from __future__ import annotations
import dataclasses
from collections.abc import Callable
from typing import TypeVar
from srdatalog.ir.core.ops import Op
T = TypeVar('T', bound=Op)
# A strategy maps an Op to a transformed Op or None ("did not apply").
Strategy = Callable[[Op], Op | None]
# -----------------------------------------------------------------------------
# Atomic strategies
# -----------------------------------------------------------------------------
[docs]
def id_(op: Op) -> Op:
'''The identity strategy. Always succeeds, returning the op unchanged.'''
return op
[docs]
def fail(_op: Op) -> Op | None:
'''The failing strategy. Always returns None.'''
return None
# -----------------------------------------------------------------------------
# Composition combinators
# -----------------------------------------------------------------------------
[docs]
def try_(s: Strategy) -> Strategy:
'''Apply `s`. If it fails, return the op unchanged. Always succeeds.'''
def go(op: Op) -> Op:
r = s(op)
return op if r is None else r
return go
[docs]
def seq(s1: Strategy, s2: Strategy) -> Strategy:
'''Apply `s1`, then `s2` to the result. Fails if either fails.'''
def go(op: Op) -> Op | None:
r1 = s1(op)
if r1 is None:
return None
return s2(r1)
return go
[docs]
def choice(s1: Strategy, s2: Strategy) -> Strategy:
'''Apply `s1`. If it fails, apply `s2` to the original op.'''
def go(op: Op) -> Op | None:
r = s1(op)
return r if r is not None else s2(op)
return go
[docs]
def repeat(s: Strategy, max_iters: int = 1024) -> Strategy:
'''Apply `s` to fixpoint (until it fails or returns the same op).
`max_iters` guards against non-terminating rewrites; when reached,
raises RuntimeError. In the standard term-rewriting literature this
bound is implicit; we make it explicit since Python doesn't have a
natural confluence proof.
'''
def go(op: Op) -> Op:
for _ in range(max_iters):
r = s(op)
if r is None or r == op:
return op
op = r
raise RuntimeError(f'repeat: did not converge in {max_iters} iterations')
return go
# -----------------------------------------------------------------------------
# Traversal combinators
# -----------------------------------------------------------------------------
[docs]
def all_(s: Strategy) -> Strategy:
'''Apply `s` to every immediate child of the op. Fails if `s` fails
on any child (Stratego semantics: `all` is "all-or-nothing").
Returns a new op with transformed children, or None if any child
transformation failed.
'''
def go(op: Op) -> Op | None:
return _map_children(op, s, all_or_nothing=True)
return go
[docs]
def one(s: Strategy) -> Strategy:
'''Apply `s` to exactly one immediate child of the op (the first
that succeeds). Fails if `s` fails on every child.
'''
def go(op: Op) -> Op | None:
return _map_one_child(op, s)
return go
[docs]
def some(s: Strategy) -> Strategy:
'''Apply `s` to as many immediate children as it succeeds on. Fails
if `s` fails on every child.
'''
def go(op: Op) -> Op | None:
return _map_some_children(op, s)
return go
[docs]
def top_down(s: Strategy) -> Strategy:
'''Apply `s` at the root, then recursively at every descendant
(preorder). Equivalent to `try_(seq(s, all_(top_down(s))))`.
'''
def go(op: Op) -> Op:
r = s(op)
op = op if r is None else r
children_result = _map_children(op, go, all_or_nothing=False)
return op if children_result is None else children_result
return go
[docs]
def bottom_up(s: Strategy) -> Strategy:
'''Apply `s` at every descendant, then at the root (postorder).
Equivalent to `seq(all_(bottom_up(s)), try_(s))`.
'''
def go(op: Op) -> Op:
children_result = _map_children(op, go, all_or_nothing=False)
op = op if children_result is None else children_result
r = s(op)
return op if r is None else r
return go
# -----------------------------------------------------------------------------
# Internal: child traversal via dataclass field walking
# -----------------------------------------------------------------------------
def _map_children(op: Op, transform: Strategy, *, all_or_nothing: bool) -> Op | None:
'''Apply `transform` to every immediate Op-valued child of `op`.
Children are discovered by walking dataclass fields. A field is a
child if its value is an Op, or a list/tuple/dict whose entries
contain Ops. Non-Op fields are passed through unchanged.
When `all_or_nothing=True`, returns None if `transform` fails on
any child (Stratego `all` semantics).
When `all_or_nothing=False`, transform-fails are treated as
identity (Stratego `try_(all)` semantics, used by traversal
walkers).
'''
# Op subclasses are always dataclasses by D4 (see design_principles.md).
changed = False
new_values: dict[str, object] = {}
for f in dataclasses.fields(op):
val = getattr(op, f.name)
new_val, child_changed, ok = _transform_field(val, transform, all_or_nothing)
if not ok:
return None
new_values[f.name] = new_val
changed = changed or child_changed
if not changed:
return op
return dataclasses.replace(op, **new_values)
def _map_one_child(op: Op, transform: Strategy) -> Op | None:
'''Apply `transform` to exactly one Op-valued child (left-to-right
first that succeeds). Returns None if no child succeeds.
'''
# Op subclasses are always dataclasses by D4 (see design_principles.md).
for f in dataclasses.fields(op):
val = getattr(op, f.name)
if isinstance(val, Op):
r = transform(val)
if r is not None and r is not val:
return dataclasses.replace(op, **{f.name: r})
elif isinstance(val, (list, tuple)):
for i, x in enumerate(val):
if isinstance(x, Op):
r = transform(x)
if r is not None and r is not x:
new_list = list(val)
new_list[i] = r
return dataclasses.replace(op, **{f.name: type(val)(new_list)})
return None
def _map_some_children(op: Op, transform: Strategy) -> Op | None:
'''Apply `transform` to as many Op-valued children as succeed.
Returns None if no child succeeds (Stratego `some` semantics).
'''
# Op subclasses are always dataclasses by D4 (see design_principles.md).
any_succeeded = False
new_values: dict[str, object] = {}
for f in dataclasses.fields(op):
val = getattr(op, f.name)
new_val, ok = _transform_field_some(val, transform)
new_values[f.name] = new_val
any_succeeded = any_succeeded or ok
if not any_succeeded:
return None
return dataclasses.replace(op, **new_values)
def _transform_field(
val: object,
transform: Strategy,
all_or_nothing: bool,
) -> tuple[object, bool, bool]:
'''Apply `transform` to a field value. Returns (new_val, changed, ok).
`ok` is False only when `all_or_nothing=True` and a child transform
returned None.
'''
if isinstance(val, Op):
r = transform(val)
if r is None:
return val, False, not all_or_nothing
return r, r is not val, True
if isinstance(val, list):
new_list = []
changed = False
for x in val:
nv, ch, ok = _transform_field(x, transform, all_or_nothing)
if not ok:
return val, False, False
new_list.append(nv)
changed = changed or ch
return new_list, changed, True
if isinstance(val, tuple):
new_tuple = []
changed = False
for x in val:
nv, ch, ok = _transform_field(x, transform, all_or_nothing)
if not ok:
return val, False, False
new_tuple.append(nv)
changed = changed or ch
return tuple(new_tuple), changed, True
return val, False, True
def _transform_field_some(val: object, transform: Strategy) -> tuple[object, bool]:
'''Helper for `some`: apply `transform`, treat failures as identity,
but report whether any application succeeded. Returns (new_val, any_ok).
'''
if isinstance(val, Op):
r = transform(val)
if r is None:
return val, False
return r, True
if isinstance(val, list):
new_list = []
any_ok = False
for x in val:
nv, ok = _transform_field_some(x, transform)
new_list.append(nv)
any_ok = any_ok or ok
return new_list, any_ok
if isinstance(val, tuple):
new_tuple = []
any_ok = False
for x in val:
nv, ok = _transform_field_some(x, transform)
new_tuple.append(nv)
any_ok = any_ok or ok
return tuple(new_tuple), any_ok
return val, False