-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathTaskEnvironment.py
1678 lines (1494 loc) · 60.2 KB
/
TaskEnvironment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# ======================================
# Environments that implement tasks
# ======================================
#
# Key OpenAI Gym defines:
# (1) step()
# (2) reset()
import numpy as np
import time
import matplotlib.pyplot as plt
import pettingzoo
from gymnasium.spaces import Box, Space, Dict
# https://github.com/Farama-Foundation/PettingZoo
from types import FunctionType
from typing import List, Union
from functools import partial
import warnings
from copy import copy, deepcopy
import random
from ratinabox.Environment import Environment
from ratinabox.Agent import Agent
class TaskEnvironment(Environment, pettingzoo.ParallelEnv):
"""
Environment with task structure: there is a goal, and when the
goal is reached, it terminates an episode, and starts a new episode
(reset). This environment can be static or dynamic, depending on whether
update() is implemented.
In order to be more useful with other Reinforcement Learning pacakges, this
environment inherits from both ratinabox.Environment and openai's widely
used gym environment.
# Inputs
--------
*pos : list
Positional arguments to pass to Environment
verbose : bool
Whether to print out information about the environment
render_mode : str
How to render the environment. Options are 'matplotlib', 'pygame', or
'none'
render_every : int
How often to render the environment (in time steps)
render_every_framestep : int
How often to render the environment (in framesteps)
teleport_on_reset : bool
Whether to teleport agents to random positions on reset
save_expired_rewards : bool
Whether to save expired rewards in the environment
goals : list
List of goals to replenish the goal cache with on reset
goalcachekws : dict
Keyword arguments to pass to GoalCache
episode_termination_delay : float
How long to wait before terminating an episode after the goal is reached
**kws :
Keyword arguments to pass to Environment
"""
default_params = {} #for RatInABox
metadata = {"render_modes": ["matplotlib", "none"], "name": "TaskEnvironment-RiaB"}
def __init__(
self,
*pos,
dt=0.01,
render_mode="matplotlib",
render_every=None,
render_every_framestep=2,
teleport_on_reset=False,
save_expired_rewards=False,
goals=[], # one can pass in goal objects directly here
goalcachekws=dict(),
rewardcachekws=dict(),
episode_terminate_delay=0,
verbose=False,
**kws,
):
super().__init__(*pos, **kws)
self.dynamic = {"walls": [], "objects": []}
self.Ags: dict[str, Agent] = {} # dict of agents in the environment
self.goal_cache: GoalCache = GoalCache(self, **goalcachekws)
# replenish from this list of goals on reset
self.goal_cache.reset_goals = goals if isinstance(goals, list) else [goals]
self.t = 0 # current time
self.dt = dt # time step
self.history = {"t": []} # history of the environment
if render_every is None and render_every_framestep is not None:
self.render_every = render_every_framestep # How often to render
elif render_every is not None:
self.render_every = render_every / self.dt
self.verbose = verbose
self.render_mode: str = render_mode # options 'matplotlib'|'pygame'|'none'
self._stable_render_objects: dict = {} # objects that are stable across
# a rendering type
# ----------------------------------------------
# Agent-related task config
# ----------------------------------------------
self.teleport_on_reset = teleport_on_reset # Whether to teleport
# agents to random
# ----------------------------------------------
# Setup gym primatives
# ----------------------------------------------
# Setup observation space from the Environment space
self.observation_spaces: Dict[Space] = Dict({})
self.action_spaces: Dict[Space] = Dict({})
self.agent_names: List[str] = []
self.agents: List[str] = [] # pettingzoo variable
# that tracks all agents who are
# still active in an episode
self.infos: dict = {} # pettingzoo returns infos in step()
self.observation_lambda = {} # lambda functions to attain an agents
# observation information -- a vector
# of whatever info in the agent defines
# its current observation -- DEFAULT: pos
# Episode history
self.episodes: dict = {} # Written to upon completion of an episode
self.episodes["episode"] = []
self.episodes["start"] = []
self.episodes["end"] = []
self.episodes["duration"] = []
self.episodes["meta_info"] = []
self.episode = 0
# Episode state and option
self.episode_state = {"delayed_term": False}
self.episode_terminate_delay = episode_terminate_delay
# Reward cache specifics
self.reward_caches: dict[str, RewardCache] = {}
self.save_expired_rewards = save_expired_rewards
self.expired_rewards: List[RewardCache] = []
self.rewardcachekws = rewardcachekws
def observation_space(self, agent_name: str):
return self.observation_spaces[agent_name]
def action_space(self, agent_name: str):
return self.action_spaces[agent_name]
def add_agents(
self,
agents: Union[dict, List[Agent], Agent],
names: Union[None, List] = None,
maxvel: float = 50.0,
**kws,
):
"""
Add agents to the environment
For each agent, we add its action space (expressed as velocities it can
take) to the environment's action space.
Parameters
----------
agents : Dict[Agent] | List[Agent] | Agent
The agents to add to the environment
names : List[str] | None
The names of the agents. If None, then the names are generated
maxvel : float
The maximum velocity that the agents can take
"""
if not isinstance(agents, (list, Agent)):
raise TypeError("agents must be a list of agents or an agent type")
if isinstance(agents, Agent):
agents = [agents]
if not ([agent.dt == self.dt for agent in agents]):
raise NotImplementedError(
"Does not yet support agents with different dt from envrionment"
)
if isinstance(agents, dict):
names = list(agents.keys())
agents = list(agents.values())
elif names is None:
start = len(self.Ags)
names = ["agent_" + str(start + i) for i in range(len(agents))]
# Enlist agents
for i, (name, agent) in enumerate(zip(names, agents)):
self.Ags[name] = agent
self.agent_names.append(name)
agent.name = name # attach name to agent
# Add the agent's action space to the environment's action spaces
# dict
D = int(self.dimensionality[0])
self.action_spaces[name] = Box(low=-maxvel, high=maxvel, shape=(D,))
# Add the agent's observation space to the environment's
# observation spaces dict
ext = [self.extent[i : i + 2] for i in np.arange(0, len(self.extent), 2)]
lows, highs = np.array(list(zip(*ext)), dtype=float)
self.observation_spaces[name] = Box(low=lows, high=highs, dtype=float)
self.observation_lambda[name] = lambda agent: agent.pos
# Attach a reward cache for the agent
cache = RewardCache(**self.rewardcachekws)
self.reward_caches[name] = cache
agent.reward = cache
# Ready the goal_cache for the agent
self.goal_cache.add_agent(agent)
# Set the agents time to the environment time
agent.t = self.t # agent clock is aligned to environment,
# in case a new agent is placed in the env
# on a later episode
self.infos[name] = {} # pettingzoo requirement
self.reset() # reset the environment with new agent
def remove_agents(self, agents):
"""
Remove agents from the environment
Parameters
----------
agents
"""
agents = self._agentnames(agents)
for name in agents:
self.reward_caches.pop(name)
self.observation_spaces.spaces.pop(name)
self.action_spaces.spaces.pop(name)
self.Ags.pop(name)
self.agent_names.remove(name)
if name in self.agents:
self.agents.remove(name)
self.reset()
def _agentnames(self, agents=None) -> list[str]:
"""
Convenience function for generally hanlding all the ways that
users might want to specify agents, names, numbers, or objects
themselves. Also as a "scalar" or list of such thing. This makes
several functions that call this robust to ways users specify
agents.
"""
if isinstance(agents, Agent):
agents: list[str] = [agents.name]
if isinstance(agents, int):
agents = [self.agent_names[agents]]
elif isinstance(agents, str):
agents = [agents]
elif isinstance(agents, list):
new: list[str] = []
for agent in agents:
if isinstance(agent, int):
new.append(self.agent_names[agent])
elif isinstance(agent, Agent):
new.append(agent.name)
elif isinstance(agent, str):
new.append(agent)
else:
raise TypeError("agent must be an Agent, int, or str")
agents = new
elif agents is None:
agents = self.agent_names
return agents
def _dict(self, V) -> dict:
"""
Convert a list of values to a dictionary of values keyed by agent name
"""
return (
{name: v for (name, v) in zip(self.agent_names, V)}
if hasattr(V, "__iter__")
else {name: V for name in self.agent_names}
)
def _is_terminal_state(self):
"""Whether the current state is a terminal state"""
# Check our objectives
test_goal = 0
# Loop through objectives, checking if they are satisfied
rewards, agents = self.goal_cache.check(remove_finished=True)
for reward, agent in zip(rewards, agents):
self.reward_caches[agent].append(reward)
if self.verbose >= 2:
print("GOALS:", self.goal_cache.goals)
# Return if no objectives left
no_objectives_left = len(self.goal_cache) == 0
return no_objectives_left
def _is_truncated_state(self):
"""
whether the current state is a truncated state,
see https://gymnasium.farama.org/api/env/#gymnasium.Env.step
default is false: an environment by default will have a terminal state,
ending the episode, whereon users should call reset(), but not a
trucation state ending the mdp.
"""
return False
def seed(self, seed=None):
"""Seed the random number generator"""
np.random.seed(seed)
def reset(self, seed=None, episode_meta_info=False, options=None):
"""How to reset the task when finished"""
if seed is not None:
self.seed(seed)
if self.verbose:
print("Resetting")
if len(self.episodes["start"]) > 0:
self.write_end_episode(episode_meta_info=episode_meta_info)
# Reset active non-terminated agents
self.agents = copy(self.agent_names)
# Clear rendering cache
self.clear_render_cache()
# If teleport on reset, randomly pick new location for agents
if self.teleport_on_reset:
for agent_name, agent in self.Ags.items():
# agent.update()
agent.pos = self.sample_positions(1)[
0
] # random position in the environment
if len(agent.history["pos"]) > 0:
agent.history["pos"][-1] = agent.pos
# Increment episode counter
if len(self.episodes["duration"]) and self.episodes["duration"][-1] == 0:
for key in self.episodes:
self.episodes[key].pop()
else:
self.episode += 1
self.write_start_episode()
# Restore agents to active state (pettingzoo variable)
self.agents = copy(self.agent_names)
# print("Active agents: ", self.agents)
# Reset goals
self.goal_cache.reset()
# Episode state trackers
# we have not applied a delayed terminate
self.episode_state["delayed_term"] = False
return self.get_observation(), self.infos
def update(self, update_agents=False):
"""
How to update the task over time --- update things
directly connected to the task
"""
self.t += self.dt # base task class only has a clock
self.history["t"].append(self.t)
def step(
self,
actions: Union[dict, np.array, None] = None,
dt=None,
drift_to_random_strength_ratio=1,
*pos,
**kws,
):
"""
step()
step() functions in Gymnasium paradigm usually take an action space
action, and return the next state, reward, whether the state is
terminal, and an information dict
different from update(), which updates this environment. this function
executes a full step on the environment with an action from the agents
https://pettingzoo.farama.org/api/parallel/#pettingzoo.utils.env.ParallelEnv.step
"""
# If the user passed drift_velocity, update the agents
if actions is not None:
if len(self.agents) == 0:
raise AttributeError(
"Action is given, but there are no "
"active agents. If there are no agents, try adding an "
"agent with .add_agents(). If there are agents, "
"try .reset() to restore inactive agents w/o goals to "
"active."
)
actions = actions if isinstance(actions, dict) else self._dict(actions)
else:
# Move agents randomly on None
actions = self._dict([None for _ in range(len(self.Ags))])
if not isinstance(drift_to_random_strength_ratio, dict):
drift_to_random_strength_ratio = self._dict(drift_to_random_strength_ratio)
for agent, action in zip(self.agents, actions.values()):
Ag = self.Ags[agent]
dt = dt if dt is not None else Ag.dt
if action is not None:
action = np.array(action).ravel()
action[np.isnan(action)] = 0
strength = drift_to_random_strength_ratio[agent]
Ag.update(
dt=dt, drift_velocity=action, drift_to_random_strength_ratio=strength
)
# Update the reward caches for time decay of existing rewards
for reward_cache in self.reward_caches.values():
reward_cache.update()
# Udpate the environment, which can add new rewards to caches
self.update(*pos, **kws)
# Return the next state, reward, whether the state is terminal,
terminal = self._is_terminal_state()
# Episode termination delay?
if (
terminal
and self.episode_terminate_delay
and self.episode_state["delayed_term"] == False
):
unrewarded_episode_padding = TimeElapsedGoal(
self,
reward=no_reward_default,
wait_time=self.episode_terminate_delay,
verbose=False,
)
self.episode_state["delayed_term"] = True
self.goal_cache.append(unrewarded_episode_padding)
terminal = self._is_terminal_state()
# If any terminal agents, remove from set of active agents
truncations = self._dict(self._is_truncated_state())
for agent, term in self._dict(self._is_terminal_state()).items():
if term and agent in self.agents or truncations[agent]:
self.agents.remove(agent)
# Create pettingzoo outputs
outs = (
self.get_observation(),
self.get_reward(),
self._dict(terminal),
self._dict(self._is_truncated_state()),
self._dict([self.infos]),
)
if self.verbose:
print(f"🐀 action @ {self.t}:", actions)
print(f"🌍 step @ {self.t}:", outs)
return outs
def step1(self, action=None, *pos, **kws):
"""
shortcut for stepping when only 1 agent exists...makes it behave
like gymnasium instead of pettingzoo
"""
results = self.step({self.agent_names[0]: action}, *pos, **kws)
results = [x[self.agent_names[0]] for x in results]
return results
def get_observation(self):
"""Get the current state of the environment"""
return {
name: self.observation_lambda[name](agent)
for name, agent in self.Ags.items()
}
def get_reward(self):
"""Get the current reward state of each agent"""
return {name: agent.reward.get_total() for name, agent in self.Ags.items()}
def set_observation(
self,
agents: Union[List, str, Agent],
spaces: Union[List, Space],
observation_lambdass: Union[List, FunctionType],
):
"""
Set the observation space and observation function for an agent(s)
- The space is a gym.Space that describes the set of possible
values an agents observation can take
- The lambda takes an agent argument and returns a tuple/list of
numbers regardings the agents position. users can set the lambda to
extract whatever attributes of the agent encode its state
The default for agents is there position. But if you would like to
change the observation to cell firing or velocity, you can do that
here.
Input
----
agents: List, str, Agent
The agent(s) to change the observation space for
spaces: List, gym.Space
The observation space(s) to change to. If a list, then it
must be a list of gymnasium spaces (these just describe the
full range of values an observation can take, and RL libraries
often use these to sample the space.)
observation_lambdass: List, Function
The observation function(s) to change to...these should take an
agent and output the vector of numbers describing what you
consider your agents' state. you can set the function to grab
whatever you'd like about the agent: it's position, velocity,
"""
agents = self._agentnames(agents)
if not isinstance(spaces, list):
spaces = [spaces]
if not isinstance(observation_lambdass, list):
observation_lambdass = [observation_lambdass]
if len(spaces) != len(observation_lambdass):
raise ValueError(
"observation space and observation lambda " "must be the same length"
)
for ag, sp, obs in zip(agents, spaces, observation_lambdass):
print("Changing observation space for {ag}")
self.observation_spaces[ag] = sp
self.observation_lambda[ag] = obs
# ----------------------------------------------
# Reading and writing episode data
# ----------------------------------------------
def _current_episode_start(self):
return 0 if not len(self.episodes["start"]) else self.episodes["end"][-1]
def write_start_episode(self):
self.episodes["episode"].append(self.episode)
self.episodes["start"].append(self._current_episode_start())
if self.verbose:
print("starting episode {}".format(self.episode))
print("episode start time: {}".format(self.episodes["start"][-1]))
def write_end_episode(self, episode_meta_info="none"):
self.episodes["end"].append(self.t)
self.episodes["duration"].append(self.t - self.episodes["start"][-1])
self.episodes["meta_info"].append(episode_meta_info)
if self.verbose:
print("ending episode {}".format(self.episode))
print("episode end time: {}".format(self.episodes["end"][-1]))
print("episode duration: {}".format(self.episodes["duration"][-1]))
# ----------------------------------------------
# Rendering
# ----------------------------------------------
def render(self, render_mode=None, *pos, **kws):
"""
Render the environment
"""
if render_mode is None:
render_mode = self.render_mode
# if self.verbose:
# print("rendering environment with mode: {}".format(render_mode))
if render_mode == "matplotlib":
out = self._render_matplotlib(*pos, **kws)
assert out is not None
return out
elif render_mode == "pygame":
return self._render_pygame(*pos, **kws)
elif render_mode == "none":
pass
else:
raise ValueError("method must be 'matplotlib' or 'pygame'")
def _render_matplotlib(self, *pos, agentkws: dict = dict(), **kws):
"""
Render the environment using matplotlib
`
Inputs
------
agentkws: dict
keyword arguments to pass to the agent's render method
"""
R, fig, ax = self._get_mpl_render_cache()
if np.mod(self.t, self.render_every) < self.dt:
# Skip rendering unless this is redraw time
return fig, ax
else:
# Render the environment
self._render_mpl_env()
# Render the agents
self._render_mpl_agents(**agentkws)
return fig, ax
def _get_mpl_render_cache(self):
if "matplotlib" not in self._stable_render_objects:
R = self._stable_render_objects["matplotlib"] = {}
else:
R = self._stable_render_objects["matplotlib"]
if "fig" not in R:
fig, ax = plt.subplots(1, 1)
R["fig"] = fig
R["ax"] = ax
else:
fig, ax = R["fig"], R["ax"]
return R, fig, ax
def _render_mpl_env(self):
R, fig, ax = self._get_mpl_render_cache()
if "environment" not in R:
R["environment"] = self.plot_environment(fig=fig, ax=ax, autosave=False)
R["title"] = fig.suptitle(
"t={:.2f}\nepisode={}".format(self.t, self.episode)
)
else:
R["title"].set_text("t={:.2f}\nepisode={}".format(self.t, self.episode))
def _render_mpl_agents(self, framerate=60, alpha=0.7, t_start="episode", **kws):
"""
Render the agents 🐀
Inputs
------
framerate: float
the framerate at which to render the agents
alpha: float
the alpha value to use for the agents
t_start: float
the time at which to start rendering the agents
- "episode" : start at the beginning of the current episode
- "all" : start at the beginning of the first episode
- float : start at the given time
**kws
keyword arguments to pass to the agent's style (point size, color)
see _agent_style
"""
R, fig, ax = self._get_mpl_render_cache()
initialize = "agents" not in R
if t_start == "episode":
t_start = self.episodes["start"][-1]
elif t_start == "all" or t_start is None:
t_start = self.episodes["start"][0]
def get_agent_props(agent, color):
t = np.array(agent.history["t"])
startid = np.nanargmin(np.abs(t - (t_start)))
skiprate = int((1.0 / framerate) // agent.dt)
trajectory = np.array(agent.history["pos"][startid::skiprate])
t = t[startid::skiprate]
c, s = self._agent_style(
agent, t, color, startid=startid, skiprate=skiprate, **kws
)
return trajectory, c, s
if initialize or len(R["agents"]) != len(self.Ags):
R["agents"] = []
for i, agent in enumerate(self.Ags.values()):
if len(agent.history["t"]):
trajectory, c, s = get_agent_props(agent, i)
ax.scatter(
*trajectory.T, s=s, alpha=alpha, zorder=0, c=c, linewidth=0
)
R["agents"].append(ax.collections[-1])
else:
for i, agent in enumerate(self.Ags.values()):
scat = R["agents"][i]
trajectory, c, s = get_agent_props(agent, i)
scat.set_offsets(trajectory)
scat.set_facecolors(c)
scat.set_edgecolors(c)
scat.set_sizes(s)
@staticmethod
def _agent_style(
agent: Agent,
time,
color=0,
skiprate=1,
startid=0,
point_size: bool = 15,
decay_point_size: bool = False,
plot_agent: bool = True,
decay_point_timescale: int = 10,
):
if isinstance(color, int):
color = plt.rcParams["axes.prop_cycle"].by_key()["color"][color]
s = point_size * np.ones_like(time)
if decay_point_size == True:
s = point_size * np.exp((time - time[-1]) / decay_point_timescale)
s[(time[-1] - time) > (1.5 * decay_point_timescale)] *= 0
c = [color] * len(time)
if plot_agent == True:
s[-1] = 40
c[-1] = "r"
return c, s
def _render_pygame(self, *pos, **kws):
pass
def clear_render_cache(self):
"""
clear_render_cache
clears the cache of objects held for render()
"""
if "matplotlib" in self._stable_render_objects:
R = self._stable_render_objects["matplotlib"]
R["ax"].cla()
for item in set(R.keys()) - set(("fig", "ax")):
R.pop(item)
def close(self):
"""gymnasium close() method"""
self.clear_render_cache()
if "fig" in self._stable_render_objects:
if isinstance(self._stable_render_objects["fig"], plt.Figure):
plt.close(self._stable_render_objects["fig"])
class Reward:
"""
When an task goal is triggered, reward goal is attached an Agent's
reward:list. This object tracks the dynamics of the reward applied to the
agent.
This implementation allows rewards to be applied:
- externally (through a task environment)
- or internally (through the agent's internal neuronal dynamics),
e.g. through a set of neurons tracking rewards, attached to the agent
This tracker specifies what the animals reward value should be at a given
time while the reward is activate
"""
decay_preset = {
"constant": lambda a, x: a,
"linear": lambda a, x: a * x,
"exponential": lambda a, x: a * np.exp(x),
"none": lambda a, x: 0,
}
decay_knobs_preset = {
"linear": [1],
"constant": [1],
"exponential": [2],
"none": [0],
}
def __init__(
self,
init_state=1,
dt=0.01,
expire_clock=None,
decay=None,
decay_knobs=[],
external_drive: Union[FunctionType, None] = None,
external_drive_strength=1,
name=None,
):
"""
Parameters
----------
init_state : float
initial reward value
dt : float
timestep
expire_clock : float|None
time until reward expires, if None, reward never expires
decay : str|function|None
decay function, or decay preset name, or None
decay_knobs : list
decay function knobs
external_drive : function|None
external drive function, or None. can be used to attach a goal
gradient or reward ramping signal
external_drive_strength : float
strength of external drive, how quickly the reward follows the
external drive
"""
self.state = (
init_state if not isinstance(init_state, FunctionType) else init_state()
)
self.dt = dt
self.expire_clock = (
expire_clock if isinstance(expire_clock, (int, float)) else dt
)
if isinstance(decay, str):
self.preset = decay
self.decay_knobs = decay_knobs or self.decay_knobs_preset[self.preset]
self.decay = partial(self.decay_preset[self.preset], *self.decay_knobs)
else:
self.preset = "custom" if decay is not None else "constant"
self.decay_knobs = decay_knobs or self.decay_knobs_preset[self.preset]
self.decay = decay or self.decay_preset["constant"]
self.external_drive = external_drive
self.external_drive_strength = external_drive_strength
self.history = {"state": [], "expire_clock": []}
self.name = (
name
if name is not None
else self.__class__.__name__ + " " + str(hash(self))[:5]
)
# if a goal provides a reward, then this attribute is used to track
# the goal that provided the reward
self.goal: Union[None, Goal] = None # optional store goal linked to
# reward
def update(self):
"""
update reward,
grows towards the gradient target value from its initial value, if a
target_value() function is defined. otherwise, reward is only
controlled by decay from some initial value. if decay is 0, and target
gradient is not defined then its constant, until the reward expire time
is reached.
# Returns
True if reward is still active, False if reward has expired
"""
self.state = self.state + self.get_delta() * self.dt
self.expire_clock -= self.dt
self.history["state"].append(self.state)
self.history["expire_clock"].append(self.expire_clock)
return not (self.expire_clock <= 0)
def get_delta(self, state=None):
"""\delta(reward) for a dt"""
state = self.state if state is None else state
if self.external_drive is not None:
target_gradient = self.external_drive()
strength = self.external_drive_strength
change = strength * (target_gradient - state) - self.decay(state)
else:
change = -(self.decay(state))
return change
def plot_theoretical_reward(self, timerange=(0, 1), name=None):
"""
plot the reward dynamics : shows the user how their parameters of
interest setup reward dynamics, without updating the object
"""
rewards = [self.state]
name = self.name if name is None else name
timesteps = np.arange(timerange[0], timerange[1], self.dt)
pre_expire_timesteps = np.arange(
timerange[0], self.expire_clock + self.dt, self.dt
)
for t in pre_expire_timesteps[1:]:
r = rewards[-1] + self.get_delta(state=rewards[-1]) * self.dt
rewards.append(r)
plt.plot(
pre_expire_timesteps,
rewards[: len(timesteps)],
label=f"reward={self.preset}, " f"knobs={self.decay_knobs}",
)
y1 = np.min((self.state, 0, np.min(plt.gca().get_ylim())))
y2 = np.max((self.state, 0, np.max(plt.gca().get_ylim())))
plt.ylim((y1, y2))
plt.xlim((timerange[0], timerange[1]))
plt.axvspan(0, self.expire_clock, color="r", alpha=0.2)
plt.text(
np.mean((plt.gca().get_xlim()[0], self.expire_clock)),
np.mean(plt.gca().get_ylim()),
f"{name}\nactive",
backgroundcolor="black",
color="white",
)
plt.text(
np.mean((self.expire_clock, plt.gca().get_xlim()[-1])),
np.mean(plt.gca().get_ylim()),
f"{name}\nexpires",
backgroundcolor="black",
color="white",
)
plt.gca().set(xlabel="time (s)", ylabel=f"{name} signal")
return plt.gcf(), plt.gca()
class RewardCache:
"""
RewardCache
A cache of all `active` rewards attached to an agent
# Parameters
default_reward_level : float
default reward level for all rewards in the cache
verbose : bool, int
print reward cache related details
"""
def __init__(self, default_reward_level=0, verbose=False):
self.default_reward_level = default_reward_level
self.cache: List[Reward] = []
self.verbose = verbose
self.stats = {
"total_steps_active": 0,
"total_steps_inactive": 0,
"max": -np.inf,
"min": np.inf,
"uniq_rewards": [],
"uniq_goals": [],
}
def append(self, reward: Reward, copymode=True):
assert isinstance(reward, Reward), "reward must be a Reward object"
if reward is not None:
if copymode:
reward = copy(reward)
if reward.name not in self.stats["uniq_rewards"]:
self.stats["uniq_rewards"].append(reward.name)
if reward.goal.name not in self.stats["uniq_goals"]:
self.stats["uniq_goals"].append(reward.goal.name)
self.cache.append(reward)
def update(self):
"""Update"""
# If any rewards ...
if self.cache:
self.stats["total_steps_active"] += 1
# Iterate through each reward, updating
for reward in self.cache:
reward_still_active = reward.update()
if not reward_still_active:
self.cache.remove(reward)
if self.verbose:
print("Reward removed from cache")
# Else, increment inactivity tracker
else:
self.stats["total_steps_inactive"] += 1
def get_total(self):
"""
If there are any active rewards, return the sum of their values.
"""
r = sum([reward.state for reward in self.cache]) + self.default_reward_level
assert not np.isnan(r), "reward is nan"
if r > self.stats["max"]:
self.stats["max"] = r
if r < self.stats["min"]:
self.stats["min"] = r
return r
def get_fraction(self):
"""Return the fraction of the total reward value relative to the max
and min values so far experienced."""
r = self.get_total()
return (r - self.stats["min":]) / (self.stats["max"] - self.stats["min"])
reward_default = Reward(1, 0.01, expire_clock=1, decay="linear")
no_reward_default = Reward(
0, 0.01, expire_clock=0.1, decay="none"
) # A reward object which doesn't give any reward (for use in goals where no reward is then given)
class Goal:
"""
Abstract `Objective` class that can be used to define finishing coditions
for a task
"""
def __init__(
self,
env: Union[None, TaskEnvironment] = None,
reward=reward_default,
name=None,
**kws,
):
self.env = env
self.reward = reward
self.reward.goal = self
self.name = (
name
if name is not None
else self.__class__.__name__ + " " + str(hash(random.random()))[:5]
)
def __hash__(self):
"""hash for uniquely identifying a goal"""
hashes = []
for value in self.__dict__.values():
try:
hashes.append(hash(value))
except:
pass
return hash(tuple(hashes))
def check(self, agents=None):
"""
Check if the goal is satisfied for agents and report which agents
satisfied the goal and if any rewards are rendered
"""
raise NotImplementedError("check() must be implemented")
def __call__(self):
"""
Can be used to report its value to the environment
(Not required -- just a convenience)
"""
pass