Files
hakorune/src/mir/join_ir/lowering/loop_update_analyzer.rs

214 lines
6.8 KiB
Rust
Raw Normal View History

//! Loop Update Expression Analyzer
//!
//! Phase 197: Extracts update expressions from loop body to generate
//! correct carrier update semantics.
//!
//! # Purpose
//!
//! The Pattern 4 lowerer needs to know how each carrier variable is updated
//! in the loop body. Instead of hardcoding "count uses +1, sum uses +i",
//! we extract the actual update expressions from the AST.
//!
//! # Example
//!
//! ```nyash
//! loop(i < 10) {
//! i = i + 1 // UpdateExpr::BinOp { lhs: "i", op: Add, rhs: Const(1) }
//! sum = sum + i // UpdateExpr::BinOp { lhs: "sum", op: Add, rhs: "i" }
//! count = count + 1 // UpdateExpr::BinOp { lhs: "count", op: Add, rhs: Const(1) }
//! }
//! ```
use crate::ast::{ASTNode, BinaryOperator, LiteralValue};
use crate::mir::join_ir::lowering::carrier_info::CarrierVar;
use crate::mir::join_ir::BinOpKind;
use std::collections::HashMap;
/// Update expression for a carrier variable
#[derive(Debug, Clone)]
pub enum UpdateExpr {
/// Constant increment: carrier = carrier + N
Const(i64),
/// Binary operation: carrier = carrier op rhs
BinOp {
lhs: String,
op: BinOpKind,
rhs: UpdateRhs,
},
}
/// Right-hand side of update expression
#[derive(Debug, Clone)]
pub enum UpdateRhs {
Const(i64),
Variable(String),
}
pub struct LoopUpdateAnalyzer;
impl LoopUpdateAnalyzer {
/// Analyze carrier update expressions from loop body
///
/// Extracts update patterns like:
/// - `sum = sum + i` → BinOp { lhs: "sum", op: Add, rhs: Variable("i") }
/// - `count = count + 1` → BinOp { lhs: "count", op: Add, rhs: Const(1) }
///
/// # Parameters
/// - `body`: Loop body AST nodes
/// - `carriers`: Carrier variables to analyze
///
/// # Returns
/// Map from carrier name to UpdateExpr
pub fn analyze_carrier_updates(
body: &[ASTNode],
carriers: &[CarrierVar],
) -> HashMap<String, UpdateExpr> {
let mut updates = HashMap::new();
// Extract carrier names for quick lookup
let carrier_names: Vec<&str> = carriers.iter().map(|c| c.name.as_str()).collect();
// Scan all statements in the loop body
for node in body {
if let ASTNode::Assignment { target, value, .. } = node {
// Check if this is a carrier update (e.g., sum = sum + i)
if let Some(target_name) = Self::extract_variable_name(target) {
if carrier_names.contains(&target_name.as_str()) {
// This is a carrier update, analyze the RHS
if let Some(update_expr) = Self::analyze_update_value(&target_name, value) {
updates.insert(target_name, update_expr);
}
}
}
}
}
updates
}
/// Extract variable name from AST node (for assignment target)
fn extract_variable_name(node: &ASTNode) -> Option<String> {
match node {
ASTNode::Variable { name, .. } => Some(name.clone()),
_ => None,
}
}
/// Analyze update value expression
///
/// Recognizes patterns like:
/// - `sum + i` → BinOp { lhs: "sum", op: Add, rhs: Variable("i") }
/// - `count + 1` → BinOp { lhs: "count", op: Add, rhs: Const(1) }
fn analyze_update_value(carrier_name: &str, value: &ASTNode) -> Option<UpdateExpr> {
match value {
ASTNode::BinaryOp {
operator,
left,
right,
..
} => {
// Check if LHS is the carrier itself (e.g., sum in "sum + i")
if let Some(lhs_name) = Self::extract_variable_name(left) {
if lhs_name == carrier_name {
// Convert operator
let op = Self::convert_operator(operator)?;
// Analyze RHS
let rhs = Self::analyze_rhs(right)?;
return Some(UpdateExpr::BinOp {
lhs: lhs_name,
op,
rhs,
});
}
}
None
}
_ => None,
}
}
/// Analyze right-hand side of update expression
fn analyze_rhs(node: &ASTNode) -> Option<UpdateRhs> {
match node {
// Constant: count + 1
ASTNode::Literal {
value: LiteralValue::Integer(n),
..
} => Some(UpdateRhs::Const(*n)),
// Variable: sum + i
ASTNode::Variable { name, .. } => {
Some(UpdateRhs::Variable(name.clone()))
}
_ => None,
}
}
/// Convert AST operator to MIR BinOpKind
fn convert_operator(op: &BinaryOperator) -> Option<BinOpKind> {
match op {
BinaryOperator::Add => Some(BinOpKind::Add),
BinaryOperator::Subtract => Some(BinOpKind::Sub),
BinaryOperator::Multiply => Some(BinOpKind::Mul),
BinaryOperator::Divide => Some(BinOpKind::Div),
_ => None, // Only support arithmetic operators for now
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_analyze_simple_increment() {
// Test case: count = count + 1
use crate::ast::Span;
let body = vec![ASTNode::Assignment {
target: Box::new(ASTNode::Variable {
name: "count".to_string(),
span: Span::default(),
}),
value: Box::new(ASTNode::BinaryOp {
operator: BinaryOperator::Add,
left: Box::new(ASTNode::Variable {
name: "count".to_string(),
span: Span::default(),
}),
right: Box::new(ASTNode::Literal {
value: LiteralValue::Integer(1),
span: Span::default(),
}),
span: Span::default(),
}),
span: Span::default(),
}];
let carriers = vec![CarrierVar {
name: "count".to_string(),
host_id: crate::mir::ValueId(0),
}];
let updates = LoopUpdateAnalyzer::analyze_carrier_updates(&body, &carriers);
assert_eq!(updates.len(), 1);
assert!(updates.contains_key("count"));
if let Some(UpdateExpr::BinOp { lhs, op, rhs }) = updates.get("count") {
assert_eq!(lhs, "count");
assert_eq!(*op, BinOpKind::Add);
if let UpdateRhs::Const(n) = rhs {
assert_eq!(*n, 1);
} else {
panic!("Expected Const(1), got {:?}", rhs);
}
} else {
panic!("Expected BinOp, got {:?}", updates.get("count"));
}
}
}