-
Notifications
You must be signed in to change notification settings - Fork 2
/
qlearning.go
167 lines (141 loc) · 4.45 KB
/
qlearning.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Package qlearning is an experimental set of interfaces and helpers to
// implement the Q-learning algorithm in Go.
//
// This is highly experimental and should be considered a toy.
//
// See https://github.com/ecooper/qlearning/tree/master/examples for
// implementation examples.
package qlearning
import (
"fmt"
"math/rand"
"time"
)
// State is an interface wrapping the current state of the model.
type State interface {
// String returns a string representation of the given state.
// Implementers should take care to insure that this is a consistent
// hash for a given state.
String() string
// Next provides a slice of possible Actions that could be applied to
// a state.
Next() []Action
}
// Action is an interface wrapping an action that can be applied to the
// model's current state.
//
// BUG (ecooper): A state should apply an action, not the other way
// around.
type Action interface {
String() string
Apply(State) State
}
// Rewarder is an interface wrapping the ability to provide a reward
// for the execution of an action in a given state.
type Rewarder interface {
// Reward calculates the reward value for a given action in a given
// state.
Reward(action *StateAction) float32
}
// Agent is an interface for a model's agent and is able to learn
// from actions and return the current Q-value of an action at a given state.
type Agent interface {
// Learn updates the model for a given state and action, using the
// provided Rewarder implementation.
Learn(*StateAction, Rewarder)
// Value returns the current Q-value for a State and Action.
Value(State, Action) float32
// Return a string representation of the Agent.
String() string
}
// StateAction is a struct grouping an action to a given State. Additionally,
// a Value can be associated to StateAction, which is typically the Q-value.
type StateAction struct {
State State
Action Action
Value float32
}
// NewStateAction creates a new StateAction for a State and Action.
func NewStateAction(state State, action Action, val float32) *StateAction {
return &StateAction{
State: state,
Action: action,
Value: val,
}
}
// Next uses an Agent and State to find the highest scored Action.
//
// In the case of Q-value ties for a set of actions, a random
// value is selected.
func Next(agent Agent, state State) *StateAction {
best := make([]*StateAction, 0)
bestVal := float32(0.0)
for _, action := range state.Next() {
val := agent.Value(state, action)
if bestVal == float32(0.0) {
best = append(best, NewStateAction(state, action, val))
bestVal = val
} else {
if val > bestVal {
best = []*StateAction{NewStateAction(state, action, val)}
bestVal = val
} else if val == bestVal {
best = append(best, NewStateAction(state, action, val))
}
}
}
return best[rand.Intn(len(best))]
}
// SimpleAgent is an Agent implementation that stores Q-values in a
// map of maps.
type SimpleAgent struct {
q map[string]map[string]float32
lr float32
d float32
}
// NewSimpleAgent creates a SimpleAgent with the provided learning rate
// and discount factor.
func NewSimpleAgent(lr, d float32) *SimpleAgent {
return &SimpleAgent{
q: make(map[string]map[string]float32),
d: d,
lr: lr,
}
}
// getActions returns the current Q-values for a given state.
func (agent *SimpleAgent) getActions(state string) map[string]float32 {
if _, ok := agent.q[state]; !ok {
agent.q[state] = make(map[string]float32)
}
return agent.q[state]
}
// Learn updates the existing Q-value for the given State and Action
// using the Rewarder.
//
// See https://en.wikipedia.org/wiki/Q-learning#Algorithm
func (agent *SimpleAgent) Learn(action *StateAction, reward Rewarder) {
current := action.State.String()
next := action.Action.Apply(action.State).String()
actions := agent.getActions(current)
maxNextVal := float32(0.0)
for _, v := range agent.getActions(next) {
if v > maxNextVal {
maxNextVal = v
}
}
currentVal := actions[action.Action.String()]
actions[action.Action.String()] = currentVal + agent.lr*(reward.Reward(action)+agent.d*maxNextVal-currentVal)
}
// Value gets the current Q-value for a State and Action.
func (agent *SimpleAgent) Value(state State, action Action) float32 {
return agent.getActions(state.String())[action.String()]
}
// String returns the current Q-value map as a printed string.
//
// BUG (ecooper): This is useless.
func (agent *SimpleAgent) String() string {
return fmt.Sprintf("%v", agent.q)
}
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}