Files
hakmem/core/hakmem_ace_ucb1.h

72 lines
2.4 KiB
C
Raw Normal View History

/* hakmem_ace_ucb1.h - UCB1 Multi-Armed Bandit Learning
*
* Upper Confidence Bound (UCB1) :
* - arm
* - Exploration () Exploitation ()
* - UCB value = avg_reward + c * sqrt(log(total_pulls) / n_pulls)
*/
#ifndef HAKMEM_ACE_UCB1_H
#define HAKMEM_ACE_UCB1_H
#include <stdint.h>
#include <stdbool.h>
/* 最大アーム数 */
#define UCB1_MAX_ARMS 16
/* UCB1 アーム1つの候補値 */
struct hkm_ucb1_arm {
uint32_t value; /* 候補値(例: TLS capacity = 32, 64, 128... */
double avg_reward; /* 平均報酬 */
uint32_t n_pulls; /* 選択回数 */
};
/* UCB1 バンディット */
struct hkm_ucb1_bandit {
struct hkm_ucb1_arm arms[UCB1_MAX_ARMS]; /* アームの配列 */
uint32_t n_arms; /* アーム数 */
uint32_t total_pulls; /* 総選択回数 */
double exploration_bonus; /* 探索ボーナス(通常 sqrt(2) */
};
/* ========== API ========== */
/* 初期化: 候補値の配列を渡す
* : uint32_t candidates[] = {32, 64, 128, 256};
* hkm_ucb1_init(&bandit, candidates, 4, 1.414); // sqrt(2)
*/
void hkm_ucb1_init(struct hkm_ucb1_bandit *bandit,
const uint32_t *candidates,
uint32_t n_candidates,
double exploration_bonus);
/* アーム選択: UCB値が最大のアームを選ぶ
* : (0 ... n_arms-1)
*/
int hkm_ucb1_select(struct hkm_ucb1_bandit *bandit);
/* 報酬更新: 選択したアームの報酬を記録
* arm_idx: hkm_ucb1_select()
* reward:
*/
void hkm_ucb1_update(struct hkm_ucb1_bandit *bandit,
int arm_idx,
double reward);
/* アームの値を取得inline*/
static inline uint32_t hkm_ucb1_get_value(const struct hkm_ucb1_bandit *bandit, int arm_idx) {
if (arm_idx < 0 || arm_idx >= (int)bandit->n_arms) {
return 0;
}
return bandit->arms[arm_idx].value;
}
/* 最良アーム取得(現時点で最高平均報酬)*/
int hkm_ucb1_best_arm(const struct hkm_ucb1_bandit *bandit);
/* デバッグ出力 */
void hkm_ucb1_print(const struct hkm_ucb1_bandit *bandit, const char *name);
#endif /* HAKMEM_ACE_UCB1_H */