Source code for srdatalog.ir.codegen.cuda.render

'''CUDA renderer registry.

Per docs/stage3a_execution_plan.md §7 task S3A.3, this package replaces
codegen/cuda/emit.py's hardcoded 41-case match with a per-op-class
registry that each dialect's render module opts into. Adding a new IIR
dialect now requires adding a `codegen/cuda/render/<dialect>.py` with
`@register_render` decorators — no edits to this file or to any
existing per-dialect render module. P1 fix.

Public API (preserved for external callers — was previously in
codegen/cuda/emit.py):

    EmitCtx          mutable per-emission state (indent, tile var, ...)
    emit(op, ctx)    statement-shaped op → C++ source line(s)
    emit_expr(op, ctx) expression-shaped op → C++ source fragment

Each IIR-contributing dialect imports from this module:

    from srdatalog.ir.codegen.cuda.render import (
        EmitCtx, register_render, emit, emit_expr,
    )

    @register_render(Block, mode='stmt')
    def _render_block(op: Block, ctx: EmitCtx) -> str:
        return ''.join(emit(s, ctx) for s in op.stmts)

The `mode` argument distinguishes statement vs expression handlers.
A single op class CAN appear in both registries (e.g. `RawString`
is renderable in both modes with different output forms).

Dispatch (`emit` / `emit_expr`) raises `KeyError` with a useful
message on an unregistered op — that's the "did this codegen forget
to register a handler" signal. Per S3A.8 (per-Codegen registry +
verify_completeness), startup will eventually catch this earlier.
'''

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

from srdatalog.ir.core import Op

RenderFn = Callable[[Op, 'EmitCtx'], str]


[docs] @dataclass class EmitCtx: '''Mutable emission state during the C++ walk. - `indent_level`: counts of 2-space units. Legacy emitter starts at indent=2 (operator() body); M1 callers pass that in to match. - `tile_var`: name of the tile/thread-group variable. "tile" by default. - `segment_depth`: how many `D2lSegmentLoop`s wrap the current point. Lets `IntersectIter` reproduce the legacy `_nested_column_join_multi` indent quirk where segment loops bump the structural indent (alias_binds, intersect at +segs) but the for-iter body lines (auto value, positions, child_binds, body_op) anchor against the *outer* indent (+1, +1, +1, +0 respectively). ''' indent_level: int = 2 tile_var: str = 'tile' segment_depth: int = 0
[docs] def ind(self) -> str: return ' ' * self.indent_level
# Module-level registries. S3A.8 will move these onto a per-`Codegen` # instance with `verify_completeness()` at `Compiler.register_codegen` # construction time. For now (S3A.3), module-globals keep the diff # focused on the registry-vs-match split. _STMT_HANDLERS: dict[type[Op], RenderFn] = {} _EXPR_HANDLERS: dict[type[Op], RenderFn] = {}
[docs] def register_render(op_class: type[Op], *, mode: str = 'stmt') -> Callable[[RenderFn], RenderFn]: '''Decorator: register `fn` as the renderer for `op_class` in the given mode ('stmt' or 'expr').''' if mode not in ('stmt', 'expr'): raise ValueError(f"register_render mode must be 'stmt' or 'expr', got {mode!r}") registry = _STMT_HANDLERS if mode == 'stmt' else _EXPR_HANDLERS def _wrap(fn: RenderFn) -> RenderFn: if op_class in registry: raise ValueError(f"register_render: {op_class.__name__} already registered in mode={mode!r}") registry[op_class] = fn return fn return _wrap
[docs] def emit(op: Op, ctx: EmitCtx) -> str: '''Statement-mode dispatch. Returns C++ source text ending in `\\n`.''' handler = _STMT_HANDLERS.get(type(op)) if handler is None: raise KeyError( f'codegen.cuda: no statement-mode renderer registered for ' f'{type(op).__name__} (module={type(op).__module__!r}). ' 'Did the dialect ship its codegen/cuda/render/<dialect>.py?' ) return handler(op, ctx)
[docs] def emit_expr(op: Op, ctx: EmitCtx) -> str: '''Expression-mode dispatch. Returns C++ source fragment with no leading indent or trailing newline.''' handler = _EXPR_HANDLERS.get(type(op)) if handler is None: raise KeyError( f'codegen.cuda: no expression-mode renderer registered for ' f'{type(op).__name__} (module={type(op).__module__!r}). ' 'Did the dialect ship its codegen/cuda/render/<dialect>.py?' ) return handler(op, ctx)
def _eager_register_all() -> None: '''Force-import each dialect's render module so its @register_render decorators run. Called at first use of `emit` / `emit_expr` to avoid a load-time cycle (the dialect render modules import EmitCtx + register_render from this module). Idempotent — register_render itself rejects duplicate registration. ''' global _registered_all if _registered_all: return _registered_all = True # Import for side-effect (the @register_render decorators). from srdatalog.ir.codegen.cuda.render import ( d2l, iir_cf, parallel_data, sorted_array, ) _registered_all = False # Ensure handlers are registered the first time someone imports from # this package. This preserves the property "import codegen.cuda.render # and emit(op, ctx) just works" — no need for callers to know about # the per-dialect registration step. _eager_register_all() __all__ = ['EmitCtx', 'RenderFn', 'register_render', 'emit', 'emit_expr']