72 lines
2.4 KiB
C
72 lines
2.4 KiB
C
|
|
/* hakmem_ace_ucb1.h - UCB1 Multi-Armed Bandit Learning
|
|||
|
|
*
|
|||
|
|
* Upper Confidence Bound (UCB1) アルゴリズム:
|
|||
|
|
* - 複数の候補値(arm)から最適なものを学習
|
|||
|
|
* - Exploration (探索) と Exploitation (活用) のバランス
|
|||
|
|
* - UCB value = avg_reward + c * sqrt(log(total_pulls) / n_pulls)
|
|||
|
|
*/
|
|||
|
|
|
|||
|
|
#ifndef HAKMEM_ACE_UCB1_H
|
|||
|
|
#define HAKMEM_ACE_UCB1_H
|
|||
|
|
|
|||
|
|
#include <stdint.h>
|
|||
|
|
#include <stdbool.h>
|
|||
|
|
|
|||
|
|
/* 最大アーム数 */
|
|||
|
|
#define UCB1_MAX_ARMS 16
|
|||
|
|
|
|||
|
|
/* UCB1 アーム(1つの候補値) */
|
|||
|
|
struct hkm_ucb1_arm {
|
|||
|
|
uint32_t value; /* 候補値(例: TLS capacity = 32, 64, 128...) */
|
|||
|
|
double avg_reward; /* 平均報酬 */
|
|||
|
|
uint32_t n_pulls; /* 選択回数 */
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
/* UCB1 バンディット */
|
|||
|
|
struct hkm_ucb1_bandit {
|
|||
|
|
struct hkm_ucb1_arm arms[UCB1_MAX_ARMS]; /* アームの配列 */
|
|||
|
|
uint32_t n_arms; /* アーム数 */
|
|||
|
|
uint32_t total_pulls; /* 総選択回数 */
|
|||
|
|
double exploration_bonus; /* 探索ボーナス(通常 sqrt(2)) */
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
/* ========== API ========== */
|
|||
|
|
|
|||
|
|
/* 初期化: 候補値の配列を渡す
|
|||
|
|
* 例: uint32_t candidates[] = {32, 64, 128, 256};
|
|||
|
|
* hkm_ucb1_init(&bandit, candidates, 4, 1.414); // sqrt(2)
|
|||
|
|
*/
|
|||
|
|
void hkm_ucb1_init(struct hkm_ucb1_bandit *bandit,
|
|||
|
|
const uint32_t *candidates,
|
|||
|
|
uint32_t n_candidates,
|
|||
|
|
double exploration_bonus);
|
|||
|
|
|
|||
|
|
/* アーム選択: UCB値が最大のアームを選ぶ
|
|||
|
|
* 戻り値: アームのインデックス (0 ... n_arms-1)
|
|||
|
|
*/
|
|||
|
|
int hkm_ucb1_select(struct hkm_ucb1_bandit *bandit);
|
|||
|
|
|
|||
|
|
/* 報酬更新: 選択したアームの報酬を記録
|
|||
|
|
* arm_idx: hkm_ucb1_select() の戻り値
|
|||
|
|
* reward: 報酬値(高いほど良い)
|
|||
|
|
*/
|
|||
|
|
void hkm_ucb1_update(struct hkm_ucb1_bandit *bandit,
|
|||
|
|
int arm_idx,
|
|||
|
|
double reward);
|
|||
|
|
|
|||
|
|
/* アームの値を取得(inline)*/
|
|||
|
|
static inline uint32_t hkm_ucb1_get_value(const struct hkm_ucb1_bandit *bandit, int arm_idx) {
|
|||
|
|
if (arm_idx < 0 || arm_idx >= (int)bandit->n_arms) {
|
|||
|
|
return 0;
|
|||
|
|
}
|
|||
|
|
return bandit->arms[arm_idx].value;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/* 最良アーム取得(現時点で最高平均報酬)*/
|
|||
|
|
int hkm_ucb1_best_arm(const struct hkm_ucb1_bandit *bandit);
|
|||
|
|
|
|||
|
|
/* デバッグ出力 */
|
|||
|
|
void hkm_ucb1_print(const struct hkm_ucb1_bandit *bandit, const char *name);
|
|||
|
|
|
|||
|
|
#endif /* HAKMEM_ACE_UCB1_H */
|