-
Notifications
You must be signed in to change notification settings - Fork 0
/
bandit_player.rb
93 lines (75 loc) · 1.86 KB
/
bandit_player.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
module Bandit
class Player
def initialize(how_greedy:)
@how_greedy = how_greedy
end
def plays(bandit)
bandit_memory = find_or_initialize_bandit_memories(bandit)
action = if greedy_this_time
bandit_memory.best_action
else
bandit_memory.random_action
end
reward = bandit.receive(action.id)
bandit_memory.re_evaluate(action.id, with: reward)
reward
end
def id
"#{@how_greedy}"
end
private
def find_or_initialize_bandit_memories(bandit)
memory = bandit_memories[bandit.id]
memory.actions = bandit.action_values if memory.empty?
memory
end
def bandit_memories
@bandit_memories ||= Hash.new { |h, k| h[k] = BanditMemory.new }
end
def greedy_this_time
rand < @how_greedy
end
class BanditMemory
def re_evaluate(action, with:)
reward = with
actions[action].values << reward
end
def empty?
actions.empty?
end
def best_action
cur_max = actions.max_by { |_k, action| action.average }[1].average
max_actions = actions.select { |_k, action| action.average == cur_max }
max_actions.values.sample
end
def random_action
actions.values.sample
end
def actions
@actions ||= Hash.new
end
def actions=(action_values)
@actions = action_values.reduce({}) { |acc, val|
acc[val] = Action.new(id: val)
acc
}
end
class Action
def initialize(id:)
@id = id
end
def id
@id
end
def values
@values ||= []
end
def average
(values.reduce(&:+) || 0) / values.length
rescue ZeroDivisionError
0
end
end
end
end
end