//! Loop Update Expression Analyzer //! //! Phase 197: Extracts update expressions from loop body to generate //! correct carrier update semantics. //! //! # Purpose //! //! The Pattern 4 lowerer needs to know how each carrier variable is updated //! in the loop body. Instead of hardcoding "count uses +1, sum uses +i", //! we extract the actual update expressions from the AST. //! //! # Example //! //! ```nyash //! loop(i < 10) { //! i = i + 1 // UpdateExpr::BinOp { lhs: "i", op: Add, rhs: Const(1) } //! sum = sum + i // UpdateExpr::BinOp { lhs: "sum", op: Add, rhs: "i" } //! count = count + 1 // UpdateExpr::BinOp { lhs: "count", op: Add, rhs: Const(1) } //! } //! ``` use crate::ast::{ASTNode, BinaryOperator, LiteralValue}; use crate::mir::join_ir::lowering::carrier_info::CarrierVar; use crate::mir::join_ir::BinOpKind; use std::collections::HashMap; /// Update expression for a carrier variable #[derive(Debug, Clone)] pub enum UpdateExpr { /// Constant increment: carrier = carrier + N Const(i64), /// Binary operation: carrier = carrier op rhs BinOp { lhs: String, op: BinOpKind, rhs: UpdateRhs, }, } /// Right-hand side of update expression #[derive(Debug, Clone)] pub enum UpdateRhs { Const(i64), Variable(String), } pub struct LoopUpdateAnalyzer; impl LoopUpdateAnalyzer { /// Analyze carrier update expressions from loop body /// /// Extracts update patterns like: /// - `sum = sum + i` → BinOp { lhs: "sum", op: Add, rhs: Variable("i") } /// - `count = count + 1` → BinOp { lhs: "count", op: Add, rhs: Const(1) } /// /// # Parameters /// - `body`: Loop body AST nodes /// - `carriers`: Carrier variables to analyze /// /// # Returns /// Map from carrier name to UpdateExpr pub fn analyze_carrier_updates( body: &[ASTNode], carriers: &[CarrierVar], ) -> HashMap { let mut updates = HashMap::new(); // Extract carrier names for quick lookup let carrier_names: Vec<&str> = carriers.iter().map(|c| c.name.as_str()).collect(); // Scan all statements in the loop body for node in body { if let ASTNode::Assignment { target, value, .. } = node { // Check if this is a carrier update (e.g., sum = sum + i) if let Some(target_name) = Self::extract_variable_name(target) { if carrier_names.contains(&target_name.as_str()) { // This is a carrier update, analyze the RHS if let Some(update_expr) = Self::analyze_update_value(&target_name, value) { updates.insert(target_name, update_expr); } } } } } updates } /// Extract variable name from AST node (for assignment target) fn extract_variable_name(node: &ASTNode) -> Option { match node { ASTNode::Variable { name, .. } => Some(name.clone()), _ => None, } } /// Analyze update value expression /// /// Recognizes patterns like: /// - `sum + i` → BinOp { lhs: "sum", op: Add, rhs: Variable("i") } /// - `count + 1` → BinOp { lhs: "count", op: Add, rhs: Const(1) } fn analyze_update_value(carrier_name: &str, value: &ASTNode) -> Option { match value { ASTNode::BinaryOp { operator, left, right, .. } => { // Check if LHS is the carrier itself (e.g., sum in "sum + i") if let Some(lhs_name) = Self::extract_variable_name(left) { if lhs_name == carrier_name { // Convert operator let op = Self::convert_operator(operator)?; // Analyze RHS let rhs = Self::analyze_rhs(right)?; return Some(UpdateExpr::BinOp { lhs: lhs_name, op, rhs, }); } } None } _ => None, } } /// Analyze right-hand side of update expression fn analyze_rhs(node: &ASTNode) -> Option { match node { // Constant: count + 1 ASTNode::Literal { value: LiteralValue::Integer(n), .. } => Some(UpdateRhs::Const(*n)), // Variable: sum + i ASTNode::Variable { name, .. } => { Some(UpdateRhs::Variable(name.clone())) } _ => None, } } /// Convert AST operator to MIR BinOpKind fn convert_operator(op: &BinaryOperator) -> Option { match op { BinaryOperator::Add => Some(BinOpKind::Add), BinaryOperator::Subtract => Some(BinOpKind::Sub), BinaryOperator::Multiply => Some(BinOpKind::Mul), BinaryOperator::Divide => Some(BinOpKind::Div), _ => None, // Only support arithmetic operators for now } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_analyze_simple_increment() { // Test case: count = count + 1 use crate::ast::Span; let body = vec![ASTNode::Assignment { target: Box::new(ASTNode::Variable { name: "count".to_string(), span: Span::default(), }), value: Box::new(ASTNode::BinaryOp { operator: BinaryOperator::Add, left: Box::new(ASTNode::Variable { name: "count".to_string(), span: Span::default(), }), right: Box::new(ASTNode::Literal { value: LiteralValue::Integer(1), span: Span::default(), }), span: Span::default(), }), span: Span::default(), }]; let carriers = vec![CarrierVar { name: "count".to_string(), host_id: crate::mir::ValueId(0), }]; let updates = LoopUpdateAnalyzer::analyze_carrier_updates(&body, &carriers); assert_eq!(updates.len(), 1); assert!(updates.contains_key("count")); if let Some(UpdateExpr::BinOp { lhs, op, rhs }) = updates.get("count") { assert_eq!(lhs, "count"); assert_eq!(*op, BinOpKind::Add); if let UpdateRhs::Const(n) = rhs { assert_eq!(*n, 1); } else { panic!("Expected Const(1), got {:?}", rhs); } } else { panic!("Expected BinOp, got {:?}", updates.get("count")); } } }