Files
hakmem/core/hakmem_ace_ucb1.c

125 lines
3.5 KiB
C
Raw Normal View History

/* hakmem_ace_ucb1.c - UCB1 Multi-Armed Bandit Implementation */
#include "hakmem_ace_ucb1.h"
#include <stdio.h>
#include <string.h>
#include <math.h>
/* ========== 初期化 ========== */
void hkm_ucb1_init(struct hkm_ucb1_bandit *bandit,
const uint32_t *candidates,
uint32_t n_candidates,
double exploration_bonus) {
memset(bandit, 0, sizeof(*bandit));
if (n_candidates > UCB1_MAX_ARMS) {
n_candidates = UCB1_MAX_ARMS;
}
bandit->n_arms = n_candidates;
bandit->total_pulls = 0;
bandit->exploration_bonus = exploration_bonus;
for (uint32_t i = 0; i < n_candidates; i++) {
bandit->arms[i].value = candidates[i];
bandit->arms[i].avg_reward = 0.0;
bandit->arms[i].n_pulls = 0;
}
}
/* ========== UCB値計算 ========== */
static inline double compute_ucb_value(const struct hkm_ucb1_bandit *bandit, int arm_idx) {
const struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx];
/* 初回選択時は無限大(探索優先) */
if (arm->n_pulls == 0) {
return INFINITY;
}
/* UCB = avg_reward + c * sqrt(log(total) / n) */
double exploitation = arm->avg_reward;
double exploration = bandit->exploration_bonus *
sqrt(log((double)bandit->total_pulls) / (double)arm->n_pulls);
return exploitation + exploration;
}
/* ========== アーム選択 ========== */
int hkm_ucb1_select(struct hkm_ucb1_bandit *bandit) {
if (bandit->n_arms == 0) {
return -1;
}
/* 各アームのUCB値を計算し、最大のものを選ぶ */
int best_arm = 0;
double best_ucb = compute_ucb_value(bandit, 0);
for (uint32_t i = 1; i < bandit->n_arms; i++) {
double ucb = compute_ucb_value(bandit, i);
if (ucb > best_ucb) {
best_ucb = ucb;
best_arm = (int)i;
}
}
return best_arm;
}
/* ========== 報酬更新 ========== */
void hkm_ucb1_update(struct hkm_ucb1_bandit *bandit,
int arm_idx,
double reward) {
if (arm_idx < 0 || arm_idx >= (int)bandit->n_arms) {
return;
}
struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx];
/* Running average:
* new_avg = (old_avg * n + reward) / (n + 1)
* = old_avg + (reward - old_avg) / (n + 1)
*/
double n = (double)arm->n_pulls;
arm->avg_reward = (arm->avg_reward * n + reward) / (n + 1.0);
arm->n_pulls++;
bandit->total_pulls++;
}
/* ========== 最良アーム取得 ========== */
int hkm_ucb1_best_arm(const struct hkm_ucb1_bandit *bandit) {
if (bandit->n_arms == 0) {
return -1;
}
/* 最高平均報酬のアームを返す */
int best = 0;
double best_avg = bandit->arms[0].avg_reward;
for (uint32_t i = 1; i < bandit->n_arms; i++) {
if (bandit->arms[i].avg_reward > best_avg) {
best_avg = bandit->arms[i].avg_reward;
best = (int)i;
}
}
return best;
}
/* ========== デバッグ出力 ========== */
void hkm_ucb1_print(const struct hkm_ucb1_bandit *bandit, const char *name) {
fprintf(stderr, "[UCB1 %s] total_pulls=%u, c=%.3f\n",
name, bandit->total_pulls, bandit->exploration_bonus);
for (uint32_t i = 0; i < bandit->n_arms; i++) {
const struct hkm_ucb1_arm *arm = &bandit->arms[i];
fprintf(stderr, " arm[%u] value=%u avg_reward=%.4f n_pulls=%u\n",
i, arm->value, arm->avg_reward, arm->n_pulls);
}
}