/* hakmem_ace_ucb1.c - UCB1 Multi-Armed Bandit Implementation */ #include "hakmem_ace_ucb1.h" #include #include #include /* ========== 初期化 ========== */ void hkm_ucb1_init(struct hkm_ucb1_bandit *bandit, const uint32_t *candidates, uint32_t n_candidates, double exploration_bonus) { memset(bandit, 0, sizeof(*bandit)); if (n_candidates > UCB1_MAX_ARMS) { n_candidates = UCB1_MAX_ARMS; } bandit->n_arms = n_candidates; bandit->total_pulls = 0; bandit->exploration_bonus = exploration_bonus; for (uint32_t i = 0; i < n_candidates; i++) { bandit->arms[i].value = candidates[i]; bandit->arms[i].avg_reward = 0.0; bandit->arms[i].n_pulls = 0; } } /* ========== UCB値計算 ========== */ static inline double compute_ucb_value(const struct hkm_ucb1_bandit *bandit, int arm_idx) { const struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx]; /* 初回選択時は無限大(探索優先) */ if (arm->n_pulls == 0) { return INFINITY; } /* UCB = avg_reward + c * sqrt(log(total) / n) */ double exploitation = arm->avg_reward; double exploration = bandit->exploration_bonus * sqrt(log((double)bandit->total_pulls) / (double)arm->n_pulls); return exploitation + exploration; } /* ========== アーム選択 ========== */ int hkm_ucb1_select(struct hkm_ucb1_bandit *bandit) { if (bandit->n_arms == 0) { return -1; } /* 各アームのUCB値を計算し、最大のものを選ぶ */ int best_arm = 0; double best_ucb = compute_ucb_value(bandit, 0); for (uint32_t i = 1; i < bandit->n_arms; i++) { double ucb = compute_ucb_value(bandit, i); if (ucb > best_ucb) { best_ucb = ucb; best_arm = (int)i; } } return best_arm; } /* ========== 報酬更新 ========== */ void hkm_ucb1_update(struct hkm_ucb1_bandit *bandit, int arm_idx, double reward) { if (arm_idx < 0 || arm_idx >= (int)bandit->n_arms) { return; } struct hkm_ucb1_arm *arm = &bandit->arms[arm_idx]; /* Running average: * new_avg = (old_avg * n + reward) / (n + 1) * = old_avg + (reward - old_avg) / (n + 1) */ double n = (double)arm->n_pulls; arm->avg_reward = (arm->avg_reward * n + reward) / (n + 1.0); arm->n_pulls++; bandit->total_pulls++; } /* ========== 最良アーム取得 ========== */ int hkm_ucb1_best_arm(const struct hkm_ucb1_bandit *bandit) { if (bandit->n_arms == 0) { return -1; } /* 最高平均報酬のアームを返す */ int best = 0; double best_avg = bandit->arms[0].avg_reward; for (uint32_t i = 1; i < bandit->n_arms; i++) { if (bandit->arms[i].avg_reward > best_avg) { best_avg = bandit->arms[i].avg_reward; best = (int)i; } } return best; } /* ========== デバッグ出力 ========== */ void hkm_ucb1_print(const struct hkm_ucb1_bandit *bandit, const char *name) { fprintf(stderr, "[UCB1 %s] total_pulls=%u, c=%.3f\n", name, bandit->total_pulls, bandit->exploration_bonus); for (uint32_t i = 0; i < bandit->n_arms; i++) { const struct hkm_ucb1_arm *arm = &bandit->arms[i]; fprintf(stderr, " arm[%u] value=%u avg_reward=%.4f n_pulls=%u\n", i, arm->value, arm->avg_reward, arm->n_pulls); } }