diff --git a/src/mir/loop_canonicalizer/mod.rs b/src/mir/loop_canonicalizer/mod.rs new file mode 100644 index 00000000..804d5c08 --- /dev/null +++ b/src/mir/loop_canonicalizer/mod.rs @@ -0,0 +1,416 @@ +//! Loop Canonicalizer - AST Level Loop Preprocessing +//! +//! Phase 1: Type Definitions Only +//! +//! ## Purpose +//! +//! Decomposes AST-level loops into a normalized "skeleton" representation +//! to prevent combinatorial explosion in pattern detection and lowering. +//! +//! ## Design Principle +//! +//! - **Input**: AST (LoopExpr) +//! - **Output**: LoopSkeleton only (no JoinIR generation) +//! - **Boundary**: No JoinIR-specific information (BlockId, ValueId, etc.) +//! +//! ## Architecture +//! +//! ``` +//! AST → LoopSkeleton → Capability Guard → RoutingDecision → Pattern Lowerer +//! ``` +//! +//! ## References +//! +//! - Design SSOT: `docs/development/current/main/design/loop-canonicalizer.md` +//! - JoinIR Architecture: `docs/development/current/main/joinir-architecture-overview.md` +//! - Pattern Space: `docs/development/current/main/loop_pattern_space.md` + +use crate::ast::{ASTNode, Span}; +use crate::mir::loop_pattern_detection::LoopPatternKind; + +// ============================================================================ +// Core Skeleton Types +// ============================================================================ + +/// Loop skeleton - The canonical representation of a loop structure +/// +/// This is the single output type of the Canonicalizer. +/// It represents the essential structure of a loop without any +/// JoinIR-specific information. +#[derive(Debug, Clone)] +pub struct LoopSkeleton { + /// Sequence of steps (HeaderCond, BodyInit, BreakCheck, Updates, Tail) + pub steps: Vec, + + /// Carriers (loop variables with update rules and boundary crossing contracts) + pub carriers: Vec, + + /// Exit contract (presence and payload of break/continue/return) + pub exits: ExitContract, + + /// Captured variables from outer scope (optional) + pub captured: Option>, + + /// Source location for debugging + pub span: Span, +} + +/// Skeleton step - Minimal step kinds for loop structure +/// +/// Each step represents a fundamental operation in the loop lifecycle. +#[derive(Debug, Clone)] +pub enum SkeletonStep { + /// Loop continuation condition (the `cond` in `loop(cond)`) + HeaderCond { + expr: Box, + }, + + /// Early exit check (`if cond { break }`) + BreakCheck { + cond: Box, + has_value: bool, + }, + + /// Skip check (`if cond { continue }`) + ContinueCheck { + cond: Box, + }, + + /// Carrier update (`i = i + 1`, etc.) + Update { + carrier_name: String, + update_kind: UpdateKind, + }, + + /// Loop body (all other statements) + Body { + stmts: Vec, + }, +} + +/// Update kind - How a carrier variable is updated +/// +/// This categorization helps determine which pattern can handle the loop. +#[derive(Debug, Clone)] +pub enum UpdateKind { + /// Constant step (`i = i + const`) + ConstStep { + delta: i64, + }, + + /// Conditional update (`if cond { x = a } else { x = b }`) + Conditional { + then_value: Box, + else_value: Box, + }, + + /// Arbitrary update (everything else) + Arbitrary, +} + +/// Exit contract - What kinds of exits the loop has +/// +/// This determines the exit line architecture needed. +#[derive(Debug, Clone)] +pub struct ExitContract { + pub has_break: bool, + pub has_continue: bool, + pub has_return: bool, + pub break_has_value: bool, +} + +/// Carrier slot - A loop variable with its role and update rule +/// +/// Carriers are variables that are updated in each iteration +/// and need to cross loop boundaries (via PHI nodes in MIR). +#[derive(Debug, Clone)] +pub struct CarrierSlot { + pub name: String, + pub role: CarrierRole, + pub update_kind: UpdateKind, +} + +/// Carrier role - The semantic role of a carrier variable +/// +/// This helps determine the appropriate pattern and PHI structure. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CarrierRole { + /// Loop counter (the `i` in `i < n`) + Counter, + + /// Accumulator (the `sum` in `sum += x`) + Accumulator, + + /// Condition variable (the `is_valid` in `while(is_valid)`) + ConditionVar, + + /// Derived value (e.g., `digit_pos` computed from other carriers) + Derived, +} + +/// Captured slot - An outer variable used within the loop +/// +/// These are read-only references to variables defined outside the loop. +/// (Write access would make them carriers instead.) +#[derive(Debug, Clone)] +pub struct CapturedSlot { + pub name: String, + pub is_mutable: bool, +} + +// ============================================================================ +// Capability Guard - Fail-Fast Reasons +// ============================================================================ + +/// Routing decision - The result of pattern selection +/// +/// This contains both the chosen pattern (if any) and detailed +/// diagnostic information about why other patterns were rejected. +#[derive(Debug, Clone)] +pub struct RoutingDecision { + /// Selected pattern (None = Fail-Fast) + pub chosen: Option, + + /// Missing capabilities that prevented other patterns + pub missing_caps: Vec<&'static str>, + + /// Selection reasoning (for debugging) + pub notes: Vec, + + /// Error tags for contract_checks integration + pub error_tags: Vec, +} + +/// Capability tags - Standardized vocabulary for Fail-Fast reasons +/// +/// These constants define the capabilities required by different patterns. +/// When a loop lacks a required capability, it uses the corresponding tag +/// to explain why it cannot be lowered by that pattern. +pub mod capability_tags { + /// Requires: Carrier update is constant step (`i = i + const`) + pub const CAP_MISSING_CONST_STEP: &str = "CAP_MISSING_CONST_STEP"; + + /// Requires: Single break point only + pub const CAP_MISSING_SINGLE_BREAK: &str = "CAP_MISSING_SINGLE_BREAK"; + + /// Requires: Single continue point only + pub const CAP_MISSING_SINGLE_CONTINUE: &str = "CAP_MISSING_SINGLE_CONTINUE"; + + /// Requires: Loop header condition has no side effects + pub const CAP_MISSING_PURE_HEADER: &str = "CAP_MISSING_PURE_HEADER"; + + /// Requires: Condition variable defined in outer local scope + pub const CAP_MISSING_OUTER_LOCAL_COND: &str = "CAP_MISSING_OUTER_LOCAL_COND"; + + /// Requires: All exit bindings are complete (no missing values) + pub const CAP_MISSING_EXIT_BINDINGS: &str = "CAP_MISSING_EXIT_BINDINGS"; + + /// Requires: LoopBodyLocal can be promoted to carrier + pub const CAP_MISSING_CARRIER_PROMOTION: &str = "CAP_MISSING_CARRIER_PROMOTION"; + + /// Requires: Break value types are consistent across all break points + pub const CAP_MISSING_BREAK_VALUE_TYPE: &str = "CAP_MISSING_BREAK_VALUE_TYPE"; +} + +// ============================================================================ +// Implementation Helpers +// ============================================================================ + +impl LoopSkeleton { + /// Create a new empty skeleton + pub fn new(span: Span) -> Self { + Self { + steps: Vec::new(), + carriers: Vec::new(), + exits: ExitContract::default(), + captured: None, + span, + } + } + + /// Count the number of break checks in this skeleton + pub fn count_break_checks(&self) -> usize { + self.steps + .iter() + .filter(|s| matches!(s, SkeletonStep::BreakCheck { .. })) + .count() + } + + /// Count the number of continue checks in this skeleton + pub fn count_continue_checks(&self) -> usize { + self.steps + .iter() + .filter(|s| matches!(s, SkeletonStep::ContinueCheck { .. })) + .count() + } + + /// Get all carrier names + pub fn carrier_names(&self) -> Vec<&str> { + self.carriers.iter().map(|c| c.name.as_str()).collect() + } +} + +impl ExitContract { + /// Create a contract with no exits + pub fn none() -> Self { + Self { + has_break: false, + has_continue: false, + has_return: false, + break_has_value: false, + } + } + + /// Check if any exit exists + pub fn has_any_exit(&self) -> bool { + self.has_break || self.has_continue || self.has_return + } +} + +impl Default for ExitContract { + fn default() -> Self { + Self::none() + } +} + +impl RoutingDecision { + /// Create a successful routing decision + pub fn success(pattern: LoopPatternKind) -> Self { + Self { + chosen: Some(pattern), + missing_caps: Vec::new(), + notes: Vec::new(), + error_tags: Vec::new(), + } + } + + /// Create a failed routing decision (Fail-Fast) + pub fn fail_fast(missing_caps: Vec<&'static str>, reason: String) -> Self { + Self { + chosen: None, + missing_caps, + notes: vec![reason.clone()], + error_tags: vec![format!("[loop_canonicalizer/fail_fast] {}", reason)], + } + } + + /// Add a diagnostic note + pub fn add_note(&mut self, note: String) { + self.notes.push(note); + } + + /// Check if routing succeeded + pub fn is_success(&self) -> bool { + self.chosen.is_some() + } + + /// Check if routing failed + pub fn is_fail_fast(&self) -> bool { + self.chosen.is_none() + } +} + +// ============================================================================ +// Display Implementations +// ============================================================================ + +impl std::fmt::Display for CarrierRole { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CarrierRole::Counter => write!(f, "Counter"), + CarrierRole::Accumulator => write!(f, "Accumulator"), + CarrierRole::ConditionVar => write!(f, "ConditionVar"), + CarrierRole::Derived => write!(f, "Derived"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_skeleton_creation() { + let skeleton = LoopSkeleton::new(Span::unknown()); + assert_eq!(skeleton.steps.len(), 0); + assert_eq!(skeleton.carriers.len(), 0); + assert!(!skeleton.exits.has_any_exit()); + } + + #[test] + fn test_exit_contract() { + let mut contract = ExitContract::none(); + assert!(!contract.has_any_exit()); + + contract.has_break = true; + assert!(contract.has_any_exit()); + } + + #[test] + fn test_routing_decision() { + let success = RoutingDecision::success(LoopPatternKind::Pattern1SimpleWhile); + assert!(success.is_success()); + assert!(!success.is_fail_fast()); + + let fail = RoutingDecision::fail_fast( + vec![capability_tags::CAP_MISSING_CONST_STEP], + "Test failure".to_string(), + ); + assert!(!fail.is_success()); + assert!(fail.is_fail_fast()); + assert_eq!(fail.missing_caps.len(), 1); + } + + #[test] + fn test_carrier_role_display() { + assert_eq!(CarrierRole::Counter.to_string(), "Counter"); + assert_eq!(CarrierRole::Accumulator.to_string(), "Accumulator"); + assert_eq!(CarrierRole::ConditionVar.to_string(), "ConditionVar"); + assert_eq!(CarrierRole::Derived.to_string(), "Derived"); + } + + #[test] + fn test_skeleton_count_helpers() { + use crate::ast::LiteralValue; + + let mut skeleton = LoopSkeleton::new(Span::unknown()); + + skeleton.steps.push(SkeletonStep::BreakCheck { + cond: Box::new(ASTNode::Literal { + value: LiteralValue::Bool(true), + span: Span::unknown(), + }), + has_value: false, + }); + + skeleton.steps.push(SkeletonStep::ContinueCheck { + cond: Box::new(ASTNode::Literal { + value: LiteralValue::Bool(true), + span: Span::unknown(), + }), + }); + + assert_eq!(skeleton.count_break_checks(), 1); + assert_eq!(skeleton.count_continue_checks(), 1); + } + + #[test] + fn test_skeleton_carrier_names() { + let mut skeleton = LoopSkeleton::new(Span::unknown()); + + skeleton.carriers.push(CarrierSlot { + name: "i".to_string(), + role: CarrierRole::Counter, + update_kind: UpdateKind::ConstStep { delta: 1 }, + }); + + skeleton.carriers.push(CarrierSlot { + name: "sum".to_string(), + role: CarrierRole::Accumulator, + update_kind: UpdateKind::Arbitrary, + }); + + let names = skeleton.carrier_names(); + assert_eq!(names, vec!["i", "sum"]); + } +} diff --git a/src/mir/mod.rs b/src/mir/mod.rs index d5c449a0..3ca35b2f 100644 --- a/src/mir/mod.rs +++ b/src/mir/mod.rs @@ -18,6 +18,7 @@ pub mod instruction; pub mod instruction_introspection; // Introspection helpers for tests (instruction names) pub mod instruction_kinds; // small kind-specific metadata (Const/BinOp) pub mod loop_api; // Minimal LoopBuilder facade (adapter-ready) +pub mod loop_canonicalizer; // Phase 1: Loop skeleton canonicalization (AST preprocessing) pub mod naming; // Static box / entry naming rules(NamingBox) pub mod optimizer; pub mod ssot; // Shared helpers (SSOT) for instruction lowering