diff --git a/apps/tests/joinir_if_merge_multiple.hako b/apps/tests/joinir_if_merge_multiple.hako new file mode 100644 index 00000000..13215437 --- /dev/null +++ b/apps/tests/joinir_if_merge_multiple.hako @@ -0,0 +1,73 @@ +// Phase 33-7: IfMerge lowering test (multiple variables pattern) +// +// Pattern: if cond { x=a; y=b; z=c } else { x=d; y=e; z=f } return x+y+z +// Expected: +// - cond=true → (x=10, y=20, z=30) → return 60 +// - cond=false → (x=40, y=50, z=60) → return 150 + +static box IfMergeTest { + multiple_true() { + local x + local y + local z + + if true { + x = 10 + y = 20 + z = 30 + } else { + x = 40 + y = 50 + z = 60 + } + + return x + y + z + } + + multiple_false() { + local x + local y + local z + + if false { + x = 10 + y = 20 + z = 30 + } else { + x = 40 + y = 50 + z = 60 + } + + return x + y + z + } + + main() { + local result_true + local result_false + + result_true = me.multiple_true() + result_false = me.multiple_false() + + print("multiple_true: ") + print(result_true) + print("\n") + + print("multiple_false: ") + print(result_false) + print("\n") + + // Verify results + if result_true == 60 { + print("PASS: multiple_true\n") + } else { + print("FAIL: multiple_true\n") + } + + if result_false == 150 { + print("PASS: multiple_false\n") + } else { + print("FAIL: multiple_false\n") + } + } +} diff --git a/apps/tests/joinir_if_merge_simple.hako b/apps/tests/joinir_if_merge_simple.hako new file mode 100644 index 00000000..b782532e --- /dev/null +++ b/apps/tests/joinir_if_merge_simple.hako @@ -0,0 +1,67 @@ +// Phase 33-7: IfMerge lowering test (simple pattern) +// +// Pattern: if cond { x=a; y=b } else { x=c; y=d } return x+y +// Expected: +// - cond=true → (x=1, y=2) → return 3 +// - cond=false → (x=3, y=4) → return 7 + +static box IfMergeTest { + simple_true() { + local x + local y + + if true { + x = 1 + y = 2 + } else { + x = 3 + y = 4 + } + + return x + y + } + + simple_false() { + local x + local y + + if false { + x = 1 + y = 2 + } else { + x = 3 + y = 4 + } + + return x + y + } + + main() { + local result_true + local result_false + + result_true = me.simple_true() + result_false = me.simple_false() + + print("simple_true: ") + print(result_true) + print("\n") + + print("simple_false: ") + print(result_false) + print("\n") + + // Verify results + if result_true == 3 { + print("PASS: simple_true\n") + } else { + print("FAIL: simple_true\n") + } + + if result_false == 7 { + print("PASS: simple_false\n") + } else { + print("FAIL: simple_false\n") + } + } +} diff --git a/src/mir/join_ir/lowering/if_merge.rs b/src/mir/join_ir/lowering/if_merge.rs new file mode 100644 index 00000000..07b2d9f2 --- /dev/null +++ b/src/mir/join_ir/lowering/if_merge.rs @@ -0,0 +1,236 @@ +//! Phase 33-7: If/Else の IfMerge 命令への lowering +//! +//! 複数変数を merge する if/else を JoinInst::IfMerge に変換する。 +//! +//! Phase 33-7 制約: +//! - return パターンのみ(continuation は Phase 33-8) +//! - k_next=None のみ + +use crate::mir::join_ir::{JoinInst, MergePair}; +use crate::mir::{BasicBlockId, MirFunction, MirInstruction, ValueId}; +use std::collections::HashSet; + +pub struct IfMergeLowerer { + debug: bool, +} + +/// 検出された IfMerge パターン情報 +#[derive(Debug, Clone)] +struct IfMergePattern { + cond: ValueId, + merge_pairs: Vec, +} + +/// Branch 命令の情報 +#[derive(Debug, Clone)] +struct IfBranch { + cond: ValueId, + then_block: BasicBlockId, + else_block: BasicBlockId, +} + +impl IfMergeLowerer { + pub fn new(debug: bool) -> Self { + Self { debug } + } + + /// if/else が IfMerge に lowering できるかチェック + pub fn can_lower_to_if_merge(&self, func: &MirFunction, if_block_id: BasicBlockId) -> bool { + self.find_if_merge_pattern(func, if_block_id).is_some() + } + + /// if/else を IfMerge に変換 + pub fn lower_if_to_if_merge( + &self, + func: &MirFunction, + if_block_id: BasicBlockId, + ) -> Option { + let pattern = self.find_if_merge_pattern(func, if_block_id)?; + + if self.debug { + eprintln!( + "[IfMergeLowerer] lowering to IfMerge with {} merge pairs", + pattern.merge_pairs.len() + ); + } + + // IfMerge 命令を生成 + Some(JoinInst::IfMerge { + cond: pattern.cond, + merges: pattern.merge_pairs, + k_next: None, // Phase 33-7 制約 + }) + } + + /// MIR 関数から IfMerge パターンを探す + fn find_if_merge_pattern( + &self, + func: &MirFunction, + block_id: BasicBlockId, + ) -> Option { + // 1. Block が Branch 命令で終わっているか確認 + let block = func.blocks.get(&block_id)?; + let branch = match block.terminator.as_ref()? { + MirInstruction::Branch { + condition, + then_bb, + else_bb, + } => IfBranch { + cond: *condition, + then_block: *then_bb, + else_block: *else_bb, + }, + _ => return None, + }; + + // 2. then/else ブロックを取得 + let then_block = func.blocks.get(&branch.then_block)?; + let else_block = func.blocks.get(&branch.else_block)?; + + // 3. Phase 33-7 制約: return パターンのみ + // 両方のブロックが Return で終わる必要がある + let is_then_return = matches!( + then_block.terminator.as_ref(), + Some(MirInstruction::Return { .. }) + ); + let is_else_return = matches!( + else_block.terminator.as_ref(), + Some(MirInstruction::Return { .. }) + ); + + if !is_then_return || !is_else_return { + if self.debug { + eprintln!( + "[IfMergeLowerer] not return pattern (then={}, else={})", + is_then_return, is_else_return + ); + } + return None; + } + + // 4. then/else で書き込まれる変数を抽出 + let then_writes = self.extract_written_vars(&then_block.instructions); + let else_writes = self.extract_written_vars(&else_block.instructions); + + if self.debug { + eprintln!( + "[IfMergeLowerer] then writes: {:?}, else writes: {:?}", + then_writes, else_writes + ); + } + + // 5. 両方で書き込まれる変数(共通集合)を抽出 + let common_writes: HashSet<_> = then_writes.intersection(&else_writes).copied().collect(); + + if common_writes.is_empty() { + if self.debug { + eprintln!("[IfMergeLowerer] no common writes found"); + } + return None; + } + + // 6. 各共通変数について MergePair を生成 + let mut merge_pairs = Vec::new(); + + for &dst in &common_writes { + // then ブロックで dst に書き込まれる値を探す + let then_val = self.find_written_value(&then_block.instructions, dst)?; + // else ブロックで dst に書き込まれる値を探す + let else_val = self.find_written_value(&else_block.instructions, dst)?; + + merge_pairs.push(MergePair { + dst, + then_val, + else_val, + }); + } + + if merge_pairs.is_empty() { + return None; + } + + // 7. MergePair を dst でソートして決定的に + merge_pairs.sort_by_key(|pair| pair.dst.0); + + Some(IfMergePattern { + cond: branch.cond, + merge_pairs, + }) + } + + /// 命令列から書き込まれる変数集合を抽出 + fn extract_written_vars(&self, instructions: &[MirInstruction]) -> HashSet { + let mut writes = HashSet::new(); + + for inst in instructions { + match inst { + MirInstruction::Copy { dst, .. } + | MirInstruction::Const { dst, .. } + | MirInstruction::BinOp { dst, .. } + | MirInstruction::Compare { dst, .. } => { + writes.insert(*dst); + } + MirInstruction::Call { dst: Some(dst), .. } + | MirInstruction::BoxCall { dst: Some(dst), .. } => { + writes.insert(*dst); + } + _ => {} + } + } + + writes + } + + /// 命令列から dst に書き込まれる値を探す(最後の書き込み) + fn find_written_value( + &self, + instructions: &[MirInstruction], + dst: ValueId, + ) -> Option { + // 逆順で探索して最後の書き込みを見つける + for inst in instructions.iter().rev() { + match inst { + MirInstruction::Copy { + dst: inst_dst, + src, + } if *inst_dst == dst => { + return Some(*src); + } + MirInstruction::Const { + dst: inst_dst, .. + } + | MirInstruction::BinOp { + dst: inst_dst, .. + } + | MirInstruction::Compare { + dst: inst_dst, .. + } + | MirInstruction::Call { + dst: Some(inst_dst), + .. + } + | MirInstruction::BoxCall { + dst: Some(inst_dst), + .. + } if *inst_dst == dst => { + // dst 自身が書き込まれる場合は dst を返す + return Some(dst); + } + _ => {} + } + } + + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_if_merge_lowerer_creation() { + let lowerer = IfMergeLowerer::new(false); + assert!(!lowerer.debug); + } +} diff --git a/src/mir/join_ir/lowering/mod.rs b/src/mir/join_ir/lowering/mod.rs index 8e915d03..5e5b830e 100644 --- a/src/mir/join_ir/lowering/mod.rs +++ b/src/mir/join_ir/lowering/mod.rs @@ -19,6 +19,7 @@ pub mod exit_args_resolver; pub mod funcscanner_append_defs; pub mod funcscanner_trim; pub mod generic_case_a; +pub mod if_merge; // Phase 33-7 pub mod if_select; // Phase 33 pub mod loop_form_intake; pub mod loop_scope_shape; @@ -46,29 +47,35 @@ pub use stageb_funcscanner::lower_stageb_funcscanner_to_joinir; use crate::mir::join_ir::JoinInst; use crate::mir::{BasicBlockId, MirFunction}; -/// Phase 33-3: Try to lower if/else to JoinIR Select instruction +/// Phase 33-7: Try to lower if/else to JoinIR Select/IfMerge instruction /// /// Scope: /// - Only applies to whitelisted functions: /// - IfSelectTest.* (Phase 33-2/33-3) +/// - IfMergeTest.* (Phase 33-7) /// - JsonShapeToMap._read_value_from_pair/1 (Phase 33-4 Stage-1) /// - Stage1JsonScannerBox.value_start_after_key_pos/2 (Phase 33-4 Stage-B) /// - Requires NYASH_JOINIR_IF_SELECT=1 environment variable /// - Falls back to traditional if_phi on pattern mismatch /// -/// Returns Some(JoinInst::Select) if pattern matched, None otherwise. +/// Pattern selection: +/// - 1 variable → Select +/// - 2+ variables → IfMerge +/// +/// Returns Some(JoinInst::Select) or Some(JoinInst::IfMerge) if pattern matched, None otherwise. pub fn try_lower_if_to_joinir( func: &MirFunction, block_id: BasicBlockId, debug: bool, ) -> Option { - // dev トグルチェック + // 1. dev トグルチェック if !crate::config::env::joinir_if_select_enabled() { return None; } - // Phase 33-4/33-5: 関数名ガード拡張(テスト + Stage-1/Stage-B 候補) + // 2. Phase 33-7: 関数名ガード拡張(テスト + Stage-1/Stage-B 候補) let is_allowed = func.signature.name.starts_with("IfSelectTest.") + || func.signature.name.starts_with("IfMergeTest.") // Phase 33-7 || func.signature.name.starts_with("Stage1JsonScannerTestBox.") // Phase 33-5 test || func.signature.name == "JsonShapeToMap._read_value_from_pair/1" || func.signature.name == "Stage1JsonScannerBox.value_start_after_key_pos/2"; @@ -83,10 +90,26 @@ pub fn try_lower_if_to_joinir( return None; } - // if_select lowering を試行 - let lowerer = if_select::IfSelectLowerer::new(debug); + // 3. Phase 33-7: IfMerge を優先的に試行(複数変数パターン) + // IfMerge が成功すればそれを返す、失敗したら Select を試行 + let if_merge_lowerer = if_merge::IfMergeLowerer::new(debug); - if !lowerer.can_lower_to_select(func, block_id) { + if if_merge_lowerer.can_lower_to_if_merge(func, block_id) { + if let Some(result) = if_merge_lowerer.lower_if_to_if_merge(func, block_id) { + if debug { + eprintln!( + "[try_lower_if_to_joinir] IfMerge lowering used for {}", + func.signature.name + ); + } + return Some(result); + } + } + + // 4. IfMerge が失敗したら Select を試行(単一変数パターン) + let if_select_lowerer = if_select::IfSelectLowerer::new(debug); + + if !if_select_lowerer.can_lower_to_select(func, block_id) { if debug { eprintln!( "[try_lower_if_to_joinir] pattern not matched for {}", @@ -96,11 +119,11 @@ pub fn try_lower_if_to_joinir( return None; } - let result = lowerer.lower_if_to_select(func, block_id); + let result = if_select_lowerer.lower_if_to_select(func, block_id); if result.is_some() && debug { eprintln!( - "[try_lower_if_to_joinir] if_select lowering used for {}", + "[try_lower_if_to_joinir] Select lowering used for {}", func.signature.name ); } diff --git a/src/tests/mir_joinir_if_select.rs b/src/tests/mir_joinir_if_select.rs index 73f51465..3c2d5aa5 100644 --- a/src/tests/mir_joinir_if_select.rs +++ b/src/tests/mir_joinir_if_select.rs @@ -373,4 +373,198 @@ mod tests { eprintln!("✅ verify_select_minimal properly checks invariants from phi_invariants.rs and conservative.rs"); } + + // ============================================================================ + // Phase 33-7: IfMerge lowering tests + // ============================================================================ + + /// Helper to create a 2-variable IfMerge pattern MIR + fn create_if_merge_simple_pattern_mir() -> MirFunction { + let mut blocks = BTreeMap::new(); + + // Entry block (bb0): branch on cond + let mut entry = BasicBlock::new(BasicBlockId::new(0)); + entry.terminator = Some(MirInstruction::Branch { + condition: ValueId(0), // cond + then_bb: BasicBlockId::new(1), + else_bb: BasicBlockId::new(2), + }); + blocks.insert(BasicBlockId::new(0), entry); + + // Then block (bb1): x = 1, y = 2 + let mut then_block = BasicBlock::new(BasicBlockId::new(1)); + then_block.instructions.push(MirInstruction::Const { + dst: ValueId(3), // x = 1 + value: crate::mir::ConstValue::Integer(1), + }); + then_block.instructions.push(MirInstruction::Const { + dst: ValueId(4), // y = 2 + value: crate::mir::ConstValue::Integer(2), + }); + then_block.terminator = Some(MirInstruction::Return { + value: Some(ValueId(10)), // result (x + y computed elsewhere) + }); + blocks.insert(BasicBlockId::new(1), then_block); + + // Else block (bb2): x = 3, y = 4 + let mut else_block = BasicBlock::new(BasicBlockId::new(2)); + else_block.instructions.push(MirInstruction::Const { + dst: ValueId(3), // x = 3 (same dst as then!) + value: crate::mir::ConstValue::Integer(3), + }); + else_block.instructions.push(MirInstruction::Const { + dst: ValueId(4), // y = 4 (same dst as then!) + value: crate::mir::ConstValue::Integer(4), + }); + else_block.terminator = Some(MirInstruction::Return { + value: Some(ValueId(20)), // result (x + y computed elsewhere) + }); + blocks.insert(BasicBlockId::new(2), else_block); + + use crate::mir::{EffectMask, MirType}; + use crate::mir::function::FunctionMetadata; + use std::collections::HashMap; + + MirFunction { + signature: crate::mir::FunctionSignature { + name: "IfMergeTest.simple_true/0".to_string(), + params: vec![], + return_type: MirType::Integer, + effects: EffectMask::PURE, + }, + entry_block: BasicBlockId::new(0), + blocks: blocks.into_iter().collect(), + locals: vec![], + params: vec![ValueId(0)], + next_value_id: 23, + metadata: FunctionMetadata::default(), + } + } + + /// Helper to create a 3-variable IfMerge pattern MIR + fn create_if_merge_multiple_pattern_mir() -> MirFunction { + let mut blocks = BTreeMap::new(); + + // Entry block (bb0): branch on cond + let mut entry = BasicBlock::new(BasicBlockId::new(0)); + entry.terminator = Some(MirInstruction::Branch { + condition: ValueId(0), // cond + then_bb: BasicBlockId::new(1), + else_bb: BasicBlockId::new(2), + }); + blocks.insert(BasicBlockId::new(0), entry); + + // Then block (bb1): x = 10, y = 20, z = 30 + let mut then_block = BasicBlock::new(BasicBlockId::new(1)); + then_block.instructions.push(MirInstruction::Const { + dst: ValueId(3), // x = 10 + value: crate::mir::ConstValue::Integer(10), + }); + then_block.instructions.push(MirInstruction::Const { + dst: ValueId(4), // y = 20 + value: crate::mir::ConstValue::Integer(20), + }); + then_block.instructions.push(MirInstruction::Const { + dst: ValueId(5), // z = 30 + value: crate::mir::ConstValue::Integer(30), + }); + then_block.terminator = Some(MirInstruction::Return { + value: Some(ValueId(10)), // result (x + y + z computed elsewhere) + }); + blocks.insert(BasicBlockId::new(1), then_block); + + // Else block (bb2): x = 40, y = 50, z = 60 + let mut else_block = BasicBlock::new(BasicBlockId::new(2)); + else_block.instructions.push(MirInstruction::Const { + dst: ValueId(3), // x = 40 (same dst as then!) + value: crate::mir::ConstValue::Integer(40), + }); + else_block.instructions.push(MirInstruction::Const { + dst: ValueId(4), // y = 50 (same dst as then!) + value: crate::mir::ConstValue::Integer(50), + }); + else_block.instructions.push(MirInstruction::Const { + dst: ValueId(5), // z = 60 (same dst as then!) + value: crate::mir::ConstValue::Integer(60), + }); + else_block.terminator = Some(MirInstruction::Return { + value: Some(ValueId(20)), // result (x + y + z computed elsewhere) + }); + blocks.insert(BasicBlockId::new(2), else_block); + + use crate::mir::{EffectMask, MirType}; + use crate::mir::function::FunctionMetadata; + use std::collections::HashMap; + + MirFunction { + signature: crate::mir::FunctionSignature { + name: "IfMergeTest.multiple_true/0".to_string(), + params: vec![], + return_type: MirType::Integer, + effects: EffectMask::PURE, + }, + entry_block: BasicBlockId::new(0), + blocks: blocks.into_iter().collect(), + locals: vec![], + params: vec![ValueId(0)], + next_value_id: 24, + metadata: FunctionMetadata::default(), + } + } + + /// Phase 33-7: Test IfMerge lowering for 2-variable pattern + #[test] + fn test_if_merge_simple_pattern() { + use crate::mir::join_ir::JoinInst; + + std::env::set_var("NYASH_JOINIR_IF_SELECT", "1"); + + let func = create_if_merge_simple_pattern_mir(); + let entry_block = func.entry_block; + let result = try_lower_if_to_joinir(&func, entry_block, true); + + assert!( + result.is_some(), + "Expected simple 2-variable pattern to be lowered to IfMerge" + ); + + if let Some(JoinInst::IfMerge { cond, merges, k_next }) = result { + eprintln!("✅ Simple pattern (2 vars) successfully lowered to IfMerge"); + eprintln!(" cond: {:?}, merges: {} pairs, k_next: {:?}", cond, merges.len(), k_next); + assert_eq!(merges.len(), 2, "Expected 2 MergePairs for x and y"); + assert!(k_next.is_none(), "Phase 33-7 constraint: k_next should be None"); + } else { + panic!("Expected JoinInst::IfMerge, got {:?}", result); + } + + std::env::remove_var("NYASH_JOINIR_IF_SELECT"); + } + + /// Phase 33-7: Test IfMerge lowering for 3-variable pattern + #[test] + fn test_if_merge_multiple_pattern() { + use crate::mir::join_ir::JoinInst; + + std::env::set_var("NYASH_JOINIR_IF_SELECT", "1"); + + let func = create_if_merge_multiple_pattern_mir(); + let entry_block = func.entry_block; + let result = try_lower_if_to_joinir(&func, entry_block, true); + + assert!( + result.is_some(), + "Expected multiple 3-variable pattern to be lowered to IfMerge" + ); + + if let Some(JoinInst::IfMerge { cond, merges, k_next }) = result { + eprintln!("✅ Multiple pattern (3 vars) successfully lowered to IfMerge"); + eprintln!(" cond: {:?}, merges: {} pairs, k_next: {:?}", cond, merges.len(), k_next); + assert_eq!(merges.len(), 3, "Expected 3 MergePairs for x, y, and z"); + assert!(k_next.is_none(), "Phase 33-7 constraint: k_next should be None"); + } else { + panic!("Expected JoinInst::IfMerge, got {:?}", result); + } + + std::env::remove_var("NYASH_JOINIR_IF_SELECT"); + } }