AotPrep collections_hot matmul tuning and bench tweaks

This commit is contained in:
nyash-codex
2025-11-14 13:36:20 +09:00
parent 13f21334c9
commit f1fa182a4b
17 changed files with 760 additions and 219 deletions

View File

@ -427,6 +427,75 @@ static box Main {
if bundles.length() > 0 || bundle_srcs.length() > 0 || require_mods.length() > 0 {
local merged_prefix = BundleResolver.resolve(bundles, bundle_names, bundle_srcs, require_mods)
if merged_prefix == null { return 1 }
// Debug: emit line-map for merged bundles so parse error line can be mapped
{
local dbg = env.get("HAKO_STAGEB_DEBUG")
if dbg != null && ("" + dbg) == "1" {
// Count lines helper (inline)
local total = 0
{
local s2 = merged_prefix
if s2 == null { total = 0 } else {
local i=0; local n=(""+s2).length(); local c=1
loop(i<n){ if (""+s2).substring(i,i+1)=="\n" { c=c+1 } i=i+1 }
total = c
}
}
print("[stageb/line-map] prefix total lines=" + total)
// bundle-src (anonymous)
if bundles != null && bundles.length() > 0 {
local i = 0; local acc = 1
loop(i < bundles.length()) {
local seg = "" + bundles.get(i)
local ln = 0
{
local s2 = seg
if s2 == null { ln = 0 } else {
local ii=0; local nn=(""+s2).length(); local cc=1
loop(ii<nn){ if (""+s2).substring(ii,ii+1)=="\n" { cc=cc+1 } ii=ii+1 }
ln = cc
}
}
local start = acc
local finish = acc + ln - 1
print("[stageb/line-map] bundle-src[#" + i + "] " + start + ".." + finish)
acc = finish + 1
i = i + 1
}
}
// bundle-mod (named)
if bundle_names != null && bundle_srcs != null {
local i2 = 0; local acc2 = 1
if bundles != null {
// count lines of joined bundle-src
local joined = bundles.join("\n")
if joined == null { acc2 = 1 } else {
local ii=0; local nn=(""+joined).length(); local cc=1
loop(ii<nn){ if (""+joined).substring(ii,ii+1)=="\n" { cc=cc+1 } ii=ii+1 }
acc2 = cc + 1
}
}
loop(i2 < bundle_srcs.length()) {
local name = "" + bundle_names.get(i2)
local seg = "" + bundle_srcs.get(i2)
local ln = 0
{
local s2 = seg
if s2 == null { ln = 0 } else {
local ii=0; local nn=(""+s2).length(); local cc=1
loop(ii<nn){ if (""+s2).substring(ii,ii+1)=="\n" { cc=cc+1 } ii=ii+1 }
ln = cc
}
}
local start = acc2
local finish = acc2 + ln - 1
print("[stageb/line-map] bundle-mod[name=" + name + "] " + start + ".." + finish)
acc2 = finish + 1
i2 = i2 + 1
}
}
}
}
body_src = merged_prefix + body_src
}

View File

@ -4,22 +4,54 @@ using selfhost.shared.common.string_helpers as StringHelpers
using selfhost.llvm.ir.aot_prep.helpers.common as AotPrepHelpers
static box AotPrepBinopCSEBox {
// Static helpers (replace anonymous fun)
read_field(text, key) {
local needle = "\"" + key + "\":\""
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return JsonFragBox.read_string_after(text, idx + needle.length())
}
read_digits_field(text, key) {
local needle = "\"" + key + "\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
resolve_copy(copy_src, vid) {
local current = vid
local depth = 0
loop(true) {
if current == "" { break }
if !copy_src.has(current) { break }
current = copy_src.get(current)
depth = depth + 1
if depth >= 12 { break }
}
return current
}
canon_binop(op, lhs, rhs, copy_src) {
if lhs == "" || rhs == "" { return "" }
local key_lhs = AotPrepBinopCSEBox.resolve_copy(copy_src, lhs)
local key_rhs = AotPrepBinopCSEBox.resolve_copy(copy_src, rhs)
if key_lhs == "" || key_rhs == "" { return "" }
if op == "+" || op == "add" || op == "*" || op == "mul" {
local li = StringHelpers.to_i64(key_lhs)
local ri = StringHelpers.to_i64(key_rhs)
if li != null && ri != null && ri < li {
local tmp = key_lhs
key_lhs = key_rhs
key_rhs = tmp
}
}
return op + ":" + key_lhs + ":" + key_rhs
}
run(json) {
if json == null { return null }
local pos = 0
local out = json
local read_field = fun(text, key) {
local needle = "\"" + key + "\":\""
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return JsonFragBox.read_string_after(text, idx + needle.length())
}
local read_digits_field = fun(text, key) {
local needle = "\"" + key + "\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
// Track seen canonicalized binops within each block
// Key format: op:lhs:rhs with copy chains resolved and commutativity normalized
// Reset per instructions-array span
loop(true) {
local key = "\"instructions\":["
local kinst = out.indexOf(key, pos)
@ -29,83 +61,60 @@ static box AotPrepBinopCSEBox {
local rb = JsonFragBox._seek_array_end(out, lb)
if rb < 0 { break }
local body = out.substring(lb+1, rb)
local insts = []
local i = 0
// Pass-1: collect copy sources
local copy_src = new MapBox()
local seen = new MapBox()
local i1 = 0
loop(true) {
local os = body.indexOf("{", i)
local os = body.indexOf("{", i1)
if os < 0 { break }
local oe = AotPrepHelpers._seek_object_end(body, os)
if oe < 0 { break }
insts.push(body.substring(os, oe+1))
i = oe + 1
local inst1 = body.substring(os, oe+1)
local op1 = AotPrepBinopCSEBox.read_field(inst1, "op")
if op1 == "copy" {
local dst1 = AotPrepBinopCSEBox.read_digits_field(inst1, "dst")
local src1 = AotPrepBinopCSEBox.read_digits_field(inst1, "src")
if dst1 != "" && src1 != "" { copy_src.set(dst1, src1) }
}
i1 = oe + 1
}
local copy_src = {}
// Pass-2: build new body with CSE
local new_body = ""
local first = 1
local append_item = fun(item) {
if first == 0 { new_body = new_body + "," }
new_body = new_body + item
first = 0
}
for inst in insts {
local op = read_field(inst, "op")
if op == "copy" {
local dst = read_digits_field(inst, "dst")
local src = read_digits_field(inst, "src")
if dst != "" && src != "" {
copy_src[dst] = src
}
}
}
local resolve_copy = fun(vid) {
local current = vid
local depth = 0
loop(true) {
if current == "" { break }
if !copy_src.contains(current) { break }
current = copy_src[current]
depth = depth + 1
if depth >= 12 { break }
}
return current
}
local canon_binop = fun(op, lhs, rhs) {
if lhs == "" || rhs == "" { return "" }
local key_lhs = resolve_copy(lhs)
local key_rhs = resolve_copy(rhs)
if key_lhs == "" || key_rhs == "" { return "" }
if op == "+" || op == "add" || op == "*" || op == "mul" {
local li = StringHelpers.to_i64(key_lhs)
local ri = StringHelpers.to_i64(key_rhs)
if li != null && ri != null && ri < li {
local tmp = key_lhs
key_lhs = key_rhs
key_rhs = tmp
}
}
return op + ":" + key_lhs + ":" + key_rhs
}
for inst in insts {
local op = read_field(inst, "op")
local i2 = 0
loop(true) {
local os2 = body.indexOf("{", i2)
if os2 < 0 { break }
local oe2 = AotPrepHelpers._seek_object_end(body, os2)
if oe2 < 0 { break }
local inst = body.substring(os2, oe2+1)
local op = AotPrepBinopCSEBox.read_field(inst, "op")
if op == "binop" {
local operation = read_field(inst, "operation")
local lhs = read_digits_field(inst, "lhs")
local rhs = read_digits_field(inst, "rhs")
local key = canon_binop(operation, lhs, rhs)
if key != "" && seen.contains(key) {
local dst = read_digits_field(inst, "dst")
local operation = AotPrepBinopCSEBox.read_field(inst, "operation")
local lhs = AotPrepBinopCSEBox.read_digits_field(inst, "lhs")
local rhs = AotPrepBinopCSEBox.read_digits_field(inst, "rhs")
local key = AotPrepBinopCSEBox.canon_binop(operation, lhs, rhs, copy_src)
if key != "" && seen.has(key) {
local dst = AotPrepBinopCSEBox.read_digits_field(inst, "dst")
if dst != "" {
append_item("{\"op\":\"copy\",\"dst\":" + dst + ",\"src\":" + seen[key] + "}")
local item = "{\"op\":\"copy\",\"dst\":" + dst + ",\"src\":" + seen.get(key) + "}"
if first == 0 { new_body = new_body + "," }
new_body = new_body + item
first = 0
continue
}
} else if key != "" {
local dst = read_digits_field(inst, "dst")
local dst = AotPrepBinopCSEBox.read_digits_field(inst, "dst")
if dst != "" {
seen[key] = dst
seen.set(key, dst)
}
}
}
append_item(inst)
if first == 0 { new_body = new_body + "," }
new_body = new_body + inst
first = 0
i2 = oe2 + 1
}
out = out.substring(0, lb+1) + new_body + out.substring(rb, out.length())
pos = lb + new_body.length() + 1

View File

@ -17,7 +17,7 @@ static box AotPrepConstDedupBox {
local body = out.substring(lb+1, rb)
local i = 0
local new_body = ""
local first_vid_by_value = {}
local first_vid_by_value = new MapBox()
loop(i < body.length()) {
local os = body.indexOf("{", i)
if os < 0 {
@ -37,12 +37,12 @@ static box AotPrepConstDedupBox {
local kval = obj.indexOf("\"value\":{\"type\":\"i64\",\"value\":")
local vals = (kval>=0 ? StringHelpers.read_digits(obj, kval+30) : "")
if dsts != "" && vals != "" {
if first_vid_by_value.contains(vals) {
local src = first_vid_by_value[vals]
if first_vid_by_value.has(vals) {
local src = first_vid_by_value.get(vals)
local repl = "{\"op\":\"copy\",\"dst\":" + dsts + ",\"src\":" + src + "}"
new_body = new_body + repl
} else {
first_vid_by_value[vals] = dsts
first_vid_by_value.set(vals, dsts)
new_body = new_body + obj
}
} else {

View File

@ -4,42 +4,29 @@ using selfhost.shared.common.string_helpers as StringHelpers
using selfhost.llvm.ir.aot_prep.helpers.common as AotPrepHelpers // for evaluate_binop_constant
static box AotPrepLoopHoistBox {
// Static helpers (replace anonymous fun)
read_field(text, key) {
local needle = "\"" + key + "\":\""
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return JsonFragBox.read_string_after(text, idx + needle.length())
}
read_digits_field(text, key) {
local needle = "\"" + key + "\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
read_const_value(text) {
local needle = "\"value\":{\"type\":\"i64\",\"value\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
run(json) {
if json == null { return null }
local out = json
local pos = 0
local build_items = fun(body) {
local items = []
local i = 0
loop(true) {
local os = body.indexOf("{", i)
if os < 0 { break }
local oe = AotPrepHelpers._seek_object_end(body, os)
if oe < 0 { break }
items.push(body.substring(os, oe+1))
i = oe + 1
}
return items
}
local read_field = fun(text, key) {
local needle = "\"" + key + "\":\""
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return JsonFragBox.read_string_after(text, idx + needle.length())
}
local read_digits_field = fun(text, key) {
local needle = "\"" + key + "\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
local read_const_value = fun(text) {
local needle = "\"value\":{\"type\":\"i64\",\"value\":"
local idx = text.indexOf(needle)
if idx < 0 { return "" }
return StringHelpers.read_digits(text, idx + needle.length())
}
local ref_fields = ["lhs", "rhs", "cond", "target"]
loop(true) {
local key = "\"instructions\":["
local kinst = out.indexOf(key, pos)
@ -49,62 +36,125 @@ static box AotPrepLoopHoistBox {
local rb = JsonFragBox._seek_array_end(out, lb)
if rb < 0 { break }
local body = out.substring(lb+1, rb)
local insts = build_items(body)
local const_defs = {}
local const_vals = {}
for inst in insts {
local op = read_field(inst, "op")
if op == "const" {
local dst = read_digits_field(inst, "dst")
local val = read_const_value(inst)
if dst != "" && val != "" { const_defs[dst] = inst; const_vals[dst] = val }
local const_defs = new MapBox()
local const_vals = new MapBox()
// Pass-1: collect const defs
{
local i0 = 0
loop(true) {
local os0 = body.indexOf("{", i0)
if os0 < 0 { break }
local oe0 = AotPrepHelpers._seek_object_end(body, os0)
if oe0 < 0 { break }
local inst0 = body.substring(os0, oe0+1)
local op0 = AotPrepLoopHoistBox.read_field(inst0, "op")
if op0 == "const" {
local dst0 = AotPrepLoopHoistBox.read_digits_field(inst0, "dst")
local val0 = AotPrepLoopHoistBox.read_const_value(inst0)
if dst0 != "" && val0 != "" { const_defs.set(dst0, inst0); const_vals.set(dst0, val0) }
}
i0 = oe0 + 1
}
}
local folded = true
while folded {
folded = false
for inst in insts {
local op = read_field(inst, "op")
if op != "binop" { continue }
local dst = read_digits_field(inst, "dst")
if dst == "" || const_vals.contains(dst) { continue }
local lhs = read_digits_field(inst, "lhs")
local rhs = read_digits_field(inst, "rhs")
local operation = read_field(inst, "operation")
if lhs == "" || rhs == "" || operation == "" { continue }
local lhs_val = const_vals.contains(lhs) ? const_vals[lhs] : ""
local rhs_val = const_vals.contains(rhs) ? const_vals[rhs] : ""
if lhs_val == "" || rhs_val == "" { continue }
local computed = AotPrepHelpers.evaluate_binop_constant(operation, lhs_val, rhs_val)
if computed == "" { continue }
const_defs[dst] = inst
const_vals[dst] = computed
folded = true
local i1 = 0
loop(true) {
local os1 = body.indexOf("{", i1)
if os1 < 0 { break }
local oe1 = AotPrepHelpers._seek_object_end(body, os1)
if oe1 < 0 { break }
local inst1 = body.substring(os1, oe1+1)
local op1 = AotPrepLoopHoistBox.read_field(inst1, "op")
if op1 == "binop" {
local dst1 = AotPrepLoopHoistBox.read_digits_field(inst1, "dst")
if dst1 != "" && !const_vals.has(dst1) {
local lhs1 = AotPrepLoopHoistBox.read_digits_field(inst1, "lhs")
local rhs1 = AotPrepLoopHoistBox.read_digits_field(inst1, "rhs")
local operation1 = AotPrepLoopHoistBox.read_field(inst1, "operation")
if lhs1 != "" && rhs1 != "" && operation1 != "" {
local lhs_val1 = const_vals.has(lhs1) ? const_vals.get(lhs1) : ""
local rhs_val1 = const_vals.has(rhs1) ? const_vals.get(rhs1) : ""
if lhs_val1 != "" && rhs_val1 != "" {
local computed1 = AotPrepHelpers.evaluate_binop_constant(operation1, lhs_val1, rhs_val1)
if computed1 != "" {
const_defs.set(dst1, inst1)
const_vals.set(dst1, computed1)
folded = true
}
}
}
}
}
i1 = oe1 + 1
}
}
local needed = {}
for inst in insts {
local op = read_field(inst, "op")
if op == "const" { continue }
for field in ref_fields {
local ref = read_digits_field(inst, field)
if ref != "" && const_defs.contains(ref) { needed[ref] = true }
local needed = new MapBox()
{
local i2 = 0
loop(true) {
local os2 = body.indexOf("{", i2)
if os2 < 0 { break }
local oe2 = AotPrepHelpers._seek_object_end(body, os2)
if oe2 < 0 { break }
local inst2 = body.substring(os2, oe2+1)
local op2 = AotPrepLoopHoistBox.read_field(inst2, "op")
if op2 != "const" {
// check lhs, rhs, cond, target
local rf = AotPrepLoopHoistBox.read_digits_field(inst2, "lhs")
if rf != "" && const_defs.has(rf) { needed.set(rf, true) }
rf = AotPrepLoopHoistBox.read_digits_field(inst2, "rhs")
if rf != "" && const_defs.has(rf) { needed.set(rf, true) }
rf = AotPrepLoopHoistBox.read_digits_field(inst2, "cond")
if rf != "" && const_defs.has(rf) { needed.set(rf, true) }
rf = AotPrepLoopHoistBox.read_digits_field(inst2, "target")
if rf != "" && const_defs.has(rf) { needed.set(rf, true) }
}
i2 = oe2 + 1
}
}
if needed.size() == 0 { pos = rb + 1; continue }
local hoist_items = []
local keep_items = []
for inst in insts {
local dst = read_digits_field(inst, "dst")
if dst != "" && needed.contains(dst) && const_defs.contains(dst) { hoist_items.push(inst); continue }
keep_items.push(inst)
}
if hoist_items.size() == 0 { pos = rb + 1; continue }
// Build merged: hoist first, then keep (two scans)
local any_hoist = 0
local merged = ""
local first = 1
local append_item = fun(item) { if first == 0 { merged = merged + "," } merged = merged + item; first = 0 }
for item in hoist_items { append_item(item) }
for item in keep_items { append_item(item) }
{
local i3 = 0
loop(true) {
local os3 = body.indexOf("{", i3)
if os3 < 0 { break }
local oe3 = AotPrepHelpers._seek_object_end(body, os3)
if oe3 < 0 { break }
local inst3 = body.substring(os3, oe3+1)
local dst3 = AotPrepLoopHoistBox.read_digits_field(inst3, "dst")
if dst3 != "" && needed.has(dst3) && const_defs.has(dst3) {
if first == 0 { merged = merged + "," }
merged = merged + inst3
first = 0
any_hoist = 1
}
i3 = oe3 + 1
}
}
if any_hoist == 0 { pos = rb + 1; continue }
{
local i4 = 0
loop(true) {
local os4 = body.indexOf("{", i4)
if os4 < 0 { break }
local oe4 = AotPrepHelpers._seek_object_end(body, os4)
if oe4 < 0 { break }
local inst4 = body.substring(os4, oe4+1)
local dst4 = AotPrepLoopHoistBox.read_digits_field(inst4, "dst")
if !(dst4 != "" && needed.has(dst4) && const_defs.has(dst4)) {
if first == 0 { merged = merged + "," }
merged = merged + inst4
first = 0
}
i4 = oe4 + 1
}
}
out = out.substring(0, lb+1) + merged + out.substring(rb, out.length())
pos = lb + merged.length() + 1
}