417 lines
12 KiB
Rust
417 lines
12 KiB
Rust
|
|
//! Loop Canonicalizer - AST Level Loop Preprocessing
|
||
|
|
//!
|
||
|
|
//! Phase 1: Type Definitions Only
|
||
|
|
//!
|
||
|
|
//! ## Purpose
|
||
|
|
//!
|
||
|
|
//! Decomposes AST-level loops into a normalized "skeleton" representation
|
||
|
|
//! to prevent combinatorial explosion in pattern detection and lowering.
|
||
|
|
//!
|
||
|
|
//! ## Design Principle
|
||
|
|
//!
|
||
|
|
//! - **Input**: AST (LoopExpr)
|
||
|
|
//! - **Output**: LoopSkeleton only (no JoinIR generation)
|
||
|
|
//! - **Boundary**: No JoinIR-specific information (BlockId, ValueId, etc.)
|
||
|
|
//!
|
||
|
|
//! ## Architecture
|
||
|
|
//!
|
||
|
|
//! ```
|
||
|
|
//! AST → LoopSkeleton → Capability Guard → RoutingDecision → Pattern Lowerer
|
||
|
|
//! ```
|
||
|
|
//!
|
||
|
|
//! ## References
|
||
|
|
//!
|
||
|
|
//! - Design SSOT: `docs/development/current/main/design/loop-canonicalizer.md`
|
||
|
|
//! - JoinIR Architecture: `docs/development/current/main/joinir-architecture-overview.md`
|
||
|
|
//! - Pattern Space: `docs/development/current/main/loop_pattern_space.md`
|
||
|
|
|
||
|
|
use crate::ast::{ASTNode, Span};
|
||
|
|
use crate::mir::loop_pattern_detection::LoopPatternKind;
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Core Skeleton Types
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
/// Loop skeleton - The canonical representation of a loop structure
|
||
|
|
///
|
||
|
|
/// This is the single output type of the Canonicalizer.
|
||
|
|
/// It represents the essential structure of a loop without any
|
||
|
|
/// JoinIR-specific information.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct LoopSkeleton {
|
||
|
|
/// Sequence of steps (HeaderCond, BodyInit, BreakCheck, Updates, Tail)
|
||
|
|
pub steps: Vec<SkeletonStep>,
|
||
|
|
|
||
|
|
/// Carriers (loop variables with update rules and boundary crossing contracts)
|
||
|
|
pub carriers: Vec<CarrierSlot>,
|
||
|
|
|
||
|
|
/// Exit contract (presence and payload of break/continue/return)
|
||
|
|
pub exits: ExitContract,
|
||
|
|
|
||
|
|
/// Captured variables from outer scope (optional)
|
||
|
|
pub captured: Option<Vec<CapturedSlot>>,
|
||
|
|
|
||
|
|
/// Source location for debugging
|
||
|
|
pub span: Span,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Skeleton step - Minimal step kinds for loop structure
|
||
|
|
///
|
||
|
|
/// Each step represents a fundamental operation in the loop lifecycle.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub enum SkeletonStep {
|
||
|
|
/// Loop continuation condition (the `cond` in `loop(cond)`)
|
||
|
|
HeaderCond {
|
||
|
|
expr: Box<ASTNode>,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Early exit check (`if cond { break }`)
|
||
|
|
BreakCheck {
|
||
|
|
cond: Box<ASTNode>,
|
||
|
|
has_value: bool,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Skip check (`if cond { continue }`)
|
||
|
|
ContinueCheck {
|
||
|
|
cond: Box<ASTNode>,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Carrier update (`i = i + 1`, etc.)
|
||
|
|
Update {
|
||
|
|
carrier_name: String,
|
||
|
|
update_kind: UpdateKind,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Loop body (all other statements)
|
||
|
|
Body {
|
||
|
|
stmts: Vec<ASTNode>,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Update kind - How a carrier variable is updated
|
||
|
|
///
|
||
|
|
/// This categorization helps determine which pattern can handle the loop.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub enum UpdateKind {
|
||
|
|
/// Constant step (`i = i + const`)
|
||
|
|
ConstStep {
|
||
|
|
delta: i64,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Conditional update (`if cond { x = a } else { x = b }`)
|
||
|
|
Conditional {
|
||
|
|
then_value: Box<ASTNode>,
|
||
|
|
else_value: Box<ASTNode>,
|
||
|
|
},
|
||
|
|
|
||
|
|
/// Arbitrary update (everything else)
|
||
|
|
Arbitrary,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Exit contract - What kinds of exits the loop has
|
||
|
|
///
|
||
|
|
/// This determines the exit line architecture needed.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct ExitContract {
|
||
|
|
pub has_break: bool,
|
||
|
|
pub has_continue: bool,
|
||
|
|
pub has_return: bool,
|
||
|
|
pub break_has_value: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Carrier slot - A loop variable with its role and update rule
|
||
|
|
///
|
||
|
|
/// Carriers are variables that are updated in each iteration
|
||
|
|
/// and need to cross loop boundaries (via PHI nodes in MIR).
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct CarrierSlot {
|
||
|
|
pub name: String,
|
||
|
|
pub role: CarrierRole,
|
||
|
|
pub update_kind: UpdateKind,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Carrier role - The semantic role of a carrier variable
|
||
|
|
///
|
||
|
|
/// This helps determine the appropriate pattern and PHI structure.
|
||
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
|
|
pub enum CarrierRole {
|
||
|
|
/// Loop counter (the `i` in `i < n`)
|
||
|
|
Counter,
|
||
|
|
|
||
|
|
/// Accumulator (the `sum` in `sum += x`)
|
||
|
|
Accumulator,
|
||
|
|
|
||
|
|
/// Condition variable (the `is_valid` in `while(is_valid)`)
|
||
|
|
ConditionVar,
|
||
|
|
|
||
|
|
/// Derived value (e.g., `digit_pos` computed from other carriers)
|
||
|
|
Derived,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Captured slot - An outer variable used within the loop
|
||
|
|
///
|
||
|
|
/// These are read-only references to variables defined outside the loop.
|
||
|
|
/// (Write access would make them carriers instead.)
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct CapturedSlot {
|
||
|
|
pub name: String,
|
||
|
|
pub is_mutable: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Capability Guard - Fail-Fast Reasons
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
/// Routing decision - The result of pattern selection
|
||
|
|
///
|
||
|
|
/// This contains both the chosen pattern (if any) and detailed
|
||
|
|
/// diagnostic information about why other patterns were rejected.
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct RoutingDecision {
|
||
|
|
/// Selected pattern (None = Fail-Fast)
|
||
|
|
pub chosen: Option<LoopPatternKind>,
|
||
|
|
|
||
|
|
/// Missing capabilities that prevented other patterns
|
||
|
|
pub missing_caps: Vec<&'static str>,
|
||
|
|
|
||
|
|
/// Selection reasoning (for debugging)
|
||
|
|
pub notes: Vec<String>,
|
||
|
|
|
||
|
|
/// Error tags for contract_checks integration
|
||
|
|
pub error_tags: Vec<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Capability tags - Standardized vocabulary for Fail-Fast reasons
|
||
|
|
///
|
||
|
|
/// These constants define the capabilities required by different patterns.
|
||
|
|
/// When a loop lacks a required capability, it uses the corresponding tag
|
||
|
|
/// to explain why it cannot be lowered by that pattern.
|
||
|
|
pub mod capability_tags {
|
||
|
|
/// Requires: Carrier update is constant step (`i = i + const`)
|
||
|
|
pub const CAP_MISSING_CONST_STEP: &str = "CAP_MISSING_CONST_STEP";
|
||
|
|
|
||
|
|
/// Requires: Single break point only
|
||
|
|
pub const CAP_MISSING_SINGLE_BREAK: &str = "CAP_MISSING_SINGLE_BREAK";
|
||
|
|
|
||
|
|
/// Requires: Single continue point only
|
||
|
|
pub const CAP_MISSING_SINGLE_CONTINUE: &str = "CAP_MISSING_SINGLE_CONTINUE";
|
||
|
|
|
||
|
|
/// Requires: Loop header condition has no side effects
|
||
|
|
pub const CAP_MISSING_PURE_HEADER: &str = "CAP_MISSING_PURE_HEADER";
|
||
|
|
|
||
|
|
/// Requires: Condition variable defined in outer local scope
|
||
|
|
pub const CAP_MISSING_OUTER_LOCAL_COND: &str = "CAP_MISSING_OUTER_LOCAL_COND";
|
||
|
|
|
||
|
|
/// Requires: All exit bindings are complete (no missing values)
|
||
|
|
pub const CAP_MISSING_EXIT_BINDINGS: &str = "CAP_MISSING_EXIT_BINDINGS";
|
||
|
|
|
||
|
|
/// Requires: LoopBodyLocal can be promoted to carrier
|
||
|
|
pub const CAP_MISSING_CARRIER_PROMOTION: &str = "CAP_MISSING_CARRIER_PROMOTION";
|
||
|
|
|
||
|
|
/// Requires: Break value types are consistent across all break points
|
||
|
|
pub const CAP_MISSING_BREAK_VALUE_TYPE: &str = "CAP_MISSING_BREAK_VALUE_TYPE";
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Implementation Helpers
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
impl LoopSkeleton {
|
||
|
|
/// Create a new empty skeleton
|
||
|
|
pub fn new(span: Span) -> Self {
|
||
|
|
Self {
|
||
|
|
steps: Vec::new(),
|
||
|
|
carriers: Vec::new(),
|
||
|
|
exits: ExitContract::default(),
|
||
|
|
captured: None,
|
||
|
|
span,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Count the number of break checks in this skeleton
|
||
|
|
pub fn count_break_checks(&self) -> usize {
|
||
|
|
self.steps
|
||
|
|
.iter()
|
||
|
|
.filter(|s| matches!(s, SkeletonStep::BreakCheck { .. }))
|
||
|
|
.count()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Count the number of continue checks in this skeleton
|
||
|
|
pub fn count_continue_checks(&self) -> usize {
|
||
|
|
self.steps
|
||
|
|
.iter()
|
||
|
|
.filter(|s| matches!(s, SkeletonStep::ContinueCheck { .. }))
|
||
|
|
.count()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get all carrier names
|
||
|
|
pub fn carrier_names(&self) -> Vec<&str> {
|
||
|
|
self.carriers.iter().map(|c| c.name.as_str()).collect()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ExitContract {
|
||
|
|
/// Create a contract with no exits
|
||
|
|
pub fn none() -> Self {
|
||
|
|
Self {
|
||
|
|
has_break: false,
|
||
|
|
has_continue: false,
|
||
|
|
has_return: false,
|
||
|
|
break_has_value: false,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check if any exit exists
|
||
|
|
pub fn has_any_exit(&self) -> bool {
|
||
|
|
self.has_break || self.has_continue || self.has_return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for ExitContract {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self::none()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl RoutingDecision {
|
||
|
|
/// Create a successful routing decision
|
||
|
|
pub fn success(pattern: LoopPatternKind) -> Self {
|
||
|
|
Self {
|
||
|
|
chosen: Some(pattern),
|
||
|
|
missing_caps: Vec::new(),
|
||
|
|
notes: Vec::new(),
|
||
|
|
error_tags: Vec::new(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Create a failed routing decision (Fail-Fast)
|
||
|
|
pub fn fail_fast(missing_caps: Vec<&'static str>, reason: String) -> Self {
|
||
|
|
Self {
|
||
|
|
chosen: None,
|
||
|
|
missing_caps,
|
||
|
|
notes: vec![reason.clone()],
|
||
|
|
error_tags: vec![format!("[loop_canonicalizer/fail_fast] {}", reason)],
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Add a diagnostic note
|
||
|
|
pub fn add_note(&mut self, note: String) {
|
||
|
|
self.notes.push(note);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check if routing succeeded
|
||
|
|
pub fn is_success(&self) -> bool {
|
||
|
|
self.chosen.is_some()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check if routing failed
|
||
|
|
pub fn is_fail_fast(&self) -> bool {
|
||
|
|
self.chosen.is_none()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Display Implementations
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
impl std::fmt::Display for CarrierRole {
|
||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
|
match self {
|
||
|
|
CarrierRole::Counter => write!(f, "Counter"),
|
||
|
|
CarrierRole::Accumulator => write!(f, "Accumulator"),
|
||
|
|
CarrierRole::ConditionVar => write!(f, "ConditionVar"),
|
||
|
|
CarrierRole::Derived => write!(f, "Derived"),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_skeleton_creation() {
|
||
|
|
let skeleton = LoopSkeleton::new(Span::unknown());
|
||
|
|
assert_eq!(skeleton.steps.len(), 0);
|
||
|
|
assert_eq!(skeleton.carriers.len(), 0);
|
||
|
|
assert!(!skeleton.exits.has_any_exit());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_exit_contract() {
|
||
|
|
let mut contract = ExitContract::none();
|
||
|
|
assert!(!contract.has_any_exit());
|
||
|
|
|
||
|
|
contract.has_break = true;
|
||
|
|
assert!(contract.has_any_exit());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_routing_decision() {
|
||
|
|
let success = RoutingDecision::success(LoopPatternKind::Pattern1SimpleWhile);
|
||
|
|
assert!(success.is_success());
|
||
|
|
assert!(!success.is_fail_fast());
|
||
|
|
|
||
|
|
let fail = RoutingDecision::fail_fast(
|
||
|
|
vec![capability_tags::CAP_MISSING_CONST_STEP],
|
||
|
|
"Test failure".to_string(),
|
||
|
|
);
|
||
|
|
assert!(!fail.is_success());
|
||
|
|
assert!(fail.is_fail_fast());
|
||
|
|
assert_eq!(fail.missing_caps.len(), 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_carrier_role_display() {
|
||
|
|
assert_eq!(CarrierRole::Counter.to_string(), "Counter");
|
||
|
|
assert_eq!(CarrierRole::Accumulator.to_string(), "Accumulator");
|
||
|
|
assert_eq!(CarrierRole::ConditionVar.to_string(), "ConditionVar");
|
||
|
|
assert_eq!(CarrierRole::Derived.to_string(), "Derived");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_skeleton_count_helpers() {
|
||
|
|
use crate::ast::LiteralValue;
|
||
|
|
|
||
|
|
let mut skeleton = LoopSkeleton::new(Span::unknown());
|
||
|
|
|
||
|
|
skeleton.steps.push(SkeletonStep::BreakCheck {
|
||
|
|
cond: Box::new(ASTNode::Literal {
|
||
|
|
value: LiteralValue::Bool(true),
|
||
|
|
span: Span::unknown(),
|
||
|
|
}),
|
||
|
|
has_value: false,
|
||
|
|
});
|
||
|
|
|
||
|
|
skeleton.steps.push(SkeletonStep::ContinueCheck {
|
||
|
|
cond: Box::new(ASTNode::Literal {
|
||
|
|
value: LiteralValue::Bool(true),
|
||
|
|
span: Span::unknown(),
|
||
|
|
}),
|
||
|
|
});
|
||
|
|
|
||
|
|
assert_eq!(skeleton.count_break_checks(), 1);
|
||
|
|
assert_eq!(skeleton.count_continue_checks(), 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_skeleton_carrier_names() {
|
||
|
|
let mut skeleton = LoopSkeleton::new(Span::unknown());
|
||
|
|
|
||
|
|
skeleton.carriers.push(CarrierSlot {
|
||
|
|
name: "i".to_string(),
|
||
|
|
role: CarrierRole::Counter,
|
||
|
|
update_kind: UpdateKind::ConstStep { delta: 1 },
|
||
|
|
});
|
||
|
|
|
||
|
|
skeleton.carriers.push(CarrierSlot {
|
||
|
|
name: "sum".to_string(),
|
||
|
|
role: CarrierRole::Accumulator,
|
||
|
|
update_kind: UpdateKind::Arbitrary,
|
||
|
|
});
|
||
|
|
|
||
|
|
let names = skeleton.carrier_names();
|
||
|
|
assert_eq!(names, vec!["i", "sum"]);
|
||
|
|
}
|
||
|
|
}
|