-
Notifications
You must be signed in to change notification settings - Fork 0
/
pycrysfmlEnvironment.py
157 lines (117 loc) · 4.64 KB
/
pycrysfmlEnvironment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os,sys;sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
import os
from copy import copy
import numpy as np
import random as rand
import pickle
import itertools
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.axes as axes
import fswig_hklgen as H
import hkl_model as Mod
import sxtal_model as S
import bumps.names as bumps
import bumps.fitters as fitter
import bumps.lsqerror as lsqerr
from bumps.formatnum import format_uncertainty_pm
from tensorforce.environments import Environment
#Tensorforce Environment representation of
#the pycrysfml 'game'
class PycrysfmlEnvironment(Environment):
def __init__(self, observedFile, infoFile, backgFile=None, sxtal=True):
#Read data
self.spaceGroup, self.crystalCell, self.atomList = H.readInfo(infoFile)
# return wavelength, refList, sfs2, error, two-theta, and four-circle parameters
wavelength, refList, sfs2, error = S.readIntFile(observedFile, kind="int", cell=self.crystalCell)
self.wavelength = wavelength
self.refList = refList
self.sfs2 = sfs2
self.error = error
self.tt = [H.twoTheta(H.calcS(self.crystalCell, ref.hkl), wavelength) for ref in refList]
self.backg = None
self.exclusions = []
self.reset()
def reset(self):
#Make a cell
cell = Mod.makeCell(self.crystalCell, self.spaceGroup.xtalSystem)
#TODO: make model thru tensorforce, not here
#Define a model
self.model = S.Model([], [], self.backg, self.wavelength, self.spaceGroup, cell,
[self.atomList], self.exclusions,
scale=0.06298, error=[], extinction=[0.0001054])
#Set a range on the x value of the first atom in the model
self.model.atomListModel.atomModels[0].z.value = 0.3
self.model.atomListModel.atomModels[0].z.range(0,0.5)
self.visited = []
self.observed = []
self.remainingActions = []
for i in range(len(self.refList)):
self.remainingActions.append(i)
self.totReward = 0
self.prevChisq = None
self.step = 0
self.state = np.zeros(len(self.refList))
self.stateList = []
return self.state
def fit(self, model):
#Create a problem from the model with bumps,
#then fit and solve it
problem = bumps.FitProblem(model)
fitted = fitter.LevenbergMarquardtFit(problem)
x, dx = fitted.solve()
return x, dx, problem.chisq()
def execute(self, actions):
self.step += 1
# negative reward for repeat actions
if self.state[actions] == 1:
self.totReward -= 0.15
return self.state, (self.step > 300), -0.15 #stop only if step > 300
else:
self.state[actions] = 1
#No repeats
self.visited.append(self.refList[actions.item()])
self.remainingActions.remove(actions.item())
#Find the data for this hkl value and add it to the model
self.model.refList = H.ReflectionList(self.visited)
self.model._set_reflections()
self.model.error.append(self.error[actions])
self.model.tt = np.append(self.model.tt, [self.tt[actions]])
self.observed.append(self.sfs2[actions])
self.model._set_observations(self.observed)
self.model.update()
reward = -0.1
#Need more data than parameters, have to wait to the second step to fit
if len(self.visited) > 11:
x, dx, chisq = self.fit(self.model)
if (self.prevChisq != None and chisq < self.prevChisq):
reward = 0.1
self.prevChisq = chisq
self.totReward += reward
#stop early if chisq drops under certain threshold
if (self.prevChisq != None and self.step > 50 and chisq < 1):
return self.state, True, 0.5
if (len(self.remainingActions) == 0 or self.step > 300):
terminal = True
else:
terminal = False
return self.state, terminal, reward
@property
def states(self):
return dict(shape=self.state.shape, type='float')
@property
def actions(self):
return dict(num_actions=len(self.refList), type='int')
@staticmethod
def from_spec(spec, kwargs):
"""
Creates an environment from a specification dict.
"""
env = tensorforce.util.get_object(
obj=spec,
predefined_objects=tensorforce.environments.environments,
kwargs=kwargs
)
assert isinstance(env, Environment)
return env