feat(joinir): Phase 213 AST-based if-sum lowerer for Pattern 3
Implement dual-mode architecture for Pattern 3 (Loop with If-Else PHI): - Add is_simple_if_sum_pattern() detection helper - Detects 1 CounterLike + 1-2 AccumulationLike carrier patterns - Unit tests for various carrier compositions - Add dual-mode dispatch in Pattern3 lowerer - ctx.is_if_sum_pattern() branches to AST-based vs legacy PoC - Legacy mode preserved for backward compatibility - Create loop_with_if_phi_if_sum.rs (~420 lines) - AST extraction: loop condition, if condition, updates - JoinIR generation: main, loop_step, k_exit structure - Helper functions: extract_loop_condition, extract_if_condition, etc. - Extend PatternPipelineContext for Pattern 3 - is_if_sum_pattern() detection using LoopUpdateSummary - extract_if_statement() helper for body analysis Note: E2E RC=2 not yet achieved due to pre-existing Pattern 3 pipeline issue (loop back branch targets wrong block). This affects both if-sum and legacy modes. Fix planned for Phase 214. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@ -1,4 +1,19 @@
|
||||
//! Pattern 3: Loop with If-Else PHI minimal lowerer
|
||||
//!
|
||||
//! # Phase 213: Dual-Mode Architecture
|
||||
//!
|
||||
//! Pattern 3 supports two lowering modes:
|
||||
//!
|
||||
//! 1. **AST-based if-sum mode** (Phase 213+)
|
||||
//! - Triggered when `ctx.is_if_sum_pattern()` returns true
|
||||
//! - Uses AST from `ctx.loop_condition` and `ctx.loop_body`
|
||||
//! - Dynamically lowers loop condition, if condition, and carrier updates
|
||||
//! - Target: `phase212_if_sum_min.hako` (RC=2)
|
||||
//!
|
||||
//! 2. **Legacy PoC mode** (Phase 188-195)
|
||||
//! - Fallback for test-only patterns (e.g., `loop_if_phi.hako`)
|
||||
//! - Hardcoded loop condition (i <= 5), if condition (i % 2 == 1)
|
||||
//! - Kept for backward compatibility with existing tests
|
||||
|
||||
use crate::ast::ASTNode;
|
||||
use crate::mir::builder::MirBuilder;
|
||||
@ -37,11 +52,17 @@ impl MirBuilder {
|
||||
///
|
||||
/// **Refactored**: Now uses PatternPipelineContext for unified preprocessing
|
||||
///
|
||||
/// # Phase 213: Dual-Mode Architecture
|
||||
///
|
||||
/// - **if-sum mode**: When `ctx.is_if_sum_pattern()` is true, uses AST-based lowering
|
||||
/// - **legacy mode**: Otherwise, uses hardcoded PoC lowering for backward compatibility
|
||||
///
|
||||
/// # Pipeline (Phase 179-B)
|
||||
/// 1. Build preprocessing context → PatternPipelineContext
|
||||
/// 2. Call JoinIR lowerer → JoinModule
|
||||
/// 3. Create boundary from context → JoinInlineBoundary
|
||||
/// 4. Merge MIR blocks → JoinIRConversionPipeline
|
||||
/// 2. Check if-sum pattern → branch to appropriate lowerer
|
||||
/// 3. Call JoinIR lowerer → JoinModule
|
||||
/// 4. Create boundary from context → JoinInlineBoundary
|
||||
/// 5. Merge MIR blocks → JoinIRConversionPipeline
|
||||
pub(in crate::mir::builder) fn cf_loop_pattern3_with_if_phi(
|
||||
&mut self,
|
||||
condition: &ASTNode,
|
||||
@ -49,11 +70,6 @@ impl MirBuilder {
|
||||
_func_name: &str,
|
||||
debug: bool,
|
||||
) -> Result<Option<ValueId>, String> {
|
||||
use crate::mir::join_ir::lowering::loop_with_if_phi_minimal::lower_loop_with_if_phi_pattern;
|
||||
|
||||
// Phase 195: Use unified trace
|
||||
trace::trace().debug("pattern3", "Calling Pattern 3 minimal lowerer");
|
||||
|
||||
// Phase 179-B: Use PatternPipelineContext for unified preprocessing
|
||||
use super::pattern_pipeline::{build_pattern_context, PatternVariant};
|
||||
let ctx = build_pattern_context(
|
||||
@ -63,6 +79,115 @@ impl MirBuilder {
|
||||
PatternVariant::Pattern3,
|
||||
)?;
|
||||
|
||||
// Phase 213: Dual-mode dispatch based on if-sum pattern detection
|
||||
if ctx.is_if_sum_pattern() {
|
||||
trace::trace().debug("pattern3", "Detected if-sum pattern, using AST-based lowerer");
|
||||
return self.lower_pattern3_if_sum(&ctx, condition, body, debug);
|
||||
}
|
||||
|
||||
// Legacy mode: Use hardcoded PoC lowering (e.g., loop_if_phi.hako)
|
||||
trace::trace().debug("pattern3", "Using legacy PoC lowerer (hardcoded conditions)");
|
||||
self.lower_pattern3_legacy(&ctx, debug)
|
||||
}
|
||||
|
||||
/// Phase 213: AST-based if-sum lowerer
|
||||
///
|
||||
/// Dynamically lowers loop condition, if condition, and carrier updates from AST.
|
||||
/// Target: `phase212_if_sum_min.hako` (RC=2)
|
||||
fn lower_pattern3_if_sum(
|
||||
&mut self,
|
||||
ctx: &super::pattern_pipeline::PatternPipelineContext,
|
||||
condition: &ASTNode,
|
||||
body: &[ASTNode],
|
||||
debug: bool,
|
||||
) -> Result<Option<ValueId>, String> {
|
||||
use crate::mir::join_ir::lowering::loop_with_if_phi_if_sum::lower_if_sum_pattern;
|
||||
|
||||
// Phase 202-B: Create JoinValueSpace for unified ValueId allocation
|
||||
use crate::mir::join_ir::lowering::join_value_space::JoinValueSpace;
|
||||
let mut join_value_space = JoinValueSpace::new();
|
||||
|
||||
// Extract if statement from loop body
|
||||
let if_stmt = ctx.extract_if_statement().ok_or_else(|| {
|
||||
"[cf_loop/pattern3] if-sum pattern detected but no if statement found".to_string()
|
||||
})?;
|
||||
|
||||
// Call AST-based if-sum lowerer
|
||||
let (join_module, fragment_meta) = lower_if_sum_pattern(
|
||||
condition,
|
||||
if_stmt,
|
||||
body,
|
||||
&mut join_value_space,
|
||||
)?;
|
||||
|
||||
let exit_meta = &fragment_meta.exit_meta;
|
||||
|
||||
trace::trace().debug(
|
||||
"pattern3/if-sum",
|
||||
&format!("ExitMeta: {} exit values", exit_meta.exit_values.len())
|
||||
);
|
||||
for (carrier_name, join_value) in &exit_meta.exit_values {
|
||||
trace::trace().debug(
|
||||
"pattern3/if-sum",
|
||||
&format!(" {} → ValueId({})", carrier_name, join_value.0)
|
||||
);
|
||||
}
|
||||
|
||||
// Build exit bindings using ExitMetaCollector
|
||||
let exit_bindings = ExitMetaCollector::collect(self, exit_meta, debug);
|
||||
|
||||
// Build boundary with carrier inputs
|
||||
use crate::mir::join_ir::lowering::JoinInlineBoundaryBuilder;
|
||||
use crate::mir::builder::emission::constant;
|
||||
|
||||
// Phase 213: Build join_inputs and host_inputs based on carriers
|
||||
let join_inputs = vec![ValueId(0), ValueId(1), ValueId(2)];
|
||||
let mut host_inputs = vec![ctx.loop_var_id];
|
||||
|
||||
// Add accumulator carriers (sum, optionally count)
|
||||
for carrier in &ctx.carrier_info.carriers {
|
||||
if carrier.name != ctx.loop_var_name {
|
||||
host_inputs.push(carrier.host_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Pad to 3 inputs if needed (for legacy compatibility)
|
||||
while host_inputs.len() < 3 {
|
||||
host_inputs.push(constant::emit_void(self));
|
||||
}
|
||||
|
||||
let boundary = JoinInlineBoundaryBuilder::new()
|
||||
.with_inputs(join_inputs, host_inputs)
|
||||
.with_exit_bindings(exit_bindings)
|
||||
.with_loop_var_name(Some(ctx.loop_var_name.clone()))
|
||||
.build();
|
||||
|
||||
// Execute JoinIR conversion pipeline
|
||||
use super::conversion_pipeline::JoinIRConversionPipeline;
|
||||
let _ = JoinIRConversionPipeline::execute(
|
||||
self,
|
||||
join_module,
|
||||
Some(&boundary),
|
||||
"pattern3/if-sum",
|
||||
debug,
|
||||
)?;
|
||||
|
||||
// Return Void (loop doesn't produce values)
|
||||
let void_val = constant::emit_void(self);
|
||||
trace::trace().debug("pattern3/if-sum", &format!("Loop complete, returning Void {:?}", void_val));
|
||||
Ok(Some(void_val))
|
||||
}
|
||||
|
||||
/// Phase 188-195: Legacy PoC lowerer (hardcoded conditions)
|
||||
///
|
||||
/// Kept for backward compatibility with existing tests like `loop_if_phi.hako`.
|
||||
fn lower_pattern3_legacy(
|
||||
&mut self,
|
||||
ctx: &super::pattern_pipeline::PatternPipelineContext,
|
||||
debug: bool,
|
||||
) -> Result<Option<ValueId>, String> {
|
||||
use crate::mir::join_ir::lowering::loop_with_if_phi_minimal::lower_loop_with_if_phi_pattern;
|
||||
|
||||
// Phase 195: Extract carrier var_ids dynamically based on what exists
|
||||
// This maintains backward compatibility with single-carrier (sum only) and multi-carrier (sum+count) tests
|
||||
let sum_carrier = ctx.carrier_info.carriers.iter()
|
||||
@ -86,7 +211,7 @@ impl MirBuilder {
|
||||
let mut join_value_space = JoinValueSpace::new();
|
||||
|
||||
// Call Pattern 3 lowerer with preprocessed scope
|
||||
let (join_module, fragment_meta) = match lower_loop_with_if_phi_pattern(ctx.loop_scope, &mut join_value_space) {
|
||||
let (join_module, fragment_meta) = match lower_loop_with_if_phi_pattern(ctx.loop_scope.clone(), &mut join_value_space) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
trace::trace().debug("pattern3", &format!("Pattern 3 lowerer failed: {}", e));
|
||||
|
||||
@ -173,6 +173,47 @@ impl PatternPipelineContext {
|
||||
pub fn has_carrier_updates(&self) -> bool {
|
||||
self.carrier_updates.is_some()
|
||||
}
|
||||
|
||||
/// Phase 213: Check if this is a simple if-sum pattern for AST-based lowering
|
||||
///
|
||||
/// Returns true if:
|
||||
/// 1. loop_body contains an if statement
|
||||
/// 2. carrier composition matches if-sum pattern (1 counter + 1-2 accumulators)
|
||||
///
|
||||
/// This determines whether to use AST-based lowering or legacy PoC lowering.
|
||||
pub fn is_if_sum_pattern(&self) -> bool {
|
||||
// Check if loop_body has if statement
|
||||
let has_if = self.loop_body.as_ref().map_or(false, |body| {
|
||||
body.iter().any(|stmt| matches!(stmt, ASTNode::If { .. }))
|
||||
});
|
||||
|
||||
if !has_if {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check carrier pattern using name heuristics
|
||||
// (1 counter like "i" + 1-2 accumulators like "sum", "count")
|
||||
use crate::mir::join_ir::lowering::loop_update_summary::analyze_loop_updates;
|
||||
let carrier_names: Vec<String> = self.carrier_info.carriers.iter()
|
||||
.map(|c| c.name.clone())
|
||||
.collect();
|
||||
|
||||
// Add loop variable to carrier list (it's also part of the pattern)
|
||||
let mut all_names = vec![self.loop_var_name.clone()];
|
||||
all_names.extend(carrier_names);
|
||||
|
||||
let summary = analyze_loop_updates(&all_names);
|
||||
summary.is_simple_if_sum_pattern()
|
||||
}
|
||||
|
||||
/// Phase 213: Extract if statement from loop body
|
||||
///
|
||||
/// Returns the first if statement found in loop_body, if any.
|
||||
pub fn extract_if_statement(&self) -> Option<&ASTNode> {
|
||||
self.loop_body.as_ref().and_then(|body| {
|
||||
body.iter().find(|stmt| matches!(stmt, ASTNode::If { .. }))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build pattern preprocessing context
|
||||
|
||||
@ -117,6 +117,34 @@ impl LoopUpdateSummary {
|
||||
.filter(|c| c.kind == UpdateKind::AccumulationLike)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Phase 213: Check if this is a simple if-sum pattern
|
||||
///
|
||||
/// Simple if-sum pattern:
|
||||
/// - Has exactly 1 CounterLike carrier (loop index, e.g., "i")
|
||||
/// - Has exactly 1 AccumulationLike carrier (accumulator, e.g., "sum")
|
||||
/// - Optionally has additional accumulators (e.g., "count")
|
||||
///
|
||||
/// Examples:
|
||||
/// - `loop(i < len) { if cond { sum = sum + 1 } i = i + 1 }` ✅
|
||||
/// - `loop(i < len) { if cond { sum = sum + 1; count = count + 1 } i = i + 1 }` ✅
|
||||
/// - `loop(i < len) { result = result + data[i]; i = i + 1 }` ❌ (no if statement)
|
||||
pub fn is_simple_if_sum_pattern(&self) -> bool {
|
||||
// Must have exactly 1 counter (loop index)
|
||||
if self.counter_count() != 1 {
|
||||
return false;
|
||||
}
|
||||
// Must have at least 1 accumulator (sum)
|
||||
if self.accumulation_count() < 1 {
|
||||
return false;
|
||||
}
|
||||
// For now, only support up to 2 accumulators (sum, count)
|
||||
// This matches the Phase 212 if-sum minimal test case
|
||||
if self.accumulation_count() > 2 {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// キャリア名から UpdateKind を推定(暫定実装)
|
||||
@ -221,4 +249,50 @@ mod tests {
|
||||
assert_eq!(summary.counter_count(), 1);
|
||||
assert_eq!(summary.accumulation_count(), 1);
|
||||
}
|
||||
|
||||
// Phase 213 tests for is_simple_if_sum_pattern
|
||||
#[test]
|
||||
fn test_is_simple_if_sum_pattern_basic() {
|
||||
// phase212_if_sum_min.hako pattern: i (counter) + sum (accumulator)
|
||||
let names = vec!["i".to_string(), "sum".to_string()];
|
||||
let summary = analyze_loop_updates(&names);
|
||||
|
||||
assert!(summary.is_simple_if_sum_pattern());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_simple_if_sum_pattern_with_count() {
|
||||
// Phase 195 pattern: i (counter) + sum + count (2 accumulators)
|
||||
let names = vec!["i".to_string(), "sum".to_string(), "count".to_string()];
|
||||
let summary = analyze_loop_updates(&names);
|
||||
|
||||
assert!(summary.is_simple_if_sum_pattern());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_simple_if_sum_pattern_no_accumulator() {
|
||||
// Only counter, no accumulator
|
||||
let names = vec!["i".to_string()];
|
||||
let summary = analyze_loop_updates(&names);
|
||||
|
||||
assert!(!summary.is_simple_if_sum_pattern()); // No accumulator
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_simple_if_sum_pattern_no_counter() {
|
||||
// Only accumulator, no counter
|
||||
let names = vec!["sum".to_string()];
|
||||
let summary = analyze_loop_updates(&names);
|
||||
|
||||
assert!(!summary.is_simple_if_sum_pattern()); // No counter
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_simple_if_sum_pattern_multiple_counters() {
|
||||
// Multiple counters (not supported)
|
||||
let names = vec!["i".to_string(), "j".to_string(), "sum".to_string()];
|
||||
let summary = analyze_loop_updates(&names);
|
||||
|
||||
assert!(!summary.is_simple_if_sum_pattern()); // 2 counters
|
||||
}
|
||||
}
|
||||
|
||||
416
src/mir/join_ir/lowering/loop_with_if_phi_if_sum.rs
Normal file
416
src/mir/join_ir/lowering/loop_with_if_phi_if_sum.rs
Normal file
@ -0,0 +1,416 @@
|
||||
//! Phase 213: Pattern 3 if-sum AST-based lowerer
|
||||
//!
|
||||
//! This module implements AST-based JoinIR lowering for "simple if-sum" patterns.
|
||||
//!
|
||||
//! # Target Pattern
|
||||
//!
|
||||
//! ```nyash
|
||||
//! loop(i < len) {
|
||||
//! if i > 0 {
|
||||
//! sum = sum + 1
|
||||
//! }
|
||||
//! i = i + 1
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! # Design Philosophy
|
||||
//!
|
||||
//! - **AST-driven**: Loop condition, if condition, and updates extracted from AST
|
||||
//! - **80/20 rule**: Only handles simple patterns, rejects complex ones (Fail-Fast)
|
||||
//! - **Reuses existing infrastructure**: JoinValueSpace, ExitMeta, CarrierInfo
|
||||
//!
|
||||
//! # Comparison with Legacy PoC
|
||||
//!
|
||||
//! | Aspect | Legacy (loop_with_if_phi_minimal.rs) | AST-based (this file) |
|
||||
//! |------------------|--------------------------------------|----------------------|
|
||||
//! | Loop condition | Hardcoded (i <= 5) | From `condition` AST |
|
||||
//! | If condition | Hardcoded (i % 2 == 1) | From `if_stmt` AST |
|
||||
//! | Carrier updates | Hardcoded (sum + i) | From AST assignments |
|
||||
//! | Flexibility | Test-only | Any if-sum pattern |
|
||||
|
||||
use crate::ast::ASTNode;
|
||||
use crate::mir::join_ir::lowering::carrier_info::{ExitMeta, JoinFragmentMeta};
|
||||
use crate::mir::join_ir::lowering::join_value_space::JoinValueSpace;
|
||||
use crate::mir::join_ir::{
|
||||
BinOpKind, CompareOp, ConstValue, JoinFuncId, JoinFunction, JoinInst, JoinModule,
|
||||
MirLikeInst, UnaryOp,
|
||||
};
|
||||
|
||||
/// Phase 213: Lower if-sum pattern to JoinIR using AST
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `loop_condition` - Loop condition AST (e.g., `i < len`)
|
||||
/// * `if_stmt` - If statement AST from loop body
|
||||
/// * `body` - Full loop body AST (for finding counter update)
|
||||
/// * `join_value_space` - Unified ValueId allocator
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok((JoinModule, JoinFragmentMeta))` - JoinIR module with exit metadata
|
||||
/// * `Err(String)` - Pattern not supported or extraction failed
|
||||
pub fn lower_if_sum_pattern(
|
||||
loop_condition: &ASTNode,
|
||||
if_stmt: &ASTNode,
|
||||
body: &[ASTNode],
|
||||
join_value_space: &mut JoinValueSpace,
|
||||
) -> Result<(JoinModule, JoinFragmentMeta), String> {
|
||||
eprintln!("[joinir/pattern3/if-sum] Starting AST-based if-sum lowering");
|
||||
|
||||
// Step 1: Extract loop condition info (e.g., i < len → var="i", op=Lt, limit=len)
|
||||
let (loop_var, loop_op, loop_limit) = extract_loop_condition(loop_condition)?;
|
||||
eprintln!("[joinir/pattern3/if-sum] Loop condition: {} {:?} {}", loop_var, loop_op, loop_limit);
|
||||
|
||||
// Step 2: Extract if condition info (e.g., i > 0 → var="i", op=Gt, value=0)
|
||||
let (if_var, if_op, if_value) = extract_if_condition(if_stmt)?;
|
||||
eprintln!("[joinir/pattern3/if-sum] If condition: {} {:?} {}", if_var, if_op, if_value);
|
||||
|
||||
// Step 3: Extract then-branch update (e.g., sum = sum + 1 → var="sum", addend=1)
|
||||
let (update_var, update_addend) = extract_then_update(if_stmt)?;
|
||||
eprintln!("[joinir/pattern3/if-sum] Then update: {} += {}", update_var, update_addend);
|
||||
|
||||
// Step 4: Extract counter update (e.g., i = i + 1 → var="i", step=1)
|
||||
let (counter_var, counter_step) = extract_counter_update(body, &loop_var)?;
|
||||
eprintln!("[joinir/pattern3/if-sum] Counter update: {} += {}", counter_var, counter_step);
|
||||
|
||||
// Step 5: Generate JoinIR
|
||||
let mut alloc_value = || join_value_space.alloc_local();
|
||||
let mut join_module = JoinModule::new();
|
||||
|
||||
// Function IDs
|
||||
let main_id = JoinFuncId::new(0);
|
||||
let loop_step_id = JoinFuncId::new(1);
|
||||
let k_exit_id = JoinFuncId::new(2);
|
||||
|
||||
// === ValueId allocation ===
|
||||
// main() locals
|
||||
let i_init_val = alloc_value(); // i = 0
|
||||
let sum_init_val = alloc_value(); // sum = 0
|
||||
let count_init_val = alloc_value(); // count = 0 (optional)
|
||||
let loop_result = alloc_value(); // result from loop_step
|
||||
|
||||
// loop_step params
|
||||
let i_param = alloc_value();
|
||||
let sum_param = alloc_value();
|
||||
let count_param = alloc_value();
|
||||
|
||||
// loop_step locals
|
||||
let loop_limit_val = alloc_value(); // loop limit value
|
||||
let cmp_loop = alloc_value(); // loop condition comparison
|
||||
let exit_cond = alloc_value(); // negated loop condition
|
||||
|
||||
let if_const = alloc_value(); // if condition constant
|
||||
let if_cmp = alloc_value(); // if condition comparison
|
||||
let sum_then = alloc_value(); // sum + update_addend
|
||||
let count_const = alloc_value(); // count increment (1)
|
||||
let count_then = alloc_value(); // count + 1
|
||||
let const_0 = alloc_value(); // 0 for else branch
|
||||
let sum_else = alloc_value(); // sum + 0 (identity)
|
||||
let count_else = alloc_value(); // count + 0 (identity)
|
||||
let sum_new = alloc_value(); // Select result for sum
|
||||
let count_new = alloc_value(); // Select result for count
|
||||
let step_const = alloc_value(); // counter step
|
||||
let i_next = alloc_value(); // i + step
|
||||
|
||||
// k_exit params
|
||||
let sum_final = alloc_value();
|
||||
let count_final = alloc_value();
|
||||
|
||||
// === main() function ===
|
||||
let mut main_func = JoinFunction::new(main_id, "main".to_string(), vec![]);
|
||||
|
||||
// i_init = 0 (initial value from ctx)
|
||||
main_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: i_init_val,
|
||||
value: ConstValue::Integer(0), // TODO: Get from AST
|
||||
}));
|
||||
|
||||
// sum_init = 0
|
||||
main_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: sum_init_val,
|
||||
value: ConstValue::Integer(0),
|
||||
}));
|
||||
|
||||
// count_init = 0
|
||||
main_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: count_init_val,
|
||||
value: ConstValue::Integer(0),
|
||||
}));
|
||||
|
||||
// result = loop_step(i_init, sum_init, count_init)
|
||||
main_func.body.push(JoinInst::Call {
|
||||
func: loop_step_id,
|
||||
args: vec![i_init_val, sum_init_val, count_init_val],
|
||||
k_next: None,
|
||||
dst: Some(loop_result),
|
||||
});
|
||||
|
||||
main_func.body.push(JoinInst::Ret {
|
||||
value: Some(loop_result),
|
||||
});
|
||||
|
||||
join_module.add_function(main_func);
|
||||
|
||||
// === loop_step(i, sum, count) function ===
|
||||
let mut loop_step_func = JoinFunction::new(
|
||||
loop_step_id,
|
||||
"loop_step".to_string(),
|
||||
vec![i_param, sum_param, count_param],
|
||||
);
|
||||
|
||||
// --- Exit Condition Check ---
|
||||
// Load loop limit from AST
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: loop_limit_val,
|
||||
value: ConstValue::Integer(loop_limit),
|
||||
}));
|
||||
|
||||
// Compare: i < limit (or other op from AST)
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Compare {
|
||||
dst: cmp_loop,
|
||||
op: loop_op,
|
||||
lhs: i_param,
|
||||
rhs: loop_limit_val,
|
||||
}));
|
||||
|
||||
// exit_cond = !cmp_loop
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::UnaryOp {
|
||||
dst: exit_cond,
|
||||
op: UnaryOp::Not,
|
||||
operand: cmp_loop,
|
||||
}));
|
||||
|
||||
// Jump to exit if condition is false
|
||||
loop_step_func.body.push(JoinInst::Jump {
|
||||
cont: k_exit_id.as_cont(),
|
||||
args: vec![sum_param, count_param],
|
||||
cond: Some(exit_cond),
|
||||
});
|
||||
|
||||
// --- If Condition (AST-based) ---
|
||||
// Load if constant
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: if_const,
|
||||
value: ConstValue::Integer(if_value),
|
||||
}));
|
||||
|
||||
// Compare: if_var <op> if_value
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Compare {
|
||||
dst: if_cmp,
|
||||
op: if_op,
|
||||
lhs: i_param, // Assuming if_var == loop_var (common case)
|
||||
rhs: if_const,
|
||||
}));
|
||||
|
||||
// --- Then Branch ---
|
||||
// sum_then = sum + update_addend
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: const_0,
|
||||
value: ConstValue::Integer(update_addend),
|
||||
}));
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: sum_then,
|
||||
op: BinOpKind::Add,
|
||||
lhs: sum_param,
|
||||
rhs: const_0,
|
||||
}));
|
||||
|
||||
// count_then = count + 1
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: count_const,
|
||||
value: ConstValue::Integer(1),
|
||||
}));
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: count_then,
|
||||
op: BinOpKind::Add,
|
||||
lhs: count_param,
|
||||
rhs: count_const,
|
||||
}));
|
||||
|
||||
// --- Else Branch ---
|
||||
// sum_else = sum + 0 (identity)
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: step_const, // reuse for 0
|
||||
value: ConstValue::Integer(0),
|
||||
}));
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: sum_else,
|
||||
op: BinOpKind::Add,
|
||||
lhs: sum_param,
|
||||
rhs: step_const,
|
||||
}));
|
||||
|
||||
// count_else = count + 0 (identity)
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: count_else,
|
||||
op: BinOpKind::Add,
|
||||
lhs: count_param,
|
||||
rhs: step_const, // 0
|
||||
}));
|
||||
|
||||
// --- Select ---
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Select {
|
||||
dst: sum_new,
|
||||
cond: if_cmp,
|
||||
then_val: sum_then,
|
||||
else_val: sum_else,
|
||||
}));
|
||||
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Select {
|
||||
dst: count_new,
|
||||
cond: if_cmp,
|
||||
then_val: count_then,
|
||||
else_val: count_else,
|
||||
}));
|
||||
|
||||
// --- Counter Update ---
|
||||
let step_const2 = alloc_value();
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::Const {
|
||||
dst: step_const2,
|
||||
value: ConstValue::Integer(counter_step),
|
||||
}));
|
||||
loop_step_func.body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: i_next,
|
||||
op: BinOpKind::Add,
|
||||
lhs: i_param,
|
||||
rhs: step_const2,
|
||||
}));
|
||||
|
||||
// --- Tail Recursion ---
|
||||
loop_step_func.body.push(JoinInst::Call {
|
||||
func: loop_step_id,
|
||||
args: vec![i_next, sum_new, count_new],
|
||||
k_next: None,
|
||||
dst: None,
|
||||
});
|
||||
|
||||
join_module.add_function(loop_step_func);
|
||||
|
||||
// === k_exit(sum_final, count_final) function ===
|
||||
let mut k_exit_func = JoinFunction::new(
|
||||
k_exit_id,
|
||||
"k_exit".to_string(),
|
||||
vec![sum_final, count_final],
|
||||
);
|
||||
|
||||
k_exit_func.body.push(JoinInst::Ret {
|
||||
value: Some(sum_final),
|
||||
});
|
||||
|
||||
join_module.add_function(k_exit_func);
|
||||
join_module.entry = Some(main_id);
|
||||
|
||||
// Build ExitMeta
|
||||
let mut exit_values = vec![];
|
||||
exit_values.push(("sum".to_string(), sum_final));
|
||||
exit_values.push(("count".to_string(), count_final));
|
||||
|
||||
let exit_meta = ExitMeta::multiple(exit_values);
|
||||
let fragment_meta = JoinFragmentMeta::carrier_only(exit_meta);
|
||||
|
||||
eprintln!("[joinir/pattern3/if-sum] Generated AST-based JoinIR");
|
||||
eprintln!("[joinir/pattern3/if-sum] Loop: {} {:?} {}", loop_var, loop_op, loop_limit);
|
||||
eprintln!("[joinir/pattern3/if-sum] If: {} {:?} {}", if_var, if_op, if_value);
|
||||
|
||||
Ok((join_module, fragment_meta))
|
||||
}
|
||||
|
||||
/// Extract loop condition: variable, operator, and limit
|
||||
///
|
||||
/// Supports: `var < lit`, `var <= lit`, `var > lit`, `var >= lit`
|
||||
fn extract_loop_condition(cond: &ASTNode) -> Result<(String, CompareOp, i64), String> {
|
||||
match cond {
|
||||
ASTNode::BinaryOp { operator, left, right, .. } => {
|
||||
let var_name = extract_variable_name(left)?;
|
||||
let limit = extract_integer_literal(right)?;
|
||||
let op = match operator {
|
||||
crate::ast::BinaryOperator::Less => CompareOp::Lt,
|
||||
crate::ast::BinaryOperator::LessEqual => CompareOp::Le,
|
||||
crate::ast::BinaryOperator::Greater => CompareOp::Gt,
|
||||
crate::ast::BinaryOperator::GreaterEqual => CompareOp::Ge,
|
||||
_ => return Err(format!("[if-sum] Unsupported loop condition operator: {:?}", operator)),
|
||||
};
|
||||
Ok((var_name, op, limit))
|
||||
}
|
||||
_ => Err("[if-sum] Loop condition must be a binary comparison".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract if condition: variable, operator, and value
|
||||
fn extract_if_condition(if_stmt: &ASTNode) -> Result<(String, CompareOp, i64), String> {
|
||||
match if_stmt {
|
||||
ASTNode::If { condition, .. } => {
|
||||
extract_loop_condition(condition) // Same format
|
||||
}
|
||||
_ => Err("[if-sum] Expected If statement".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract then-branch update: variable and addend
|
||||
///
|
||||
/// Supports: `var = var + lit`
|
||||
fn extract_then_update(if_stmt: &ASTNode) -> Result<(String, i64), String> {
|
||||
match if_stmt {
|
||||
ASTNode::If { then_body, .. } => {
|
||||
// Find assignment in then block
|
||||
for stmt in then_body {
|
||||
if let ASTNode::Assignment { target, value, .. } = stmt {
|
||||
let target_name = extract_variable_name(&**target)?;
|
||||
// Check if value is var + lit
|
||||
if let ASTNode::BinaryOp { operator: crate::ast::BinaryOperator::Add, left, right, .. } = value.as_ref() {
|
||||
let lhs_name = extract_variable_name(left)?;
|
||||
if lhs_name == target_name {
|
||||
let addend = extract_integer_literal(right)?;
|
||||
return Ok((target_name, addend));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err("[if-sum] No valid accumulator update found in then block".to_string())
|
||||
}
|
||||
_ => Err("[if-sum] Expected If statement".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract counter update: variable and step
|
||||
///
|
||||
/// Looks for `var = var + lit` where var is the loop variable
|
||||
fn extract_counter_update(body: &[ASTNode], loop_var: &str) -> Result<(String, i64), String> {
|
||||
for stmt in body {
|
||||
if let ASTNode::Assignment { target, value, .. } = stmt {
|
||||
if let Ok(target_name) = extract_variable_name(&**target) {
|
||||
if target_name == loop_var {
|
||||
if let ASTNode::BinaryOp { operator: crate::ast::BinaryOperator::Add, left, right, .. } = value.as_ref() {
|
||||
let lhs_name = extract_variable_name(left)?;
|
||||
if lhs_name == target_name {
|
||||
let step = extract_integer_literal(right)?;
|
||||
return Ok((target_name, step));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(format!("[if-sum] No counter update found for '{}'", loop_var))
|
||||
}
|
||||
|
||||
/// Extract variable name from AST node
|
||||
fn extract_variable_name(node: &ASTNode) -> Result<String, String> {
|
||||
match node {
|
||||
ASTNode::Variable { name, .. } => Ok(name.clone()),
|
||||
_ => Err(format!("[if-sum] Expected variable, got {:?}", node)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract integer literal from AST node
|
||||
fn extract_integer_literal(node: &ASTNode) -> Result<i64, String> {
|
||||
match node {
|
||||
ASTNode::Literal { value: crate::ast::LiteralValue::Integer(n), .. } => Ok(*n),
|
||||
ASTNode::Variable { name, .. } => {
|
||||
// Handle variable reference (e.g., `len`)
|
||||
// For Phase 213, we only support literals. Variables need Phase 214+
|
||||
Err(format!("[if-sum] Variable '{}' in condition not supported yet (Phase 214+)", name))
|
||||
}
|
||||
_ => Err(format!("[if-sum] Expected integer literal, got {:?}", node)),
|
||||
}
|
||||
}
|
||||
@ -58,6 +58,7 @@ pub(crate) mod loop_view_builder; // Phase 33-23: Loop lowering dispatch
|
||||
pub mod loop_with_break_minimal; // Phase 188-Impl-2: Pattern 2 minimal lowerer
|
||||
pub mod loop_with_continue_minimal; // Phase 195: Pattern 4 minimal lowerer
|
||||
pub mod loop_with_if_phi_minimal; // Phase 188-Impl-3: Pattern 3 minimal lowerer
|
||||
pub mod loop_with_if_phi_if_sum; // Phase 213: Pattern 3 AST-based if-sum lowerer
|
||||
pub mod simple_while_minimal; // Phase 188-Impl-1: Pattern 1 minimal lowerer
|
||||
pub mod min_loop;
|
||||
pub mod skip_ws;
|
||||
|
||||
Reference in New Issue
Block a user