-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pickling state / pickling whole class issue #6
Comments
Hello, I would also like to do the calculation, the fit separately from the predict method, among other things also because the calculation takes some time. Based on the previous post, I have adapted an example to show my approach. I would be grateful for any hints on how to do it correctly.
My error message: Mean accuracy over 10 runs:
|
Just to link things together, that last post is now asked over at StackOverflow: |
Hi @bmreiniger and @andife! Thanks for bringing up this issue. The get- and set state methods only operate on the state of the Tsetlin Automata from the C-part of the code.
As you see from the above code, the number of features and classes is obtained automatically from the input data X and Y. So, one possible trick is to call fit on the training data first, with epochs=0. Then fit will set the number of classes and features and create the TM in C, without running any epochs over the data. After calling fit, you can call set_state to initialize the Tsetlin Automata from the saved state. I am planning to add methods for this. Let me know if it works out! |
Hello @olegranmo, @bmreiniger; Thanks for the solution for this issue. It works! (I used the code below for testing) In general, I'm wondering if I really need/should load the original dataset, or if it's not a synthetic one that just contains the same number of features and classes. This would also correspond to the approach of https://github.com/cair/pyTsetlinMachine/pull/4/files.
|
Great that it works, @andife! Yes, I guess you could for instance just use one example X with the correct number of features and the largest y-value (number of classes - 1). |
@olegranmo It seems this approach is not working for CUDA version.
results of reloaded TM are the same as at the end of training: #1 Accuracy: 94.27% Training: 35.89s Testing: 21.07s but for CUDA version:
there is no effect after reload, it's like script is starting from zero: #1 Accuracy: 92.80% Training: 9.59s Testing: 1.08s |
Hi @devop01 - thanks for reporting! I have started adding pickle support, just completed for PyTsetlinMachine. |
Hi again @devop01, just added pickle support for PyTsetlinMachineCUDA! |
Hi @olegranmo Thank you for new version :) Accuracy over 25 epochs: #1 Accuracy: 92.89% Training: 9.78s Testing: 1.06s Below is a full code I used:
|
Hi @devop01, this happens because the local voting tallies used for asynchronous parallel learning is not stored as part of the state. Everything is reinitialized when you start training again. Will fix this in the next update! |
As far as I understand there are problems with saving state of learned TM.
This is important for research development, as checkpoint can be created in a easy way.
I tried to pickle MultiClassTsetlinMachine (from pyTsetlinMachineParallel.tm ) as well as pyCUDA version and both failed.
Then I tried to pickle just TM state (from MNIST example):
eg:
tm = MultiClassTsetlinMachine(2000, 50, 10.0)
f = open..
pickle.dump(tm.get_state(),f)
#new session
f = open...
state = pickle.load(f)
tm2 = MultiClassTsetlinMachine(2000, 50, 10.0)
tm2.set_state(state)
Result:
Traceback (most recent call last):
File "<pyshell#15>", line 1, in
tm.set_state(state)
File "/.../TsetlinMachine/pyTsetlinMachineParallel/tm.py", line 314, in set_state
for i in range(self.number_of_classes):
AttributeError: 'MultiClassTsetlinMachine' object has no attribute 'number_of_classes'
if I update number of classes:
tm2.number_of_classes = 10
tm2.set_state(state)
program hangs, and after some time I got:
=============================== RESTART: Shell ===============================
The text was updated successfully, but these errors were encountered: