Phase 61: structural if-sum+break lowering (dev-only)
This commit is contained in:
@ -25,10 +25,17 @@ use super::common::{
|
||||
process_local_inits,
|
||||
};
|
||||
use super::param_guess::{build_param_order, compute_param_guess};
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
use super::if_sum_break_pattern;
|
||||
use super::{AstToJoinIrLowerer, JoinModule, LoweringError};
|
||||
use crate::mir::join_ir::{JoinFunction, JoinInst};
|
||||
use crate::mir::ValueId;
|
||||
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
use crate::mir::join_ir::ownership::{
|
||||
plan_to_p2_inputs_with_relay, OwnershipAnalyzer,
|
||||
};
|
||||
|
||||
/// Break パターンを JoinModule に変換
|
||||
///
|
||||
/// # Arguments
|
||||
@ -37,6 +44,26 @@ use crate::mir::ValueId;
|
||||
pub fn lower(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> Result<JoinModule, LoweringError> {
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
{
|
||||
if let Some(module) = if_sum_break_pattern::try_lower_if_sum_break(lowerer, program_json)? {
|
||||
return Ok(module);
|
||||
}
|
||||
if let Ok(module) = lower_with_ownership_relay(lowerer, program_json) {
|
||||
return Ok(module);
|
||||
}
|
||||
}
|
||||
|
||||
lower_legacy_param_guess(lowerer, program_json)
|
||||
}
|
||||
|
||||
/// Legacy Break lowering (Phase P4) using param_guess heuristics.
|
||||
///
|
||||
/// This remains as a fallback and is also used for Phase 60 comparison tests.
|
||||
fn lower_legacy_param_guess(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> Result<JoinModule, LoweringError> {
|
||||
// 1. Program(JSON) をパース
|
||||
let parsed = parse_program_json(program_json);
|
||||
@ -101,6 +128,136 @@ pub fn lower(
|
||||
Ok(build_join_module(entry_func, loop_step_func, k_exit_func))
|
||||
}
|
||||
|
||||
/// Phase 60 dev-only Break lowering using OwnershipAnalyzer + relay threading.
|
||||
///
|
||||
/// This function is only compiled with `normalized_dev` and is fail-fast on
|
||||
/// multi-hop relay (relay_path.len()>1).
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn lower_with_ownership_relay(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> Result<JoinModule, LoweringError> {
|
||||
// Parse and build contexts similarly to legacy path.
|
||||
let parsed = parse_program_json(program_json);
|
||||
let (ctx, mut entry_ctx) = create_loop_context(lowerer, &parsed);
|
||||
let init_insts = process_local_inits(lowerer, &parsed, &mut entry_ctx);
|
||||
|
||||
let loop_node = &parsed.stmts[parsed.loop_node_idx];
|
||||
let loop_body = loop_node["body"]
|
||||
.as_array()
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Loop must have 'body' array".to_string(),
|
||||
})?;
|
||||
|
||||
let (break_if_idx, break_if_stmt) = loop_body
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, stmt)| {
|
||||
stmt["type"].as_str() == Some("If")
|
||||
&& stmt["then"].as_array().map_or(false, |then| {
|
||||
then.iter().any(|s| s["type"].as_str() == Some("Break"))
|
||||
})
|
||||
})
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Break pattern must have If + Break".to_string(),
|
||||
})?;
|
||||
|
||||
let break_cond_expr = &break_if_stmt["cond"];
|
||||
let loop_cond_expr = &loop_node["cond"];
|
||||
|
||||
// Use legacy guess only to stabilize loop_var/acc names.
|
||||
let legacy_guess = compute_param_guess(&entry_ctx);
|
||||
let loop_var_name = legacy_guess.loop_var.0.clone();
|
||||
let acc_name = legacy_guess.acc.0.clone();
|
||||
|
||||
let param_order = compute_param_order_from_ownership(program_json, &entry_ctx, &loop_var_name)
|
||||
.unwrap_or_else(|| build_param_order(&legacy_guess, &entry_ctx));
|
||||
|
||||
// Ensure accumulator is present in param list (avoid missing carrier ordering).
|
||||
let mut param_order = param_order;
|
||||
if !param_order.iter().any(|(n, _)| n == &acc_name) {
|
||||
if let Some(id) = entry_ctx.get_var(&acc_name) {
|
||||
param_order.push((acc_name.clone(), id));
|
||||
}
|
||||
}
|
||||
|
||||
let entry_func =
|
||||
create_entry_function_break(&ctx, &parsed, init_insts, &mut entry_ctx, ¶m_order);
|
||||
|
||||
let loop_step_func = create_loop_step_function_break(
|
||||
lowerer,
|
||||
&ctx,
|
||||
&parsed.func_name,
|
||||
loop_cond_expr,
|
||||
break_cond_expr,
|
||||
loop_body,
|
||||
¶m_order,
|
||||
&loop_var_name,
|
||||
&acc_name,
|
||||
break_if_idx,
|
||||
)?;
|
||||
|
||||
let k_exit_func = create_k_exit_function(&ctx, &parsed.func_name);
|
||||
Ok(build_join_module(entry_func, loop_step_func, k_exit_func))
|
||||
}
|
||||
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn compute_param_order_from_ownership(
|
||||
program_json: &serde_json::Value,
|
||||
entry_ctx: &super::super::context::ExtractCtx,
|
||||
loop_var_name: &str,
|
||||
) -> Option<Vec<(String, ValueId)>> {
|
||||
let mut analyzer = OwnershipAnalyzer::new();
|
||||
let plans = analyzer.analyze_json(program_json).ok()?;
|
||||
|
||||
let loop_plan = plans
|
||||
.iter()
|
||||
// Prefer the actual loop scope (loop var is re-bound inside loop => relay_writes)
|
||||
.find(|p| p.relay_writes.iter().any(|r| r.name == loop_var_name))
|
||||
// Fallback: any loop plan with relay_writes
|
||||
.or_else(|| plans.iter().find(|p| !p.relay_writes.is_empty()))
|
||||
// Last resort: any plan that owns loop_var_name (loop-local case)
|
||||
.or_else(|| plans.iter().find(|p| p.owned_vars.iter().any(|v| v.name == loop_var_name)))?;
|
||||
|
||||
let inputs = plan_to_p2_inputs_with_relay(loop_plan, loop_var_name).ok()?;
|
||||
|
||||
let mut order: Vec<(String, ValueId)> = Vec::new();
|
||||
let mut seen = std::collections::BTreeSet::<String>::new();
|
||||
|
||||
if let Some(id) = entry_ctx.get_var(loop_var_name) {
|
||||
order.push((loop_var_name.to_string(), id));
|
||||
seen.insert(loop_var_name.to_string());
|
||||
}
|
||||
|
||||
for carrier in inputs.carriers {
|
||||
if seen.contains(&carrier.name) {
|
||||
continue;
|
||||
}
|
||||
if let Some(id) = entry_ctx.get_var(&carrier.name) {
|
||||
order.push((carrier.name.clone(), id));
|
||||
seen.insert(carrier.name);
|
||||
}
|
||||
}
|
||||
|
||||
for (name, var_id) in &entry_ctx.var_map {
|
||||
if !seen.contains(name) {
|
||||
order.push((name.clone(), *var_id));
|
||||
seen.insert(name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Some(order)
|
||||
}
|
||||
|
||||
/// Expose legacy Break lowering for Phase 60 comparison tests (dev-only).
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
pub fn lower_break_legacy_for_comparison(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> Result<JoinModule, LoweringError> {
|
||||
lower_legacy_param_guess(lowerer, program_json)
|
||||
}
|
||||
|
||||
/// Break パターン用 entry 関数を生成
|
||||
fn create_entry_function_break(
|
||||
ctx: &super::common::LoopContext,
|
||||
|
||||
@ -0,0 +1,503 @@
|
||||
//! Phase 61: If-Sum + Break pattern (dev-only)
|
||||
//!
|
||||
//! ## Responsibility
|
||||
//! Break 付き if-sum ループを、sum/count の複数キャリアで k_exit に渡し、
|
||||
//! k_exit で `sum + count` を返す。
|
||||
//!
|
||||
//! ## Fail-Fast Boundary
|
||||
//! - Return が `Var + Var` 以外 → not matched
|
||||
//! - ループ末尾の counter update が `i = i + 1` 形で検出できない → Err
|
||||
//! - Ownership relay が single-hop 以外 → Err
|
||||
//! - loop-carried carriers が Return の 2 変数と一致しない → Err
|
||||
|
||||
#![cfg(feature = "normalized_dev")]
|
||||
|
||||
use super::common::{
|
||||
build_join_module, create_k_exit_function, create_loop_context, parse_program_json,
|
||||
process_local_inits,
|
||||
};
|
||||
use super::{AstToJoinIrLowerer, JoinModule, LoweringError};
|
||||
use crate::mir::join_ir::{BinOpKind, JoinFunction, JoinInst, MirLikeInst};
|
||||
use crate::mir::ValueId;
|
||||
use crate::mir::join_ir::ownership::{plan_to_p3_inputs_with_relay, OwnershipAnalyzer};
|
||||
|
||||
pub fn try_lower_if_sum_break(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> Result<Option<JoinModule>, LoweringError> {
|
||||
let parsed = parse_program_json(program_json);
|
||||
|
||||
let return_expr = parsed.stmts.last().and_then(|s| {
|
||||
if s["type"].as_str() == Some("Return") {
|
||||
s.get("expr")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let Some((ret_lhs, ret_rhs)) = parse_return_var_plus_var(return_expr) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let loop_node = &parsed.stmts[parsed.loop_node_idx];
|
||||
let loop_body = loop_node["body"]
|
||||
.as_array()
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Loop must have 'body' array".to_string(),
|
||||
})?;
|
||||
|
||||
let (break_if_idx, break_if_stmt) = match find_break_if(loop_body) {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
if break_if_idx != 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
let break_cond_expr = &break_if_stmt["cond"];
|
||||
|
||||
// Limit scope (Phase 61 dev-only): [break-if, update-if, counter-update] only.
|
||||
if loop_body.len() != 3 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let update_if_stmt = match find_single_update_if(loop_body, break_if_idx) {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let counter_update_stmt = loop_body
|
||||
.last()
|
||||
.expect("loop_body len checked")
|
||||
.clone();
|
||||
|
||||
let loop_var_name = detect_counter_update_loop_var(loop_body).ok_or_else(|| {
|
||||
LoweringError::InvalidLoopBody {
|
||||
message: "if-sum-break requires trailing counter update like i = i + 1".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
if ret_lhs == loop_var_name || ret_rhs == loop_var_name || ret_lhs == ret_rhs {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// === Ownership SSOT: param order = [loop_var] + carriers + captures ===
|
||||
let (ctx, mut entry_ctx) = create_loop_context(lowerer, &parsed);
|
||||
let init_insts = process_local_inits(lowerer, &parsed, &mut entry_ctx);
|
||||
|
||||
let mut analyzer = OwnershipAnalyzer::new();
|
||||
let plans = analyzer
|
||||
.analyze_json(program_json)
|
||||
.map_err(|e| LoweringError::JsonParseError { message: e })?;
|
||||
|
||||
let loop_plan = plans
|
||||
.iter()
|
||||
.find(|p| p.relay_writes.iter().any(|r| r.name == loop_var_name))
|
||||
.or_else(|| plans.iter().find(|p| p.owned_vars.iter().any(|v| v.name == loop_var_name)))
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "if-sum-break: failed to find loop ownership plan".to_string(),
|
||||
})?;
|
||||
|
||||
let inputs = plan_to_p3_inputs_with_relay(loop_plan, &loop_var_name).map_err(|e| {
|
||||
LoweringError::JsonParseError { message: e }
|
||||
})?;
|
||||
|
||||
// Ensure carriers are exactly the return vars (fail-fast mixing protection).
|
||||
let carrier_names: std::collections::BTreeSet<String> =
|
||||
inputs.carriers.iter().map(|c| c.name.clone()).collect();
|
||||
let expected: std::collections::BTreeSet<String> =
|
||||
[ret_lhs.clone(), ret_rhs.clone()].into_iter().collect();
|
||||
|
||||
if carrier_names != expected {
|
||||
return Err(LoweringError::InvalidLoopBody {
|
||||
message: format!(
|
||||
"if-sum-break: carriers {:?} must equal return vars {:?}",
|
||||
carrier_names, expected
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut param_order: Vec<(String, ValueId)> = Vec::new();
|
||||
let mut seen = std::collections::BTreeSet::<String>::new();
|
||||
|
||||
let loop_var_id = entry_ctx
|
||||
.get_var(&loop_var_name)
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: format!("loop var '{}' must be initialized before loop", loop_var_name),
|
||||
})?;
|
||||
param_order.push((loop_var_name.clone(), loop_var_id));
|
||||
seen.insert(loop_var_name.clone());
|
||||
|
||||
for carrier in &inputs.carriers {
|
||||
let id = entry_ctx
|
||||
.get_var(&carrier.name)
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: format!("carrier '{}' must be initialized before loop", carrier.name),
|
||||
})?;
|
||||
param_order.push((carrier.name.clone(), id));
|
||||
seen.insert(carrier.name.clone());
|
||||
}
|
||||
|
||||
for cap_name in &inputs.captures {
|
||||
if seen.contains(cap_name) {
|
||||
continue;
|
||||
}
|
||||
if let Some(id) = entry_ctx.get_var(cap_name) {
|
||||
param_order.push((cap_name.clone(), id));
|
||||
seen.insert(cap_name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Include remaining params/vars deterministically
|
||||
for (name, var_id) in &entry_ctx.var_map {
|
||||
if !seen.contains(name) {
|
||||
param_order.push((name.clone(), *var_id));
|
||||
seen.insert(name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let entry_func = create_entry_function_if_sum_break(
|
||||
&ctx,
|
||||
&parsed,
|
||||
init_insts,
|
||||
&mut entry_ctx,
|
||||
¶m_order,
|
||||
);
|
||||
|
||||
let loop_step_func = create_loop_step_function_if_sum_break(
|
||||
lowerer,
|
||||
&ctx,
|
||||
&parsed.func_name,
|
||||
&loop_node["cond"],
|
||||
break_cond_expr,
|
||||
update_if_stmt,
|
||||
&counter_update_stmt,
|
||||
¶m_order,
|
||||
&loop_var_name,
|
||||
&ret_lhs,
|
||||
&ret_rhs,
|
||||
)?;
|
||||
|
||||
let k_exit_func = create_k_exit_function(&ctx, &parsed.func_name);
|
||||
|
||||
Ok(Some(build_join_module(entry_func, loop_step_func, k_exit_func)))
|
||||
}
|
||||
|
||||
fn parse_return_var_plus_var(
|
||||
expr: Option<&serde_json::Value>,
|
||||
) -> Option<(String, String)> {
|
||||
let expr = expr?;
|
||||
if expr["type"].as_str()? != "Binary" {
|
||||
return None;
|
||||
}
|
||||
if expr["op"].as_str()? != "+" {
|
||||
return None;
|
||||
}
|
||||
let lhs = expr["lhs"].as_object()?;
|
||||
let rhs = expr["rhs"].as_object()?;
|
||||
if lhs.get("type")?.as_str()? != "Var" || rhs.get("type")?.as_str()? != "Var" {
|
||||
return None;
|
||||
}
|
||||
Some((
|
||||
lhs.get("name")?.as_str()?.to_string(),
|
||||
rhs.get("name")?.as_str()?.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn find_break_if(loop_body: &[serde_json::Value]) -> Option<(usize, &serde_json::Value)> {
|
||||
loop_body.iter().enumerate().find(|(_, stmt)| {
|
||||
stmt["type"].as_str() == Some("If")
|
||||
&& stmt["then"].as_array().map_or(false, |then| {
|
||||
then.iter().any(|s| s["type"].as_str() == Some("Break"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_counter_update_loop_var(loop_body: &[serde_json::Value]) -> Option<String> {
|
||||
let last = loop_body.last()?;
|
||||
if last["type"].as_str()? != "Local" {
|
||||
return None;
|
||||
}
|
||||
let name = last["name"].as_str()?.to_string();
|
||||
let expr = last.get("expr")?;
|
||||
if expr["type"].as_str()? != "Binary" || expr["op"].as_str()? != "+" {
|
||||
return None;
|
||||
}
|
||||
let lhs = &expr["lhs"];
|
||||
let rhs = &expr["rhs"];
|
||||
if lhs["type"].as_str()? != "Var" || lhs["name"].as_str()? != name {
|
||||
return None;
|
||||
}
|
||||
if rhs["type"].as_str()? != "Int" || rhs["value"].as_i64()? != 1 {
|
||||
return None;
|
||||
}
|
||||
Some(name)
|
||||
}
|
||||
|
||||
fn create_entry_function_if_sum_break(
|
||||
ctx: &super::common::LoopContext,
|
||||
parsed: &super::common::ParsedProgram,
|
||||
init_insts: Vec<JoinInst>,
|
||||
entry_ctx: &mut super::super::context::ExtractCtx,
|
||||
param_order: &[(String, ValueId)],
|
||||
) -> JoinFunction {
|
||||
let loop_args: Vec<ValueId> = param_order.iter().map(|(_, id)| *id).collect();
|
||||
let loop_result = entry_ctx.alloc_var();
|
||||
let mut body = init_insts;
|
||||
body.push(JoinInst::Call {
|
||||
func: ctx.loop_step_id,
|
||||
args: loop_args,
|
||||
k_next: None,
|
||||
dst: Some(loop_result),
|
||||
});
|
||||
body.push(JoinInst::Ret {
|
||||
value: Some(loop_result),
|
||||
});
|
||||
JoinFunction {
|
||||
id: ctx.entry_id,
|
||||
name: parsed.func_name.clone(),
|
||||
params: (0..parsed.param_names.len())
|
||||
.map(|i| ValueId(i as u32))
|
||||
.collect(),
|
||||
body,
|
||||
exit_cont: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_loop_step_function_if_sum_break(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
ctx: &super::common::LoopContext,
|
||||
func_name: &str,
|
||||
loop_cond_expr: &serde_json::Value,
|
||||
break_cond_expr: &serde_json::Value,
|
||||
update_if_stmt: &serde_json::Value,
|
||||
counter_update_stmt: &serde_json::Value,
|
||||
param_order: &[(String, ValueId)],
|
||||
loop_var_name: &str,
|
||||
sum_name: &str,
|
||||
count_name: &str,
|
||||
) -> Result<JoinFunction, LoweringError> {
|
||||
use super::super::context::ExtractCtx;
|
||||
|
||||
let param_names: Vec<String> = param_order.iter().map(|(name, _)| name.clone()).collect();
|
||||
let mut step_ctx = ExtractCtx::new(param_names.len() as u32);
|
||||
for (idx, name) in param_names.iter().enumerate() {
|
||||
step_ctx.register_param(name.clone(), ValueId(idx as u32));
|
||||
}
|
||||
|
||||
let mut body = Vec::new();
|
||||
|
||||
// Header condition: if !loop_cond -> exit with (sum, count)
|
||||
let (loop_cond_var, loop_cond_insts) = lowerer.extract_value(loop_cond_expr, &mut step_ctx);
|
||||
body.extend(loop_cond_insts);
|
||||
let header_exit_flag = step_ctx.alloc_var();
|
||||
body.push(JoinInst::Compute(MirLikeInst::UnaryOp {
|
||||
dst: header_exit_flag,
|
||||
op: crate::mir::join_ir::UnaryOp::Not,
|
||||
operand: loop_cond_var,
|
||||
}));
|
||||
let sum_before = step_ctx
|
||||
.get_var(sum_name)
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: format!("{} must exist", sum_name),
|
||||
})?;
|
||||
let count_before = step_ctx
|
||||
.get_var(count_name)
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: format!("{} must exist", count_name),
|
||||
})?;
|
||||
let acc_before = step_ctx.alloc_var();
|
||||
body.push(JoinInst::Compute(MirLikeInst::BinOp {
|
||||
dst: acc_before,
|
||||
op: BinOpKind::Add,
|
||||
lhs: sum_before,
|
||||
rhs: count_before,
|
||||
}));
|
||||
body.push(JoinInst::Jump {
|
||||
cont: ctx.k_exit_id.as_cont(),
|
||||
args: vec![acc_before],
|
||||
cond: Some(header_exit_flag),
|
||||
});
|
||||
|
||||
// Break condition: if break_cond -> exit with (sum, count)
|
||||
let (break_cond_var, break_cond_insts) = lowerer.extract_value(break_cond_expr, &mut step_ctx);
|
||||
body.extend(break_cond_insts);
|
||||
body.push(JoinInst::Jump {
|
||||
cont: ctx.k_exit_id.as_cont(),
|
||||
args: vec![acc_before],
|
||||
cond: Some(break_cond_var),
|
||||
});
|
||||
|
||||
// Update-if: cond ? then_update : else_update (Select-based, no if-in-loop lowering)
|
||||
let update_cond_expr = &update_if_stmt["cond"];
|
||||
let (update_cond_var, update_cond_insts) =
|
||||
lowerer.extract_value(update_cond_expr, &mut step_ctx);
|
||||
body.extend(update_cond_insts);
|
||||
|
||||
let (sum_then_expr, sum_else_expr) =
|
||||
extract_if_branch_assignment(update_if_stmt, sum_name)?;
|
||||
let (count_then_expr, count_else_expr) =
|
||||
extract_if_branch_assignment(update_if_stmt, count_name)?;
|
||||
|
||||
let (sum_then_val, sum_then_insts) =
|
||||
lowerer.extract_value(&sum_then_expr, &mut step_ctx);
|
||||
let (sum_else_val, sum_else_insts) =
|
||||
lowerer.extract_value(&sum_else_expr, &mut step_ctx);
|
||||
let (count_then_val, count_then_insts) =
|
||||
lowerer.extract_value(&count_then_expr, &mut step_ctx);
|
||||
let (count_else_val, count_else_insts) =
|
||||
lowerer.extract_value(&count_else_expr, &mut step_ctx);
|
||||
body.extend(sum_then_insts);
|
||||
body.extend(sum_else_insts);
|
||||
body.extend(count_then_insts);
|
||||
body.extend(count_else_insts);
|
||||
|
||||
let sum_next = step_ctx.alloc_var();
|
||||
body.push(JoinInst::Compute(MirLikeInst::Select {
|
||||
dst: sum_next,
|
||||
cond: update_cond_var,
|
||||
then_val: sum_then_val,
|
||||
else_val: sum_else_val,
|
||||
}));
|
||||
step_ctx.register_param(sum_name.to_string(), sum_next);
|
||||
|
||||
let count_next = step_ctx.alloc_var();
|
||||
body.push(JoinInst::Compute(MirLikeInst::Select {
|
||||
dst: count_next,
|
||||
cond: update_cond_var,
|
||||
then_val: count_then_val,
|
||||
else_val: count_else_val,
|
||||
}));
|
||||
step_ctx.register_param(count_name.to_string(), count_next);
|
||||
|
||||
// Counter update (must update loop var)
|
||||
let counter_expr = counter_update_stmt.get("expr").ok_or_else(|| {
|
||||
LoweringError::InvalidLoopBody {
|
||||
message: "counter update must have 'expr'".to_string(),
|
||||
}
|
||||
})?;
|
||||
let (i_next, i_insts) = lowerer.extract_value(counter_expr, &mut step_ctx);
|
||||
body.extend(i_insts);
|
||||
step_ctx.register_param(loop_var_name.to_string(), i_next);
|
||||
|
||||
// Recurse with updated params.
|
||||
let recurse_result = step_ctx.alloc_var();
|
||||
let mut recurse_args = Vec::new();
|
||||
for name in ¶m_names {
|
||||
let arg = step_ctx
|
||||
.get_var(name)
|
||||
.unwrap_or_else(|| panic!("param {} must exist", name));
|
||||
recurse_args.push(arg);
|
||||
}
|
||||
body.push(JoinInst::Call {
|
||||
func: ctx.loop_step_id,
|
||||
args: recurse_args,
|
||||
k_next: None,
|
||||
dst: Some(recurse_result),
|
||||
});
|
||||
body.push(JoinInst::Ret {
|
||||
value: Some(recurse_result),
|
||||
});
|
||||
|
||||
Ok(JoinFunction {
|
||||
id: ctx.loop_step_id,
|
||||
name: format!("{}_loop_step", func_name),
|
||||
params: (0..param_names.len())
|
||||
.map(|i| ValueId(i as u32))
|
||||
.collect(),
|
||||
body,
|
||||
exit_cont: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn find_single_update_if<'a>(
|
||||
loop_body: &'a [serde_json::Value],
|
||||
break_if_idx: usize,
|
||||
) -> Option<&'a serde_json::Value> {
|
||||
let mut found: Option<&serde_json::Value> = None;
|
||||
for (idx, stmt) in loop_body.iter().enumerate() {
|
||||
if idx == break_if_idx {
|
||||
continue;
|
||||
}
|
||||
if stmt["type"].as_str() == Some("If") {
|
||||
if found.is_some() {
|
||||
return None;
|
||||
}
|
||||
found = Some(stmt);
|
||||
}
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
fn extract_if_branch_assignment(
|
||||
if_stmt: &serde_json::Value,
|
||||
target: &str,
|
||||
) -> Result<(serde_json::Value, serde_json::Value), LoweringError> {
|
||||
fn find_assignment_expr(branch: &[serde_json::Value], target: &str) -> Result<Option<serde_json::Value>, LoweringError> {
|
||||
let mut found: Option<serde_json::Value> = None;
|
||||
for stmt in branch {
|
||||
match stmt["type"].as_str() {
|
||||
Some("Local") => {
|
||||
let name = stmt["name"].as_str().ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Local must have 'name'".to_string(),
|
||||
})?;
|
||||
if name != target {
|
||||
continue;
|
||||
}
|
||||
if found.is_some() {
|
||||
return Err(LoweringError::InvalidLoopBody {
|
||||
message: format!("if-sum-break: multiple assignments to '{}'", target),
|
||||
});
|
||||
}
|
||||
let expr = stmt.get("expr").ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Local must have 'expr'".to_string(),
|
||||
})?;
|
||||
found = Some(expr.clone());
|
||||
}
|
||||
Some("Assignment") | Some("Assign") => {
|
||||
let name = stmt["target"].as_str().ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "Assignment must have 'target'".to_string(),
|
||||
})?;
|
||||
if name != target {
|
||||
continue;
|
||||
}
|
||||
if found.is_some() {
|
||||
return Err(LoweringError::InvalidLoopBody {
|
||||
message: format!("if-sum-break: multiple assignments to '{}'", target),
|
||||
});
|
||||
}
|
||||
let expr = stmt.get("expr").or_else(|| stmt.get("value")).ok_or_else(|| {
|
||||
LoweringError::InvalidLoopBody {
|
||||
message: "Assignment must have 'expr' or 'value'".to_string(),
|
||||
}
|
||||
})?;
|
||||
found = Some(expr.clone());
|
||||
}
|
||||
_ => {
|
||||
return Err(LoweringError::InvalidLoopBody {
|
||||
message: "if-sum-break: unsupported statement in update if".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(found)
|
||||
}
|
||||
|
||||
let then_branch = if_stmt["then"]
|
||||
.as_array()
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "If must have 'then' array".to_string(),
|
||||
})?;
|
||||
let else_branch = if_stmt["else"]
|
||||
.as_array()
|
||||
.ok_or_else(|| LoweringError::InvalidLoopBody {
|
||||
message: "If must have 'else' array".to_string(),
|
||||
})?;
|
||||
|
||||
let then_expr = find_assignment_expr(then_branch, target)?.unwrap_or_else(|| {
|
||||
serde_json::json!({"type":"Var","name":target})
|
||||
});
|
||||
let else_expr = find_assignment_expr(else_branch, target)?.unwrap_or_else(|| {
|
||||
serde_json::json!({"type":"Var","name":target})
|
||||
});
|
||||
Ok((then_expr, else_expr))
|
||||
}
|
||||
@ -19,6 +19,8 @@ pub mod break_pattern;
|
||||
pub mod common;
|
||||
pub mod continue_pattern;
|
||||
pub mod filter;
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
pub mod if_sum_break_pattern;
|
||||
pub mod param_guess;
|
||||
pub mod print_tokens;
|
||||
pub mod simple;
|
||||
|
||||
@ -189,6 +189,18 @@ impl AstToJoinIrLowerer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase 60 dev-only helper: legacy Break(P2) lowering for comparison tests.
|
||||
///
|
||||
/// `loop_patterns` is private, so this wrapper is exposed at the ast_lowerer boundary.
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
pub fn lower_break_legacy_for_comparison(
|
||||
lowerer: &mut AstToJoinIrLowerer,
|
||||
program_json: &serde_json::Value,
|
||||
) -> JoinModule {
|
||||
loop_patterns::break_pattern::lower_break_legacy_for_comparison(lowerer, program_json)
|
||||
.unwrap_or_else(|e| panic!("legacy break lowering failed: {:?}", e))
|
||||
}
|
||||
|
||||
impl Default for AstToJoinIrLowerer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
|
||||
@ -351,6 +351,10 @@ pub fn normalize_pattern2_minimal(structured: &JoinModule) -> NormalizedModule {
|
||||
NormalizedDevShape::Pattern3IfSumMinimal
|
||||
| NormalizedDevShape::Pattern3IfSumMulti
|
||||
| NormalizedDevShape::Pattern3IfSumJson
|
||||
| NormalizedDevShape::SelfhostIfSumP3
|
||||
| NormalizedDevShape::SelfhostIfSumP3Ext
|
||||
| NormalizedDevShape::SelfhostStmtCountP3
|
||||
| NormalizedDevShape::SelfhostDetectFormatP3
|
||||
)
|
||||
}) {
|
||||
max = max.max(6);
|
||||
|
||||
@ -312,7 +312,10 @@ pub fn detect_shapes(module: &JoinModule) -> Vec<NormalizedDevShape> {
|
||||
|| shapes.contains(&NormalizedDevShape::SelfhostArgsParseP2)
|
||||
|| shapes.contains(&NormalizedDevShape::SelfhostVerifySchemaP2)
|
||||
{
|
||||
shapes.retain(|s| *s != NormalizedDevShape::Pattern2Mini);
|
||||
shapes.retain(|s| {
|
||||
*s != NormalizedDevShape::Pattern2Mini
|
||||
&& *s != NormalizedDevShape::Pattern4ContinueMinimal
|
||||
});
|
||||
}
|
||||
if shapes.contains(&NormalizedDevShape::SelfhostIfSumP3)
|
||||
|| shapes.contains(&NormalizedDevShape::SelfhostIfSumP3Ext)
|
||||
@ -325,6 +328,7 @@ pub fn detect_shapes(module: &JoinModule) -> Vec<NormalizedDevShape> {
|
||||
NormalizedDevShape::Pattern3IfSumMinimal
|
||||
| NormalizedDevShape::Pattern3IfSumMulti
|
||||
| NormalizedDevShape::Pattern3IfSumJson
|
||||
| NormalizedDevShape::Pattern4ContinueMinimal
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@ -51,11 +51,33 @@ impl OwnershipAnalyzer {
|
||||
self.scopes.clear();
|
||||
self.next_scope_id = 0;
|
||||
|
||||
// Find functions and analyze each
|
||||
// Find functions and analyze each.
|
||||
//
|
||||
// Supported inputs:
|
||||
// - "functions": [...], with statement nodes using "kind" (test schema)
|
||||
// - "defs": [...], where FunctionDef nodes contain "params"/"body" (Program(JSON v0))
|
||||
if let Some(functions) = json.get("functions").and_then(|f| f.as_array()) {
|
||||
for func in functions {
|
||||
self.analyze_function(func, None)?;
|
||||
}
|
||||
} else if let Some(defs) = json.get("defs").and_then(|d| d.as_array()) {
|
||||
let mut found = false;
|
||||
for def in defs {
|
||||
let def_kind = def
|
||||
.get("type")
|
||||
.or_else(|| def.get("kind"))
|
||||
.and_then(|k| k.as_str())
|
||||
.unwrap_or("");
|
||||
if def_kind == "FunctionDef" {
|
||||
found = true;
|
||||
self.analyze_function(def, None)?;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return Err("OwnershipAnalyzer: no FunctionDef found in 'defs'".to_string());
|
||||
}
|
||||
} else {
|
||||
return Err("OwnershipAnalyzer: expected top-level 'functions' or 'defs' array".to_string());
|
||||
}
|
||||
|
||||
// Convert ScopeInfo to OwnershipPlan
|
||||
@ -98,26 +120,45 @@ impl OwnershipAnalyzer {
|
||||
}
|
||||
|
||||
fn analyze_statement(&mut self, stmt: &Value, current_scope: ScopeId) -> Result<(), String> {
|
||||
let kind = stmt.get("kind").and_then(|k| k.as_str()).unwrap_or("");
|
||||
let kind = stmt
|
||||
.get("kind")
|
||||
.or_else(|| stmt.get("type"))
|
||||
.and_then(|k| k.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match kind {
|
||||
"Local" => {
|
||||
// Variable definition
|
||||
// NOTE: Program(JSON v0) historically uses Local for both "new binding"
|
||||
// and "rebind/update". This analyzer treats Local as:
|
||||
// - definition when the name is not yet defined in scope chain
|
||||
// - write when the name is already defined (rebind)
|
||||
if let Some(name) = stmt.get("name").and_then(|n| n.as_str()) {
|
||||
// Find enclosing loop (or function) for ownership
|
||||
let owner_scope = self.find_enclosing_loop_or_function(current_scope);
|
||||
self.scopes.get_mut(&owner_scope).unwrap().defined.insert(name.to_string());
|
||||
if self.is_defined_in_scope_chain(current_scope, name) {
|
||||
self.scopes
|
||||
.get_mut(¤t_scope)
|
||||
.unwrap()
|
||||
.writes
|
||||
.insert(name.to_string());
|
||||
} else {
|
||||
// Find enclosing loop (or function) for ownership
|
||||
let owner_scope = self.find_enclosing_loop_or_function(current_scope);
|
||||
self.scopes
|
||||
.get_mut(&owner_scope)
|
||||
.unwrap()
|
||||
.defined
|
||||
.insert(name.to_string());
|
||||
}
|
||||
}
|
||||
// Analyze initializer if present
|
||||
if let Some(init) = stmt.get("init") {
|
||||
if let Some(init) = stmt.get("init").or_else(|| stmt.get("expr")) {
|
||||
self.analyze_expression(init, current_scope, false)?;
|
||||
}
|
||||
}
|
||||
"Assign" => {
|
||||
"Assign" | "Assignment" => {
|
||||
if let Some(target) = stmt.get("target").and_then(|t| t.as_str()) {
|
||||
self.scopes.get_mut(¤t_scope).unwrap().writes.insert(target.to_string());
|
||||
}
|
||||
if let Some(value) = stmt.get("value") {
|
||||
if let Some(value) = stmt.get("value").or_else(|| stmt.get("expr")) {
|
||||
self.analyze_expression(value, current_scope, false)?;
|
||||
}
|
||||
}
|
||||
@ -125,7 +166,7 @@ impl OwnershipAnalyzer {
|
||||
let loop_scope = self.alloc_scope(ScopeKind::Loop, Some(current_scope));
|
||||
|
||||
// Analyze condition (mark as condition_reads)
|
||||
if let Some(cond) = stmt.get("condition") {
|
||||
if let Some(cond) = stmt.get("condition").or_else(|| stmt.get("cond")) {
|
||||
self.analyze_expression(cond, loop_scope, true)?;
|
||||
}
|
||||
|
||||
@ -141,7 +182,7 @@ impl OwnershipAnalyzer {
|
||||
let if_scope = self.alloc_scope(ScopeKind::If, Some(current_scope));
|
||||
|
||||
// Analyze condition
|
||||
if let Some(cond) = stmt.get("condition") {
|
||||
if let Some(cond) = stmt.get("condition").or_else(|| stmt.get("cond")) {
|
||||
self.analyze_expression(cond, if_scope, true)?;
|
||||
}
|
||||
|
||||
@ -158,7 +199,11 @@ impl OwnershipAnalyzer {
|
||||
"Block" => {
|
||||
let block_scope = self.alloc_scope(ScopeKind::Block, Some(current_scope));
|
||||
|
||||
if let Some(stmts) = stmt.get("statements").and_then(|s| s.as_array()) {
|
||||
let stmts = stmt
|
||||
.get("statements")
|
||||
.or_else(|| stmt.get("body"))
|
||||
.and_then(|s| s.as_array());
|
||||
if let Some(stmts) = stmts {
|
||||
for s in stmts {
|
||||
self.analyze_statement(s, block_scope)?;
|
||||
}
|
||||
@ -167,7 +212,7 @@ impl OwnershipAnalyzer {
|
||||
self.propagate_to_parent(block_scope);
|
||||
}
|
||||
"Return" | "Break" | "Continue" => {
|
||||
if let Some(value) = stmt.get("value") {
|
||||
if let Some(value) = stmt.get("value").or_else(|| stmt.get("expr")) {
|
||||
self.analyze_expression(value, current_scope, false)?;
|
||||
}
|
||||
}
|
||||
@ -190,7 +235,11 @@ impl OwnershipAnalyzer {
|
||||
}
|
||||
|
||||
fn analyze_expression(&mut self, expr: &Value, current_scope: ScopeId, is_condition: bool) -> Result<(), String> {
|
||||
let kind = expr.get("kind").and_then(|k| k.as_str()).unwrap_or("");
|
||||
let kind = expr
|
||||
.get("kind")
|
||||
.or_else(|| expr.get("type"))
|
||||
.and_then(|k| k.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
match kind {
|
||||
"Var" | "Variable" | "Identifier" => {
|
||||
@ -248,6 +297,21 @@ impl OwnershipAnalyzer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_defined_in_scope_chain(&self, from_scope: ScopeId, name: &str) -> bool {
|
||||
let mut current = Some(from_scope);
|
||||
while let Some(id) = current {
|
||||
let scope = match self.scopes.get(&id) {
|
||||
Some(scope) => scope,
|
||||
None => break,
|
||||
};
|
||||
if scope.defined.contains(name) {
|
||||
return true;
|
||||
}
|
||||
current = scope.parent;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Find enclosing Loop or Function scope (body-local ownership rule)
|
||||
fn find_enclosing_loop_or_function(&self, scope_id: ScopeId) -> ScopeId {
|
||||
let scope = &self.scopes[&scope_id];
|
||||
@ -629,4 +693,30 @@ mod tests {
|
||||
|
||||
assert!(any_relay, "Some loop should relay 'total' to function");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_program_json_v0_break_fixture_relay_and_capture() {
|
||||
let program_json: serde_json::Value = serde_json::from_str(include_str!(
|
||||
"../../../../docs/private/roadmap2/phases/phase-34-joinir-frontend/fixtures/loop_frontend_break.program.json"
|
||||
))
|
||||
.expect("fixture json");
|
||||
|
||||
let mut analyzer = OwnershipAnalyzer::new();
|
||||
let plans = analyzer
|
||||
.analyze_json(&program_json)
|
||||
.expect("Program(JSON v0) analysis should succeed");
|
||||
|
||||
let loop_plan = plans
|
||||
.iter()
|
||||
.find(|p| !p.relay_writes.is_empty())
|
||||
.expect("expected a loop plan with relay_writes");
|
||||
|
||||
// i/acc are defined outside the loop but rebound inside loop body -> relay_writes
|
||||
assert!(loop_plan.relay_writes.iter().any(|r| r.name == "i"));
|
||||
assert!(loop_plan.relay_writes.iter().any(|r| r.name == "acc"));
|
||||
|
||||
// n is read-only in loop condition -> capture + condition_capture
|
||||
assert!(loop_plan.captures.iter().any(|c| c.name == "n"));
|
||||
assert!(loop_plan.condition_captures.iter().any(|c| c.name == "n"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,6 +89,74 @@ pub fn plan_to_p2_inputs(
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert OwnershipPlan to P2 lowering inputs, allowing relay_writes (dev-only Phase 60).
|
||||
///
|
||||
/// Rules:
|
||||
/// - carriers = owned_vars where is_written && name != loop_var
|
||||
/// + relay_writes where name != loop_var
|
||||
/// - relay_path.len() > 1 is rejected (single-hop only)
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
pub fn plan_to_p2_inputs_with_relay(
|
||||
plan: &OwnershipPlan,
|
||||
loop_var: &str,
|
||||
) -> Result<P2LoweringInputs, String> {
|
||||
let mut carriers = Vec::new();
|
||||
|
||||
for var in &plan.owned_vars {
|
||||
if var.name == loop_var || !var.is_written {
|
||||
continue;
|
||||
}
|
||||
|
||||
let role = if var.is_condition_only {
|
||||
CarrierRole::ConditionOnly
|
||||
} else {
|
||||
CarrierRole::LoopState
|
||||
};
|
||||
|
||||
carriers.push(CarrierVar {
|
||||
name: var.name.clone(),
|
||||
role,
|
||||
init: CarrierInit::FromHost,
|
||||
host_id: crate::mir::ValueId(0),
|
||||
join_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
for relay in &plan.relay_writes {
|
||||
if relay.name == loop_var {
|
||||
continue;
|
||||
}
|
||||
if relay.relay_path.len() > 1 {
|
||||
return Err(format!(
|
||||
"Phase 60 limitation: only single-hop relay supported for P2. Var='{}' relay_path_len={}",
|
||||
relay.name,
|
||||
relay.relay_path.len()
|
||||
));
|
||||
}
|
||||
|
||||
carriers.push(CarrierVar {
|
||||
name: relay.name.clone(),
|
||||
role: CarrierRole::LoopState,
|
||||
init: CarrierInit::FromHost,
|
||||
host_id: crate::mir::ValueId(0),
|
||||
join_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
let captures: Vec<String> = plan.captures.iter().map(|c| c.name.clone()).collect();
|
||||
let condition_captures: Vec<String> = plan
|
||||
.condition_captures
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
.collect();
|
||||
|
||||
Ok(P2LoweringInputs {
|
||||
carriers,
|
||||
captures,
|
||||
condition_captures,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert OwnershipPlan to P3 (if-sum) lowering inputs.
|
||||
///
|
||||
/// P3 patterns have multiple carriers (sum, count, etc.) updated conditionally.
|
||||
@ -152,6 +220,22 @@ pub fn plan_to_p3_inputs(
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert OwnershipPlan to P3 lowering inputs, allowing relay_writes (dev-only Phase 60).
|
||||
///
|
||||
/// Rules are identical to `plan_to_p2_inputs_with_relay`, but output type is P3LoweringInputs.
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
pub fn plan_to_p3_inputs_with_relay(
|
||||
plan: &OwnershipPlan,
|
||||
loop_var: &str,
|
||||
) -> Result<P3LoweringInputs, String> {
|
||||
let p2 = plan_to_p2_inputs_with_relay(plan, loop_var)?;
|
||||
Ok(P3LoweringInputs {
|
||||
carriers: p2.carriers,
|
||||
captures: p2.captures,
|
||||
condition_captures: p2.condition_captures,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -220,6 +304,44 @@ mod tests {
|
||||
.contains("relay_writes not yet supported"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn test_relay_single_hop_accepted_in_with_relay() {
|
||||
let mut plan = OwnershipPlan::new(ScopeId(1));
|
||||
plan.owned_vars.push(ScopeOwnedVar {
|
||||
name: "i".to_string(),
|
||||
is_written: true,
|
||||
is_condition_only: false,
|
||||
});
|
||||
plan.relay_writes.push(RelayVar {
|
||||
name: "sum".to_string(),
|
||||
owner_scope: ScopeId(0),
|
||||
relay_path: vec![ScopeId(42)],
|
||||
});
|
||||
|
||||
let inputs = plan_to_p2_inputs_with_relay(&plan, "i").expect("with_relay should accept");
|
||||
assert_eq!(inputs.carriers.len(), 1);
|
||||
assert_eq!(inputs.carriers[0].name, "sum");
|
||||
assert_eq!(inputs.carriers[0].role, CarrierRole::LoopState);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn test_relay_multi_hop_rejected_in_with_relay() {
|
||||
let mut plan = OwnershipPlan::new(ScopeId(1));
|
||||
plan.relay_writes.push(RelayVar {
|
||||
name: "outer_var".to_string(),
|
||||
owner_scope: ScopeId(0),
|
||||
relay_path: vec![ScopeId(1), ScopeId(2)],
|
||||
});
|
||||
|
||||
let result = plan_to_p2_inputs_with_relay(&plan, "i");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("only single-hop relay supported"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn test_read_only_vars_not_carriers() {
|
||||
@ -402,4 +524,24 @@ mod tests {
|
||||
.unwrap_err()
|
||||
.contains("relay_writes not yet supported for P3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "normalized_dev")]
|
||||
fn test_p3_with_relay_accepts_single_hop() {
|
||||
let mut plan = OwnershipPlan::new(ScopeId(1));
|
||||
plan.owned_vars.push(ScopeOwnedVar {
|
||||
name: "i".to_string(),
|
||||
is_written: true,
|
||||
is_condition_only: false,
|
||||
});
|
||||
plan.relay_writes.push(RelayVar {
|
||||
name: "sum".to_string(),
|
||||
owner_scope: ScopeId(0),
|
||||
relay_path: vec![],
|
||||
});
|
||||
|
||||
let inputs = plan_to_p3_inputs_with_relay(&plan, "i").expect("P3 with_relay should accept");
|
||||
assert_eq!(inputs.carriers.len(), 1);
|
||||
assert_eq!(inputs.carriers[0].name, "sum");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user