Files
hakorune/src/mir/join_ir_runner.rs
2025-11-23 08:38:15 +09:00

328 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! JoinIR 実験用のミニ実行器Phase 27.2
//!
//! 目的: hand-written / minimal JoinIR を VM と A/B 比較するための軽量ランナー。
//! - 対応値: i64 / bool / String / Unit
//! - 対応命令: Const / BinOp / Compare / BoxCall(StringBox: length, substring) /
//! Call / Jump / Ret
use std::collections::HashMap;
use crate::mir::join_ir::{
BinOpKind, CompareOp, ConstValue, JoinFuncId, JoinInst, JoinModule, MirLikeInst, VarId,
};
#[derive(Debug, Clone, PartialEq)]
pub enum JoinValue {
Int(i64),
Bool(bool),
Str(String),
Unit,
}
#[derive(Debug, Clone)]
pub struct JoinRuntimeError {
pub message: String,
}
impl JoinRuntimeError {
fn new(msg: impl Into<String>) -> Self {
Self {
message: msg.into(),
}
}
}
impl std::fmt::Display for JoinRuntimeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for JoinRuntimeError {}
pub fn run_joinir_function(
module: &JoinModule,
entry: JoinFuncId,
args: &[JoinValue],
) -> Result<JoinValue, JoinRuntimeError> {
execute_function(module, entry, args.to_vec())
}
fn execute_function(
module: &JoinModule,
mut current_func: JoinFuncId,
mut current_args: Vec<JoinValue>,
) -> Result<JoinValue, JoinRuntimeError> {
'exec: loop {
let func = module
.functions
.get(&current_func)
.ok_or_else(|| JoinRuntimeError::new(format!("Function {:?} not found", current_func)))?;
if func.params.len() != current_args.len() {
return Err(JoinRuntimeError::new(format!(
"Arity mismatch for {:?}: expected {}, got {}",
func.id,
func.params.len(),
current_args.len()
)));
}
let mut locals: HashMap<VarId, JoinValue> = HashMap::new();
for (param, arg) in func.params.iter().zip(current_args.iter()) {
locals.insert(*param, arg.clone());
}
let mut ip = 0usize;
while ip < func.body.len() {
match &func.body[ip] {
JoinInst::Compute(inst) => {
eval_compute(inst, &mut locals)?;
ip += 1;
}
JoinInst::Call {
func: target,
args,
k_next,
dst,
} => {
if k_next.is_some() {
return Err(JoinRuntimeError::new(
"Join continuation (k_next) is not supported in the experimental runner",
));
}
let resolved_args = materialize_args(args, &locals)?;
if let Some(dst_var) = dst {
let value = execute_function(module, *target, resolved_args)?;
locals.insert(*dst_var, value);
ip += 1;
} else {
current_func = *target;
current_args = resolved_args;
continue 'exec;
}
}
JoinInst::Jump { cont: _, args, cond } => {
let should_jump = match cond {
Some(var) => as_bool(&read_var(&locals, *var)?)?,
None => true,
};
if should_jump {
let ret = if let Some(first) = args.first() {
read_var(&locals, *first)?
} else {
JoinValue::Unit
};
return Ok(ret);
}
ip += 1;
}
JoinInst::Ret { value } => {
let ret = match value {
Some(var) => read_var(&locals, *var)?,
None => JoinValue::Unit,
};
return Ok(ret);
}
}
}
// fallthrough without explicit return
return Ok(JoinValue::Unit);
}
}
fn eval_compute(inst: &MirLikeInst, locals: &mut HashMap<VarId, JoinValue>) -> Result<(), JoinRuntimeError> {
match inst {
MirLikeInst::Const { dst, value } => {
let v = match value {
ConstValue::Integer(i) => JoinValue::Int(*i),
ConstValue::Bool(b) => JoinValue::Bool(*b),
ConstValue::String(s) => JoinValue::Str(s.clone()),
ConstValue::Null => JoinValue::Unit,
};
locals.insert(*dst, v);
}
MirLikeInst::BinOp { dst, op, lhs, rhs } => {
let l = read_var(locals, *lhs)?;
let r = read_var(locals, *rhs)?;
let v = match op {
BinOpKind::Add => match (l, r) {
(JoinValue::Int(a), JoinValue::Int(b)) => JoinValue::Int(a + b),
(JoinValue::Str(a), JoinValue::Str(b)) => JoinValue::Str(format!("{a}{b}")),
_ => {
return Err(JoinRuntimeError::new(
"Add supported only for (int,int) or (str,str)",
))
}
},
BinOpKind::Sub => match (l, r) {
(JoinValue::Int(a), JoinValue::Int(b)) => JoinValue::Int(a - b),
_ => return Err(JoinRuntimeError::new("Sub supported only for integers")),
},
BinOpKind::Mul => match (l, r) {
(JoinValue::Int(a), JoinValue::Int(b)) => JoinValue::Int(a * b),
_ => return Err(JoinRuntimeError::new("Mul supported only for integers")),
},
BinOpKind::Div => match (l, r) {
(JoinValue::Int(_), JoinValue::Int(0)) => {
return Err(JoinRuntimeError::new("Division by zero"))
}
(JoinValue::Int(a), JoinValue::Int(b)) => JoinValue::Int(a / b),
_ => return Err(JoinRuntimeError::new("Div supported only for integers")),
},
BinOpKind::Or => match (l, r) {
(JoinValue::Bool(a), JoinValue::Bool(b)) => JoinValue::Bool(a || b),
_ => return Err(JoinRuntimeError::new("Or supported only for bools")),
},
BinOpKind::And => match (l, r) {
(JoinValue::Bool(a), JoinValue::Bool(b)) => JoinValue::Bool(a && b),
_ => return Err(JoinRuntimeError::new("And supported only for bools")),
},
};
locals.insert(*dst, v);
}
MirLikeInst::Compare { dst, op, lhs, rhs } => {
let l = read_var(locals, *lhs)?;
let r = read_var(locals, *rhs)?;
let v = match (l, r) {
(JoinValue::Int(a), JoinValue::Int(b)) => match op {
CompareOp::Lt => a < b,
CompareOp::Le => a <= b,
CompareOp::Gt => a > b,
CompareOp::Ge => a >= b,
CompareOp::Eq => a == b,
CompareOp::Ne => a != b,
},
(JoinValue::Bool(a), JoinValue::Bool(b)) => match op {
CompareOp::Eq => a == b,
CompareOp::Ne => a != b,
_ => {
return Err(JoinRuntimeError::new(
"Bool comparison only supports Eq/Ne in the JoinIR runner",
))
}
},
(JoinValue::Str(a), JoinValue::Str(b)) => match op {
CompareOp::Eq => a == b,
CompareOp::Ne => a != b,
_ => {
return Err(JoinRuntimeError::new(
"String comparison only supports Eq/Ne in the JoinIR runner",
))
}
},
_ => {
return Err(JoinRuntimeError::new(
"Type mismatch in Compare (expected homogeneous operands)",
))
}
};
locals.insert(*dst, JoinValue::Bool(v));
}
MirLikeInst::BoxCall {
dst,
box_name,
method,
args,
} => {
if box_name != "StringBox" {
return Err(JoinRuntimeError::new(format!(
"Unsupported box call target: {}",
box_name
)));
}
match method.as_str() {
"length" => {
let arg = expect_str(&read_var(locals, args[0])?)?;
locals.insert(*dst.as_ref().ok_or_else(|| {
JoinRuntimeError::new("length call requires destination")
})?, JoinValue::Int(arg.len() as i64));
}
"substring" => {
if args.len() != 3 {
return Err(JoinRuntimeError::new(
"substring expects 3 arguments (s, start, end)",
));
}
let s = expect_str(&read_var(locals, args[0])?)?;
let start = expect_int(&read_var(locals, args[1])?)?;
let end = expect_int(&read_var(locals, args[2])?)?;
let slice = safe_substring(&s, start, end)?;
let dst_var = dst.ok_or_else(|| {
JoinRuntimeError::new("substring call requires destination")
})?;
locals.insert(dst_var, JoinValue::Str(slice));
}
_ => {
return Err(JoinRuntimeError::new(format!(
"Unsupported StringBox method: {}",
method
)))
}
}
}
}
Ok(())
}
fn safe_substring(s: &str, start: i64, end: i64) -> Result<String, JoinRuntimeError> {
if start < 0 || end < 0 {
return Err(JoinRuntimeError::new("substring indices must be non-negative"));
}
let (start_usize, end_usize) = (start as usize, end as usize);
if start_usize > end_usize {
return Err(JoinRuntimeError::new("substring start > end"));
}
if start_usize > s.len() || end_usize > s.len() {
return Err(JoinRuntimeError::new("substring indices out of bounds"));
}
Ok(s[start_usize..end_usize].to_string())
}
fn read_var(locals: &HashMap<VarId, JoinValue>, var: VarId) -> Result<JoinValue, JoinRuntimeError> {
locals
.get(&var)
.cloned()
.ok_or_else(|| JoinRuntimeError::new(format!("Variable {:?} not bound", var)))
}
fn materialize_args(
args: &[VarId],
locals: &HashMap<VarId, JoinValue>,
) -> Result<Vec<JoinValue>, JoinRuntimeError> {
args.iter().map(|v| read_var(locals, *v)).collect()
}
fn as_bool(value: &JoinValue) -> Result<bool, JoinRuntimeError> {
match value {
JoinValue::Bool(b) => Ok(*b),
JoinValue::Int(i) => Ok(*i != 0),
JoinValue::Unit => Ok(false),
other => Err(JoinRuntimeError::new(format!(
"Expected bool-compatible value, got {:?}",
other
))),
}
}
fn expect_int(value: &JoinValue) -> Result<i64, JoinRuntimeError> {
match value {
JoinValue::Int(i) => Ok(*i),
other => Err(JoinRuntimeError::new(format!(
"Expected int, got {:?}",
other
))),
}
}
fn expect_str(value: &JoinValue) -> Result<String, JoinRuntimeError> {
match value {
JoinValue::Str(s) => Ok(s.clone()),
other => Err(JoinRuntimeError::new(format!(
"Expected string, got {:?}",
other
))),
}
}