125 lines
3.5 KiB
C
125 lines
3.5 KiB
C
|
|
/* hakmem_ace_ucb1.c - UCB1 Multi-Armed Bandit Implementation */
|
||
|
|
|
||
|
|
#include "hakmem_ace_ucb1.h"
|
||
|
|
#include <stdio.h>
|
||
|
|
#include <string.h>
|
||
|
|
#include <math.h>
|
||
|
|
|
||
|
|
/* ========== 初期化 ========== */
|
||
|
|
|
||
|
|
void hkm_ucb1_init(struct hkm_ucb1_bandit *bandit,
|
||
|
|
const uint32_t *candidates,
|
||
|
|
uint32_t n_candidates,
|
||
|
|
double exploration_bonus) {
|
||
|
|
memset(bandit, 0, sizeof(*bandit));
|
||
|
|
|
||
|
|
if (n_candidates > UCB1_MAX_ARMS) {
|
||
|
|
n_candidates = UCB1_MAX_ARMS;
|
||
|
|
}
|
||
|
|
|
||
|
|
bandit->n_arms = n_candidates;
|
||
|
|
bandit->total_pulls = 0;
|
||
|
|
bandit->exploration_bonus = exploration_bonus;
|
||
|
|
|
||
|
|
for (uint32_t i = 0; i < n_candidates; i++) {
|
||
|
|
bandit->arms[i].value = candidates[i];
|
||
|
|
bandit->arms[i].avg_reward = 0.0;
|
||
|
|
bandit->arms[i].n_pulls = 0;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/* ========== UCB値計算 ========== */
|
||
|
|
|
||
|
|
static inline double compute_ucb_value(const struct hkm_ucb1_bandit *bandit, int arm_idx) {
|
||
|
|
const struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx];
|
||
|
|
|
||
|
|
/* 初回選択時は無限大(探索優先) */
|
||
|
|
if (arm->n_pulls == 0) {
|
||
|
|
return INFINITY;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* UCB = avg_reward + c * sqrt(log(total) / n) */
|
||
|
|
double exploitation = arm->avg_reward;
|
||
|
|
double exploration = bandit->exploration_bonus *
|
||
|
|
sqrt(log((double)bandit->total_pulls) / (double)arm->n_pulls);
|
||
|
|
|
||
|
|
return exploitation + exploration;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* ========== アーム選択 ========== */
|
||
|
|
|
||
|
|
int hkm_ucb1_select(struct hkm_ucb1_bandit *bandit) {
|
||
|
|
if (bandit->n_arms == 0) {
|
||
|
|
return -1;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* 各アームのUCB値を計算し、最大のものを選ぶ */
|
||
|
|
int best_arm = 0;
|
||
|
|
double best_ucb = compute_ucb_value(bandit, 0);
|
||
|
|
|
||
|
|
for (uint32_t i = 1; i < bandit->n_arms; i++) {
|
||
|
|
double ucb = compute_ucb_value(bandit, i);
|
||
|
|
if (ucb > best_ucb) {
|
||
|
|
best_ucb = ucb;
|
||
|
|
best_arm = (int)i;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return best_arm;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* ========== 報酬更新 ========== */
|
||
|
|
|
||
|
|
void hkm_ucb1_update(struct hkm_ucb1_bandit *bandit,
|
||
|
|
int arm_idx,
|
||
|
|
double reward) {
|
||
|
|
if (arm_idx < 0 || arm_idx >= (int)bandit->n_arms) {
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx];
|
||
|
|
|
||
|
|
/* Running average:
|
||
|
|
* new_avg = (old_avg * n + reward) / (n + 1)
|
||
|
|
* = old_avg + (reward - old_avg) / (n + 1)
|
||
|
|
*/
|
||
|
|
double n = (double)arm->n_pulls;
|
||
|
|
arm->avg_reward = (arm->avg_reward * n + reward) / (n + 1.0);
|
||
|
|
arm->n_pulls++;
|
||
|
|
bandit->total_pulls++;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* ========== 最良アーム取得 ========== */
|
||
|
|
|
||
|
|
int hkm_ucb1_best_arm(const struct hkm_ucb1_bandit *bandit) {
|
||
|
|
if (bandit->n_arms == 0) {
|
||
|
|
return -1;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* 最高平均報酬のアームを返す */
|
||
|
|
int best = 0;
|
||
|
|
double best_avg = bandit->arms[0].avg_reward;
|
||
|
|
|
||
|
|
for (uint32_t i = 1; i < bandit->n_arms; i++) {
|
||
|
|
if (bandit->arms[i].avg_reward > best_avg) {
|
||
|
|
best_avg = bandit->arms[i].avg_reward;
|
||
|
|
best = (int)i;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return best;
|
||
|
|
}
|
||
|
|
|
||
|
|
/* ========== デバッグ出力 ========== */
|
||
|
|
|
||
|
|
void hkm_ucb1_print(const struct hkm_ucb1_bandit *bandit, const char *name) {
|
||
|
|
fprintf(stderr, "[UCB1 %s] total_pulls=%u, c=%.3f\n",
|
||
|
|
name, bandit->total_pulls, bandit->exploration_bonus);
|
||
|
|
|
||
|
|
for (uint32_t i = 0; i < bandit->n_arms; i++) {
|
||
|
|
const struct hkm_ucb1_arm *arm = &bandit->arms[i];
|
||
|
|
fprintf(stderr, " arm[%u] value=%u avg_reward=%.4f n_pulls=%u\n",
|
||
|
|
i, arm->value, arm->avg_reward, arm->n_pulls);
|
||
|
|
}
|
||
|
|
}
|