Add pattern router tests and tidy pattern4 lowering
This commit is contained in:
@ -219,7 +219,7 @@ pub(crate) fn extract_features(loop_form: &LoopForm, scope: Option<&LoopScopeSha
|
||||
// Note: carriers is BTreeSet<String>, so each item is already a String
|
||||
let update_summary = scope.map(|s| {
|
||||
let carrier_names: Vec<String> = s.carriers.iter().cloned().collect();
|
||||
crate::mir::join_ir::lowering::loop_update_summary::analyze_loop_updates(&carrier_names)
|
||||
crate::mir::join_ir::lowering::loop_update_summary::analyze_loop_updates_by_name(&carrier_names)
|
||||
});
|
||||
|
||||
LoopFeatures {
|
||||
@ -664,107 +664,178 @@ fn has_simple_condition(_loop_form: &LoopForm) -> bool {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::ast::{ASTNode, BinaryOperator, LiteralValue, Span};
|
||||
use crate::mir::loop_pattern_detection::LoopFeatures;
|
||||
|
||||
// ========================================================================
|
||||
// Pattern 1: Simple While Loop Tests
|
||||
// ========================================================================
|
||||
fn span() -> Span {
|
||||
Span::unknown()
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern1_simple_while_detection() {
|
||||
// TODO: Add unit test for simple while pattern detection
|
||||
// Step 1: Create mock LoopForm with:
|
||||
// - Empty break_targets
|
||||
// - Empty continue_targets
|
||||
// - Single latch
|
||||
// Step 2: Call is_simple_while_pattern()
|
||||
// Step 3: Assert returns true
|
||||
fn var(name: &str) -> ASTNode {
|
||||
ASTNode::Variable {
|
||||
name: name.to_string(),
|
||||
span: span(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lit_i(n: i64) -> ASTNode {
|
||||
ASTNode::Literal {
|
||||
value: LiteralValue::Integer(n),
|
||||
span: span(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bin(op: BinaryOperator, left: ASTNode, right: ASTNode) -> ASTNode {
|
||||
ASTNode::BinaryOp {
|
||||
operator: op,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
span: span(),
|
||||
}
|
||||
}
|
||||
|
||||
fn assignment(target: ASTNode, value: ASTNode) -> ASTNode {
|
||||
ASTNode::Assignment {
|
||||
target: Box::new(target),
|
||||
value: Box::new(value),
|
||||
span: span(),
|
||||
}
|
||||
}
|
||||
|
||||
fn has_continue(node: &ASTNode) -> bool {
|
||||
match node {
|
||||
ASTNode::Continue { .. } => true,
|
||||
ASTNode::If { then_body, else_body, .. } => {
|
||||
then_body.iter().any(has_continue) || else_body.as_ref().map_or(false, |b| b.iter().any(has_continue))
|
||||
}
|
||||
ASTNode::Loop { body, .. } => body.iter().any(has_continue),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_break(node: &ASTNode) -> bool {
|
||||
match node {
|
||||
ASTNode::Break { .. } => true,
|
||||
ASTNode::If { then_body, else_body, .. } => {
|
||||
then_body.iter().any(has_break) || else_body.as_ref().map_or(false, |b| b.iter().any(has_break))
|
||||
}
|
||||
ASTNode::Loop { body, .. } => body.iter().any(has_break),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_if(body: &[ASTNode]) -> bool {
|
||||
body.iter().any(|n| matches!(n, ASTNode::If { .. }))
|
||||
}
|
||||
|
||||
fn carrier_count(body: &[ASTNode]) -> usize {
|
||||
fn count(nodes: &[ASTNode]) -> usize {
|
||||
let mut c = 0;
|
||||
for n in nodes {
|
||||
match n {
|
||||
ASTNode::Assignment { .. } => c += 1,
|
||||
ASTNode::If { then_body, else_body, .. } => {
|
||||
c += count(then_body);
|
||||
if let Some(else_body) = else_body {
|
||||
c += count(else_body);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
c
|
||||
}
|
||||
if count(body) > 0 { 1 } else { 0 }
|
||||
}
|
||||
|
||||
fn classify_body(body: &[ASTNode]) -> LoopPatternKind {
|
||||
let has_continue_flag = body.iter().any(has_continue);
|
||||
let has_break_flag = body.iter().any(has_break);
|
||||
let features = LoopFeatures {
|
||||
has_break: has_break_flag,
|
||||
has_continue: has_continue_flag,
|
||||
has_if: has_if(body),
|
||||
has_if_else_phi: false,
|
||||
carrier_count: carrier_count(body),
|
||||
break_count: if has_break_flag { 1 } else { 0 },
|
||||
continue_count: if has_continue_flag { 1 } else { 0 },
|
||||
update_summary: None,
|
||||
};
|
||||
classify(&features)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern1_rejects_break() {
|
||||
// TODO: Add test that rejects loop with break
|
||||
// Step 1: Create mock LoopForm with non-empty break_targets
|
||||
// Step 2: Call is_simple_while_pattern()
|
||||
// Step 3: Assert returns false
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Pattern 2: Loop with Break Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern2_break_detection() {
|
||||
// TODO: Add unit test for break pattern detection
|
||||
// Step 1: Create mock LoopForm with:
|
||||
// - Non-empty break_targets (exactly 1)
|
||||
// - Empty continue_targets
|
||||
// - If statement with break
|
||||
// Step 2: Call is_loop_with_break_pattern()
|
||||
// Step 3: Assert returns true
|
||||
fn pattern1_simple_while_is_detected() {
|
||||
// loop(i < len) { i = i + 1 }
|
||||
let body = vec![assignment(var("i"), bin(BinaryOperator::Add, var("i"), lit_i(1)))];
|
||||
let kind = classify_body(&body);
|
||||
assert_eq!(kind, LoopPatternKind::Pattern1SimpleWhile);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern2_rejects_no_break() {
|
||||
// TODO: Add test that rejects loop without break
|
||||
// Step 1: Create mock LoopForm with empty break_targets
|
||||
// Step 2: Call is_loop_with_break_pattern()
|
||||
// Step 3: Assert returns false
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Pattern 3: Loop with If-Else PHI Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern3_if_else_phi_detection() {
|
||||
// TODO: Add unit test for if-else phi pattern detection
|
||||
// Step 1: Create mock LoopForm with:
|
||||
// - Empty break_targets
|
||||
// - Empty continue_targets
|
||||
// - If-else statement in body
|
||||
// - Multiple carrier variables
|
||||
// Step 2: Call is_loop_with_conditional_phi_pattern()
|
||||
// Step 3: Assert returns true
|
||||
fn pattern2_break_loop_is_detected() {
|
||||
// loop(i < len) { if i > 0 { break } i = i + 1 }
|
||||
let cond = bin(BinaryOperator::Greater, var("i"), lit_i(0));
|
||||
let body = vec![
|
||||
ASTNode::If {
|
||||
condition: Box::new(cond),
|
||||
then_body: vec![ASTNode::Break { span: span() }],
|
||||
else_body: None,
|
||||
span: span(),
|
||||
},
|
||||
assignment(var("i"), bin(BinaryOperator::Add, var("i"), lit_i(1))),
|
||||
];
|
||||
let kind = classify_body(&body);
|
||||
assert_eq!(kind, LoopPatternKind::Pattern2Break);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern3_rejects_break() {
|
||||
// TODO: Add test that rejects loop with break
|
||||
// Step 1: Create mock LoopForm with non-empty break_targets
|
||||
// Step 2: Call is_loop_with_conditional_phi_pattern()
|
||||
// Step 3: Assert returns false
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Pattern 4: Loop with Continue Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern4_continue_detection() {
|
||||
// TODO: Add unit test for continue pattern detection
|
||||
// Step 1: Create mock LoopForm with:
|
||||
// - Non-empty continue_targets (at least 1)
|
||||
// - Empty break_targets
|
||||
// - If statement with continue
|
||||
// Step 2: Call is_loop_with_continue_pattern()
|
||||
// Step 3: Assert returns true
|
||||
fn pattern3_if_sum_shape_is_detected() {
|
||||
// loop(i < len) { if i % 2 == 1 { sum = sum + 1 } i = i + 1 }
|
||||
let cond = bin(
|
||||
BinaryOperator::Equal,
|
||||
bin(BinaryOperator::Modulo, var("i"), lit_i(2)),
|
||||
lit_i(1),
|
||||
);
|
||||
let body = vec![
|
||||
ASTNode::If {
|
||||
condition: Box::new(cond),
|
||||
then_body: vec![assignment(
|
||||
var("sum"),
|
||||
bin(BinaryOperator::Add, var("sum"), lit_i(1)),
|
||||
)],
|
||||
else_body: None,
|
||||
span: span(),
|
||||
},
|
||||
assignment(var("i"), bin(BinaryOperator::Add, var("i"), lit_i(1))),
|
||||
];
|
||||
let kind = classify_body(&body);
|
||||
assert_eq!(kind, LoopPatternKind::Pattern3IfPhi);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // TODO: Implement test after detection logic is complete
|
||||
fn test_pattern4_rejects_no_continue() {
|
||||
// TODO: Add test that rejects loop without continue
|
||||
// Step 1: Create mock LoopForm with empty continue_targets
|
||||
// Step 2: Call is_loop_with_continue_pattern()
|
||||
// Step 3: Assert returns false
|
||||
fn pattern4_continue_loop_is_detected() {
|
||||
// loop(i < len) { if (i % 2 == 0) { continue } sum = sum + i; i = i + 1 }
|
||||
let cond = bin(
|
||||
BinaryOperator::Equal,
|
||||
bin(BinaryOperator::Modulo, var("i"), lit_i(2)),
|
||||
lit_i(0),
|
||||
);
|
||||
let body = vec![
|
||||
ASTNode::If {
|
||||
condition: Box::new(cond),
|
||||
then_body: vec![ASTNode::Continue { span: span() }],
|
||||
else_body: Some(vec![assignment(
|
||||
var("sum"),
|
||||
bin(BinaryOperator::Add, var("sum"), var("i")),
|
||||
)]),
|
||||
span: span(),
|
||||
},
|
||||
assignment(var("i"), bin(BinaryOperator::Add, var("i"), lit_i(1))),
|
||||
];
|
||||
let kind = classify_body(&body);
|
||||
assert_eq!(kind, LoopPatternKind::Pattern4Continue);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user