Files
hakorune/src/llvm_py/prepass/loops.py

107 lines
3.9 KiB
Python

"""
Loop prepass utilities
Detect simple while-shaped loops in MIR(JSON) and return a lowering plan.
"""
from typing import Dict, List, Any, Optional
from cfg.utils import build_preds_succs
def detect_simple_while(block_by_id: Dict[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Detect a simple loop pattern: header(branch cond → then/else),
a latch that jumps back to header reachable from then, and exit on else.
Returns a plan dict or None.
"""
# Build succ and pred maps from JSON quickly
preds, succs = build_preds_succs(block_by_id)
# Find a header with a branch terminator and else leading to a ret (direct)
for b in block_by_id.values():
bid = int(b.get('id', 0))
term = None
if b.get('instructions'):
last = b.get('instructions')[-1]
if last.get('op') in ('jump','branch','ret'):
term = last
if term is None and 'terminator' in b:
t = b['terminator']
if t and t.get('op') in ('jump','branch','ret'):
term = t
if not term or term.get('op') != 'branch':
continue
then_bid = int(term.get('then'))
else_bid = int(term.get('else'))
cond_vid = int(term.get('cond')) if term.get('cond') is not None else None
if cond_vid is None:
continue
# Quick check: else block ends with ret
else_blk = block_by_id.get(else_bid)
has_ret = False
if else_blk is not None:
insts = else_blk.get('instructions', [])
if insts and insts[-1].get('op') == 'ret':
has_ret = True
elif else_blk.get('terminator', {}).get('op') == 'ret':
has_ret = True
if not has_ret:
continue
# Find a latch that jumps back to header reachable from then
latch = None
visited = set()
stack = [then_bid]
while stack:
cur = stack.pop()
if cur in visited:
continue
visited.add(cur)
cur_blk = block_by_id.get(cur)
if cur_blk is None:
continue
for inst in cur_blk.get('instructions', []) or []:
if inst.get('op') == 'jump' and int(inst.get('target')) == bid:
latch = cur
break
if latch is not None:
break
for nx in succs.get(cur, []) or []:
if nx not in visited and nx != else_bid:
stack.append(nx)
if latch is None:
continue
# Compose body_insts: collect insts along then-branch region up to latch (inclusive),
# excluding any final jump back to header to prevent double edges.
collect_order: List[int] = []
visited2 = set()
stack2 = [then_bid]
while stack2:
cur = stack2.pop()
if cur in visited2 or cur == bid or cur == else_bid:
continue
visited2.add(cur)
collect_order.append(cur)
if cur == latch:
continue
for nx in succs.get(cur, []) or []:
if nx not in visited2 and nx != else_bid:
stack2.append(nx)
body_insts: List[Dict[str, Any]] = []
for bbid in collect_order:
blk = block_by_id.get(bbid)
if blk is None:
continue
for inst in blk.get('instructions', []) or []:
if inst.get('op') == 'jump' and int(inst.get('target', -1)) == bid:
continue
body_insts.append(inst)
skip_blocks = set(collect_order)
skip_blocks.add(bid)
return {
'header': bid,
'then': then_bid,
'else': else_bid,
'latch': latch,
'exit': else_bid,
'cond': cond_vid,
'body_insts': body_insts,
'skip_blocks': list(skip_blocks),
}
return None