Source code for srdatalog.ir.core.dialect
'''Dialect base + Compiler registry.
A `Dialect` is a coherent set of types, ops, lowerings, rewrites, and
a verifier. Dialects register with a `Compiler` at init time, which
indexes them by name and exposes lookups for the pass driver.
The registry has no central enum of dialects — Property P1. Adding a
new dialect = constructing a `Dialect` and calling `register_dialect`.
No edits to this module.
See docs/ir_lowering_semantics.md, section 19.
'''
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
[docs]
@dataclass
class Dialect:
'''A registered dialect.
Fields:
name — dialect identifier (e.g. "relation.sorted_array").
types — Type subclasses owned by this dialect.
ops — Op subclasses owned by this dialect.
lowerings — Lowering rules emitted *out* of this dialect.
rewrites — Rewrite rules *within* this dialect.
verifier — optional callable that validates the dialect's IR
shape; returns a list of VerificationError on failure.
'''
name: str
types: list[type] = field(default_factory=list)
ops: list[type] = field(default_factory=list)
lowerings: list[Any] = field(default_factory=list)
rewrites: list[Any] = field(default_factory=list)
verifier: Callable[[Any], list[Any]] | None = None
[docs]
class Compiler:
'''Holds the registered dialects.
The registry is the single source of truth for what dialects exist.
Lookups happen by name; cross-dialect lowerings are resolved by
matching the source op kind against each dialect's `lowerings` list.
'''
def __init__(self) -> None:
self._dialects: dict[str, Dialect] = {}
[docs]
def register_dialect(self, d: Dialect) -> None:
'''Register a dialect. Raises ValueError if `d.name` is already taken.'''
if d.name in self._dialects:
raise ValueError(f'Dialect {d.name!r} already registered')
self._dialects[d.name] = d
[docs]
def get_dialect(self, name: str) -> Dialect:
'''Look up a registered dialect by name. Raises KeyError if missing.'''
return self._dialects[name]
@property
def dialects(self) -> list[Dialect]:
'''All registered dialects, in registration order.'''
return list(self._dialects.values())