'''AST-based rule location extraction.
Walks a Python source file looking for the pattern:
(HEAD_ATOM <= BODY).named("RuleName")[.with_plan(kwargs...)]*
and produces a mapping from rule name → source location, plus, for
each `.with_plan(...)` call, the exact byte-offset range of each
keyword argument's value. The extension uses this to:
1. Jump to a rule from the visualization sidebar
2. Replace a specific `var_order=[...]` / `clause_order=[...]` value
without reformatting the rest of the file
We key rules by the string literal passed to `.named(...)` because
that's the only name the HIR / JIT stages know them by. Anonymous
rules (no `.named` call) are skipped — the extension has no handle
for them anyway.
'''
from __future__ import annotations
import ast
from dataclasses import dataclass, field
[docs]
@dataclass(frozen=True)
class PlanKwargSpan:
'''Byte-offset range of one `with_plan(kwarg=VALUE)` — VALUE only.'''
kwarg: str # e.g. "var_order", "clause_order"
start: int # inclusive, 0-based byte offset into source
end: int # exclusive
[docs]
@dataclass(frozen=True)
class PlanCallSpan:
'''Byte-offset range of one `.with_plan(...)` call on a rule.'''
start: int
end: int
kwargs: list[PlanKwargSpan] = field(default_factory=list)
[docs]
@dataclass(frozen=True)
class RuleLocation:
'''Where a named rule lives in the source + what plans it has.'''
name: str
start_line: int # 1-based
end_line: int # 1-based, inclusive
start: int # byte offset of the outermost expression
end: int
plan_calls: list[PlanCallSpan] = field(default_factory=list)
[docs]
def find_rule_locations(source: str) -> list[RuleLocation]:
'''Walk the AST, return one RuleLocation per `.named("X")`-suffixed rule.
The pattern we recognize:
( HEAD <= BODY ).named("NAME").with_plan(kw=VAL, ...) ...
We accept any number (including zero) of trailing `.with_plan(...)`
calls after `.named(...)`. Rules without `.named` are skipped.
'''
tree = ast.parse(source)
lines_offsets = _line_offsets(source)
locs: list[RuleLocation] = []
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
name = _named_call_name(node)
if name is None:
continue
# Peel back to the outermost expression — keep extending while the
# parent is `.with_plan(...)` so we capture all trailing calls.
outer = node
plan_calls: list[PlanCallSpan] = []
# The AST builds chains bottom-up: outermost call is what .walk
# hits LAST for a given chain. Instead of peeling, we detect all
# `.with_plan` Call nodes whose .func is an Attribute rooted at
# this `.named(...)` — safer to run a targeted traversal from the
# chain's outermost ancestor. Simpler: a second pass below.
locs.append(
RuleLocation(
name=name,
start_line=outer.lineno,
end_line=outer.end_lineno or outer.lineno,
start=_offset(lines_offsets, outer.lineno, outer.col_offset),
end=_offset(lines_offsets, outer.end_lineno or outer.lineno, outer.end_col_offset or 0),
plan_calls=plan_calls,
)
)
# Second pass: for each rule, find any `.with_plan(...)` calls whose
# chain descends through its `.named(...)` call.
locs_by_offset = {(l.start_line, l.name): l for l in locs}
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
rule_name = _with_plan_chain_owner(node)
if rule_name is None:
continue
key = (_chain_named_line(node), rule_name)
target = locs_by_offset.get(key)
if target is None:
continue
span = _plan_call_span(node, source, lines_offsets)
target.plan_calls.append(span)
# Extend the rule's overall range to cover the trailing call.
new_end_line = node.end_lineno or target.end_line
new_end = _offset(lines_offsets, new_end_line, node.end_col_offset or 0)
if new_end > target.end:
# RuleLocation is frozen, so replace in-place via dataclasses trick.
object.__setattr__(target, "end_line", new_end_line)
object.__setattr__(target, "end", new_end)
# Sort by line for deterministic output + stable editor jumps.
locs.sort(key=lambda l: (l.start_line, l.name))
return locs
# ---------------------------------------------------------------------------
# AST helpers
# ---------------------------------------------------------------------------
def _named_call_name(call: ast.Call) -> str | None:
'''If `call` is `X.named("string")`, return the string. Else None.'''
func = call.func
if not isinstance(func, ast.Attribute):
return None
if func.attr != "named":
return None
if len(call.args) != 1:
return None
arg = call.args[0]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value
return None
def _with_plan_chain_owner(call: ast.Call) -> str | None:
'''If `call` is `X.with_plan(...)` with a `.named("name")` upstream in
the chain, return the name. Else None.'''
func = call.func
if not isinstance(func, ast.Attribute) or func.attr != "with_plan":
return None
cur: ast.AST | None = func.value
while isinstance(cur, (ast.Call, ast.Attribute)):
if isinstance(cur, ast.Call):
name = _named_call_name(cur)
if name is not None:
return name
cur = cur.func
else:
cur = cur.value
return None
def _chain_named_line(call: ast.Call) -> int:
'''The line of the `.named(...)` call in this with_plan chain.'''
cur: ast.AST | None = call.func
while isinstance(cur, (ast.Call, ast.Attribute)):
if isinstance(cur, ast.Call) and _named_call_name(cur) is not None:
return cur.lineno
cur = cur.func if isinstance(cur, ast.Call) else cur.value
return call.lineno
def _plan_call_span(call: ast.Call, source: str, line_offsets: list[int]) -> PlanCallSpan:
start = _offset(line_offsets, call.lineno, call.col_offset)
end = _offset(line_offsets, call.end_lineno or call.lineno, call.end_col_offset or 0)
kwargs: list[PlanKwargSpan] = []
for kw in call.keywords:
if kw.arg is None:
continue
v = kw.value
ks = _offset(line_offsets, v.lineno, v.col_offset)
ke = _offset(line_offsets, v.end_lineno or v.lineno, v.end_col_offset or 0)
kwargs.append(PlanKwargSpan(kwarg=kw.arg, start=ks, end=ke))
return PlanCallSpan(start=start, end=end, kwargs=kwargs)
def _line_offsets(source: str) -> list[int]:
'''1-based line → starting byte offset (0-based). Index 0 unused.'''
offsets = [0, 0] # offsets[1] = 0 (first line starts at 0)
off = 0
for line in source.splitlines(keepends=True):
off += len(line)
offsets.append(off)
return offsets
def _offset(line_offsets: list[int], line: int, col: int) -> int:
return line_offsets[line] + col