diff --git a/src/mir/builder/control_flow/plan/planner/build.rs b/src/mir/builder/control_flow/plan/planner/build.rs index 7fbf8d76..1e289474 100644 --- a/src/mir/builder/control_flow/plan/planner/build.rs +++ b/src/mir/builder/control_flow/plan/planner/build.rs @@ -14,8 +14,9 @@ use super::outcome::build_plan_with_facts; use super::Freeze; use crate::mir::builder::control_flow::plan::{ DomainPlan, Pattern1SimpleWhilePlan, Pattern2BreakPlan, Pattern2PromotionHint, Pattern3IfPhiPlan, - Pattern4ContinuePlan, Pattern5InfiniteEarlyExitPlan, Pattern8BoolPredicateScanPlan, - Pattern9AccumConstLoopPlan, ScanDirection, ScanWithInitPlan, SplitScanPlan, + Pattern4ContinuePlan, Pattern5ExitKind, Pattern5InfiniteEarlyExitPlan, + Pattern8BoolPredicateScanPlan, Pattern9AccumConstLoopPlan, ScanDirection, ScanWithInitPlan, + SplitScanPlan, }; use crate::mir::loop_pattern_detection::LoopPatternKind; @@ -75,13 +76,43 @@ pub(in crate::mir::builder) fn build_plan_from_facts_ctx( } fn infer_skeleton_kind(facts: &CanonicalLoopFacts) -> Option { - Some(facts.facts.skeleton.kind) + Some(facts.skeleton_kind) } fn infer_exit_usage(facts: &CanonicalLoopFacts) -> Option { - Some(facts.facts.features.exit_usage.clone()) + Some(facts.exit_usage.clone()) } +#[cfg(debug_assertions)] +fn debug_assert_exit_usage_matches_plan(plan: &DomainPlan, exit_usage: &ExitUsageFacts) { + match plan { + DomainPlan::Pattern1SimpleWhile(_) => { + debug_assert!( + !exit_usage.has_break && !exit_usage.has_continue && !exit_usage.has_return, + "pattern1 requires no exit usage" + ); + } + DomainPlan::Pattern2Break(_) => { + debug_assert!(exit_usage.has_break, "pattern2 requires break usage"); + } + DomainPlan::Pattern4Continue(_) => { + debug_assert!(exit_usage.has_continue, "pattern4 requires continue usage"); + } + DomainPlan::Pattern5InfiniteEarlyExit(plan) => match plan.exit_kind { + Pattern5ExitKind::Return => { + debug_assert!(exit_usage.has_return, "pattern5 return requires return usage"); + } + Pattern5ExitKind::Break => { + debug_assert!(exit_usage.has_break, "pattern5 break requires break usage"); + } + }, + _ => {} + } +} + +#[cfg(not(debug_assertions))] +fn debug_assert_exit_usage_matches_plan(_plan: &DomainPlan, _exit_usage: &ExitUsageFacts) {} + fn push_scan_with_init(candidates: &mut CandidateSet, facts: &CanonicalLoopFacts) { let Some(scan) = &facts.facts.scan_with_init else { return; @@ -129,17 +160,19 @@ fn push_pattern2_break(candidates: &mut CandidateSet, facts: &CanonicalLoopFacts .pattern2_loopbodylocal .as_ref() .map(|facts| Pattern2PromotionHint::LoopBodyLocal(facts.clone())); + let plan = DomainPlan::Pattern2Break(Pattern2BreakPlan { + loop_var: pattern2.loop_var.clone(), + carrier_var: pattern2.carrier_var.clone(), + loop_condition: pattern2.loop_condition.clone(), + break_condition: pattern2.break_condition.clone(), + carrier_update_in_break: pattern2.carrier_update_in_break.clone(), + carrier_update_in_body: pattern2.carrier_update_in_body.clone(), + loop_increment: pattern2.loop_increment.clone(), + promotion, + }); + debug_assert_exit_usage_matches_plan(&plan, &facts.exit_usage); candidates.push(PlanCandidate { - plan: DomainPlan::Pattern2Break(Pattern2BreakPlan { - loop_var: pattern2.loop_var.clone(), - carrier_var: pattern2.carrier_var.clone(), - loop_condition: pattern2.loop_condition.clone(), - break_condition: pattern2.break_condition.clone(), - carrier_update_in_break: pattern2.carrier_update_in_break.clone(), - carrier_update_in_body: pattern2.carrier_update_in_body.clone(), - loop_increment: pattern2.loop_increment.clone(), - promotion, - }), + plan, rule: "loop/pattern2_break", }); } @@ -166,15 +199,17 @@ fn push_pattern4_continue(candidates: &mut CandidateSet, facts: &CanonicalLoopFa let Some(pattern4) = &facts.facts.pattern4_continue else { return; }; + let plan = DomainPlan::Pattern4Continue(Pattern4ContinuePlan { + loop_var: pattern4.loop_var.clone(), + carrier_vars: pattern4.carrier_updates.keys().cloned().collect(), + condition: pattern4.condition.clone(), + continue_condition: pattern4.continue_condition.clone(), + carrier_updates: pattern4.carrier_updates.clone(), + loop_increment: pattern4.loop_increment.clone(), + }); + debug_assert_exit_usage_matches_plan(&plan, &facts.exit_usage); candidates.push(PlanCandidate { - plan: DomainPlan::Pattern4Continue(Pattern4ContinuePlan { - loop_var: pattern4.loop_var.clone(), - carrier_vars: pattern4.carrier_updates.keys().cloned().collect(), - condition: pattern4.condition.clone(), - continue_condition: pattern4.continue_condition.clone(), - carrier_updates: pattern4.carrier_updates.clone(), - loop_increment: pattern4.loop_increment.clone(), - }), + plan, rule: "loop/pattern4_continue", }); } @@ -183,16 +218,18 @@ fn push_pattern5_infinite_early_exit(candidates: &mut CandidateSet, facts: &Cano let Some(pattern5) = &facts.facts.pattern5_infinite_early_exit else { return; }; + let plan = DomainPlan::Pattern5InfiniteEarlyExit(Pattern5InfiniteEarlyExitPlan { + loop_var: pattern5.loop_var.clone(), + exit_kind: pattern5.exit_kind, + exit_condition: pattern5.exit_condition.clone(), + exit_value: pattern5.exit_value.clone(), + carrier_var: pattern5.carrier_var.clone(), + carrier_update: pattern5.carrier_update.clone(), + loop_increment: pattern5.loop_increment.clone(), + }); + debug_assert_exit_usage_matches_plan(&plan, &facts.exit_usage); candidates.push(PlanCandidate { - plan: DomainPlan::Pattern5InfiniteEarlyExit(Pattern5InfiniteEarlyExitPlan { - loop_var: pattern5.loop_var.clone(), - exit_kind: pattern5.exit_kind, - exit_condition: pattern5.exit_condition.clone(), - exit_value: pattern5.exit_value.clone(), - carrier_var: pattern5.carrier_var.clone(), - carrier_update: pattern5.carrier_update.clone(), - loop_increment: pattern5.loop_increment.clone(), - }), + plan, rule: "loop/pattern5_infinite_early_exit", }); } @@ -248,12 +285,14 @@ fn push_pattern1_simplewhile( let Some(pattern1) = &facts.facts.pattern1_simplewhile else { return; }; + let plan = DomainPlan::Pattern1SimpleWhile(Pattern1SimpleWhilePlan { + loop_var: pattern1.loop_var.clone(), + condition: pattern1.condition.clone(), + loop_increment: pattern1.loop_increment.clone(), + }); + debug_assert_exit_usage_matches_plan(&plan, &facts.exit_usage); candidates.push(PlanCandidate { - plan: DomainPlan::Pattern1SimpleWhile(Pattern1SimpleWhilePlan { - loop_var: pattern1.loop_var.clone(), - condition: pattern1.condition.clone(), - loop_increment: pattern1.loop_increment.clone(), - }), + plan, rule: "loop/pattern1_simplewhile", }); } @@ -535,6 +574,162 @@ mod tests { } } + #[test] + fn debug_exit_usage_invariant_pattern1_ok() { + let loop_condition = ASTNode::BinaryOp { + operator: BinaryOperator::Less, + left: Box::new(v("i")), + right: Box::new(lit_int(3)), + span: Span::unknown(), + }; + let loop_increment = ASTNode::BinaryOp { + operator: BinaryOperator::Add, + left: Box::new(v("i")), + right: Box::new(lit_int(1)), + span: Span::unknown(), + }; + let facts = LoopFacts { + condition_shape: ConditionShape::Unknown, + step_shape: StepShape::Unknown, + skeleton: SkeletonFacts { + kind: SkeletonKind::Loop, + }, + features: LoopFeatureFacts::default(), + scan_with_init: None, + split_scan: None, + pattern1_simplewhile: Some(Pattern1SimpleWhileFacts { + loop_var: "i".to_string(), + condition: loop_condition.clone(), + loop_increment: loop_increment.clone(), + }), + pattern3_ifphi: None, + pattern4_continue: None, + pattern5_infinite_early_exit: None, + pattern8_bool_predicate_scan: None, + pattern9_accum_const_loop: None, + pattern2_break: None, + pattern2_loopbodylocal: None, + }; + let canonical = canonicalize_loop_facts(facts); + let plan = build_plan_from_facts(canonical).expect("Ok"); + assert!(matches!(plan, Some(DomainPlan::Pattern1SimpleWhile(_)))); + } + + #[test] + fn debug_exit_usage_invariant_pattern2_ok() { + let loop_condition = ASTNode::BinaryOp { + operator: BinaryOperator::Less, + left: Box::new(v("i")), + right: Box::new(lit_int(3)), + span: Span::unknown(), + }; + let break_condition = ASTNode::BinaryOp { + operator: BinaryOperator::Equal, + left: Box::new(v("i")), + right: Box::new(lit_int(2)), + span: Span::unknown(), + }; + let carrier_update = ASTNode::BinaryOp { + operator: BinaryOperator::Add, + left: Box::new(v("sum")), + right: Box::new(lit_int(1)), + span: Span::unknown(), + }; + let loop_increment = ASTNode::BinaryOp { + operator: BinaryOperator::Add, + left: Box::new(v("i")), + right: Box::new(lit_int(1)), + span: Span::unknown(), + }; + let facts = LoopFacts { + condition_shape: ConditionShape::Unknown, + step_shape: StepShape::Unknown, + skeleton: SkeletonFacts { + kind: SkeletonKind::Loop, + }, + features: LoopFeatureFacts { + exit_usage: ExitUsageFacts { + has_break: true, + has_continue: false, + has_return: false, + }, + value_join: None, + cleanup: None, + }, + scan_with_init: None, + split_scan: None, + pattern1_simplewhile: None, + pattern3_ifphi: None, + pattern4_continue: None, + pattern5_infinite_early_exit: None, + pattern8_bool_predicate_scan: None, + pattern9_accum_const_loop: None, + pattern2_break: Some(Pattern2BreakFacts { + loop_var: "i".to_string(), + carrier_var: "sum".to_string(), + loop_condition, + break_condition, + carrier_update_in_break: None, + carrier_update_in_body: carrier_update, + loop_increment, + }), + pattern2_loopbodylocal: None, + }; + let canonical = canonicalize_loop_facts(facts); + let plan = build_plan_from_facts(canonical).expect("Ok"); + assert!(matches!(plan, Some(DomainPlan::Pattern2Break(_)))); + } + + #[cfg(debug_assertions)] + #[test] + #[should_panic] + fn debug_exit_usage_invariant_pattern1_panics_on_break_usage() { + let loop_condition = ASTNode::BinaryOp { + operator: BinaryOperator::Less, + left: Box::new(v("i")), + right: Box::new(lit_int(3)), + span: Span::unknown(), + }; + let loop_increment = ASTNode::BinaryOp { + operator: BinaryOperator::Add, + left: Box::new(v("i")), + right: Box::new(lit_int(1)), + span: Span::unknown(), + }; + let facts = LoopFacts { + condition_shape: ConditionShape::Unknown, + step_shape: StepShape::Unknown, + skeleton: SkeletonFacts { + kind: SkeletonKind::Loop, + }, + features: LoopFeatureFacts { + exit_usage: ExitUsageFacts { + has_break: true, + has_continue: false, + has_return: false, + }, + value_join: None, + cleanup: None, + }, + scan_with_init: None, + split_scan: None, + pattern1_simplewhile: Some(Pattern1SimpleWhileFacts { + loop_var: "i".to_string(), + condition: loop_condition.clone(), + loop_increment: loop_increment.clone(), + }), + pattern3_ifphi: None, + pattern4_continue: None, + pattern5_infinite_early_exit: None, + pattern8_bool_predicate_scan: None, + pattern9_accum_const_loop: None, + pattern2_break: None, + pattern2_loopbodylocal: None, + }; + let canonical = canonicalize_loop_facts(facts); + let _ = build_plan_from_facts(canonical).expect("Ok"); + } + #[test] fn planner_builds_pattern3_ifphi_plan_from_facts() { let loop_condition = ASTNode::BinaryOp { @@ -639,7 +834,15 @@ mod tests { skeleton: SkeletonFacts { kind: SkeletonKind::Loop, }, - features: LoopFeatureFacts::default(), + features: LoopFeatureFacts { + exit_usage: ExitUsageFacts { + has_break: false, + has_continue: true, + has_return: false, + }, + value_join: None, + cleanup: None, + }, scan_with_init: None, split_scan: None, pattern1_simplewhile: None, @@ -690,7 +893,15 @@ mod tests { skeleton: SkeletonFacts { kind: SkeletonKind::Loop, }, - features: LoopFeatureFacts::default(), + features: LoopFeatureFacts { + exit_usage: ExitUsageFacts { + has_break: false, + has_continue: false, + has_return: true, + }, + value_join: None, + cleanup: None, + }, scan_with_init: None, split_scan: None, pattern1_simplewhile: None, @@ -742,7 +953,15 @@ mod tests { skeleton: SkeletonFacts { kind: SkeletonKind::Loop, }, - features: LoopFeatureFacts::default(), + features: LoopFeatureFacts { + exit_usage: ExitUsageFacts { + has_break: true, + has_continue: false, + has_return: false, + }, + value_join: None, + cleanup: None, + }, scan_with_init: None, split_scan: None, pattern1_simplewhile: None,