Source code for srdatalog.ir.codegen.cuda.schema
'''Schema types for C++ codegen.
Describes the per-program relation declarations that drive the emitted
C++ `using` aliases (RelationSchema, Database, SemiNaiveDatabase). Ported
from mhk's `python-api-notemplate` branch `nt_schema.py`; independent of
MIR types, so nothing here couples back to `mir_types`.
Two emission formats exist:
- `str(schema)` — plain alias list, used by the orchestrator today.
- `schema.get_batch_prelude(name)` — same aliases plus the Database /
SemiNaiveDatabase typedefs, used at the top of each JIT batch file.
'''
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
[docs]
class Pragma(Enum):
'''Fact pragmas. INPUT carries the CSV filename to load from (required
for `load_data()` in the FFI wrapper). PRINT_SIZE and SEMIRING are bool.
'''
INPUT = "input"
PRINT_SIZE = "print_size"
SEMIRING = "semiring"
[docs]
@dataclass
class FactDefinition:
'''A schema-declared relation. `name` must match the relation name used
by MIR nodes (InsertInto.rel_name, ColumnSource.rel_name, ...).
`params` is the column type tuple, e.g. `[int, int]` for an arity-2
int-valued relation.
'''
name: str
params: list[type]
pragmas: dict[Pragma, Any] = field(default_factory=dict)
[docs]
def __str__(self) -> str:
semiring = self.pragmas.get(Pragma.SEMIRING, "BooleanSR")
params_str = ", ".join(p.__name__ for p in self.params)
return (
f'using {self.name} = AST::RelationSchema<decltype("{self.name}"_s), '
f'{semiring}, std::tuple<{params_str}>>;'
)
[docs]
@dataclass
class SchemaDefinition:
'''All relations used by a program. Order matters for the emitted
`AST::Database<...>` template argument list.'''
facts: list[FactDefinition]
[docs]
def __str__(self) -> str:
return "".join(str(f) + "\n" for f in self.facts)
[docs]
def get_batch_prelude(self, name: str) -> str:
'''Header block emitted at the top of a JIT batch file: per-relation
aliases plus the blueprint / device-DB typedefs keyed on `name`.
'''
parts: list[str] = []
for fact in self.facts:
semiring = fact.pragmas.get(Pragma.SEMIRING, "BooleanSR")
params_str = ", ".join(p.__name__ for p in fact.params)
parts.append(
f'using {fact.name} = SRDatalog::AST::RelationSchema<'
f'decltype("{fact.name}"_s), {semiring}, std::tuple<{params_str}>>;'
)
fact_names = ", ".join(f.name for f in self.facts)
parts.append(f'using {name}Fixpoint_DB_Blueprint = SRDatalog::AST::Database<{fact_names}>;')
parts.append(
f'using {name}Fixpoint_DB_DeviceDB = SRDatalog::AST::SemiNaiveDatabase<'
f'{name}Fixpoint_DB_Blueprint, SRDatalog::GPU::DeviceRelationType>;'
)
return "".join(parts)