Skip to content

Commit

Permalink
Refactor/autodiff/track pnl (#3030)
Browse files Browse the repository at this point in the history
• composition.py, autodiffcomposition.py and relevant subordinate methods:
 - implement synch and track parameter dictionaries that are passed to relevant methods
 
  - add/rename attributes:
    - PytorchCompositionWrapper:
      - retained_outputs
      - retained_targets
      - retained_losses
      - _nodes_to_execute_after_gradient_calc
    - PytorchMechanismWrapper:
      - value -> output
      - input
  - add methods:
    - synch_with_psyneulink(): centralize copying of params and values to pnl using methods below
    - copy_node_variables_to_psyneulink(): centralize updating of node (mech & comp) variables in PNL
    - copy_node_values_to_psyneulink(): centralize updating of node (mech & comp) values in PNL
    - copy_results_to_psyneulink(): centralize updating of autodiffcomposition.results
    - retain_in_psyneulink(): centralize tracking of pytorch results in PNL using methods below
    - retain_torch_outputs: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn()
    - retain_torch_targets: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn()
    - retain_torch_losses: keeps record of losses and copies to AutodiffComposition.pytorch_losses at end of call to learn()


• compositionrunner.py, autodiffcomposition.py, pytorchwrappers.py:
  - move loss tracking from parameter on autodiff to attribute on its pytorch_rep
  - batch_inputs():  add calls to synch_with_psyneulink() and retain_in_psyneulink()
  - batch_function_inputs():
     - needs calls to synch_with_psyneulink() and retain_in_psyneulink()

• composition.py:
- run(): add _update_results() as helper method than can be overidden (e.g., by autodiffcomposition) for less frequent updating

* • autodiffcomposition.py
  - restrict calls to copy_weights_to_psyneulink based on copy_parameters_to_psyneulink_after arg/attribute
  - implement handling of optimizations_per_minibatch and copy_parameters_to_psyneulink as attributes and args to learn
  - autodiff_training(): fix bug in call to pytorch_rep.forward()
  - implement synch and track Parameters
  - _manage_synch_and_retain_args()
  - run(): support specification of synch and retain args when called directly
  - autodiff._update_learning_parameters -> do_optimzation():
    - calculates loss for current trial
    - calls autodiff_backward() to calculate gradients and update parameters
    - updates tracked_loss over triasl
  - autodiff_backward() -> new method that is called from do_optimization that calculates and updates the gradients
  - self.loss -> self.loss_function
  - _update_results() - overriden to call pytoch_rep.retain_for_psyneulink(RUN:trial_output)
  - learn():
    - move tracked_loss for each minibatch from parameter on autodiff to attribute on its pytorch_rep
       (since that is already context dependent, and avoids calls to pnl.parameters._set on every call to forward()
    - synch_with_pnl_options:
         implement as dict to consolidate synch_projection_matrices_with_torch, synch_node_values_with_torch and synch_node_values_with_torch options passed to learning methods
    - retain_in_pnl_options
         implement as dict to consolidate retain_torch_outputs_in_results, retain_torch_targets and retain_torch_losses
         passed to learning methods

• pytorchwrappers.py
  - sublcass PytorchCompositionWrapper from torch.jit.ScriptModule
  - retain_for_psyneulink(): implemented
  - stores outputs, targets, and losses from Pytorch execution for copying to PsyNeuLink at end of learn().
  - PytorchMechanismWrapper:
      - .value -> .output
      - add .input
  - add/rename attributes:
    - PytorchCompositionWrapper:
      - retained_outputs
      - retained_targets
      - retained_losses
      - _nodes_to_execute_after_gradient_calc
    - PytorchMechanismWrapper:
      - value -> output
      - input
  - add methods:
    - synch_with_psyneulink(): centralize copying of params and values to pnl using methods below
    - copy_node_variables_to_psyneulink(): centralize updating of node (mech & comp) variables in PNL
    - copy_node_values_to_psyneulink(): centralize updating of node (mech & comp) values in PNL
    - copy_results_to_psyneulink(): centralize updating of autodiffcomposition.results
    - retain_in_psyneulink(): centralize tracking of pytorch results in PNL using methods below
    - retain_torch_outputs: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn()
    - retain_torch_targets: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn()
    - retain_torch_losses: keeps record of losses and copies to AutodiffComposition.pytorch_losses at end of call to learn()

• pytorchEMcompositionwrapper.py
  - store_memory():  
     - implement single call to linalg over memory
  - only execute storage_node after last optimization_rep

• keywords.py
  - implement LearningScale keywords class

  • AutoAssociativeProjection:
    make dependent on MaskedMappingProjection in prep for allowing lcamechanism to modulate auto/hetero parameters

* fix Literals import

• Factorize scripts into:
  - ScriptControl.py
  - TestParams.py
  - [MODEL].py

---------

Co-authored-by: jdcpni <pniintel55>
  • Loading branch information
jdcpni authored Aug 12, 2024
1 parent 310afb1 commit 7047302
Show file tree
Hide file tree
Showing 34 changed files with 2,033 additions and 2,298 deletions.

This file was deleted.

This file was deleted.

This file was deleted.

500 changes: 0 additions & 500 deletions Scripts/Models (Under Development)/EGO/EGO Model - MDP OLD.py

This file was deleted.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
DECLAN Params: **************************************************************************
√ episodic_lr = 1 # learning rate for the episodic pathway
√ temperature = 0.1 # temperature for EM retrieval (lower is more argmax-like)
√ n_optimization_steps = 10 # number of update steps
sim_thresh = 0.8 # threshold for discarding bad seeds -- can probably ignore this for now
Filter runs whose context representations are too uniform (i.e. not similar to "checkerboard" foil)
May need to pad the context reps because there will be 999 reps
def filter_run(run_em, thresh=0.8):
foil = np.zeros([4,4])
foil[::2, ::2] = 1
foil[1::2, 1::2] = 1
run_em = run_em.reshape(200, 5, 11).mean(axis=1)
mat = cosine_similarity(run_em, run_em)
vec = mat[:160, :160].reshape(4, 40, 4, 40).mean(axis=(1, 3)).ravel()
return cosine_similarity(foil.reshape(1, -1), vec.reshape(1, -1))[0][0]
# Stack the model predictions (should be 999x11), pad with zeros, and reshape into trials for averaging.
em_preds = np.vstack([em_preds, np.zeros([1,11])]).reshape(-1,5,11)
# Stack the ground truth states (should be 999x11), pad with zeros, and reshape into trials for averaging.
ys = np.vstack([data_loader.dataset.ys.cpu().numpy(), np.zeros([1,11])]).reshape(-1,5,11)
# compute the probability as a performance metric
def calc_prob(em_preds, test_ys):
em_preds, test_ys = em_preds[:, 2:-1, :], test_ys[:, 2:-1, :]
em_probability = (em_preds*test_ys).sum(-1).mean(-1)
trial_probs = (em_preds*test_ys)
return em_probability, trial_probs
Calculate the retrieval probability of the correct response as a performance metric (probs)
probs, trial_probs = calc_prob(em_preds, test_ys)
"""
from psyneulink.core.llvm import ExecutionMode
from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL

model_params = dict(

# Names:
name = "EGO Model CSW",
state_input_layer_name = "STATE",
previous_state_layer_name = "PREVIOUS STATE",
context_layer_name = 'CONTEXT',
em_name = "EM",
prediction_layer_name = "PREDICTION",

# Structural
state_d = 11, # length of state vector
previous_state_d = 11, # length of state vector
context_d = 11, # length of context vector
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
concatenate_keys = False,
# concatenate_keys = True,

# environment
# curriculum_type = 'Interleaved',
curriculum_type = 'Blocked',
# num_stims = 100, # Integer or ALL
num_stims = ALL, # Integer or ALL

# Processing
integration_rate = .69, # rate at which state is integrated into new context
# state_weight = 1, # weight of the state used during memory retrieval
# context_weight = 1, # weight of the context used during memory retrieval
state_weight = .5, # weight of the state used during memory retrieval
context_weight = .5, # weight of the context used during memory retrieval
normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
# normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
learn_field_weights = False,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.MSE,
learning_rate = .5,
# num_optimization_steps = 1,
num_optimization_steps = 10,
synch_weights = RUN,
synch_values = RUN,
synch_results = RUN,
# execution_mode = ExecutionMode.Python,
execution_mode = ExecutionMode.PyTorch,
device = CPU,
# device = MPS,
)
#endregion
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@
MEMORY_CAPACITY = 5
CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script
DISPLAY_MODEL = ( # Only one of the following can be uncommented:
None # suppress display of model
# {} # show simple visual display of model
# None # suppress display of model
{} # show simple visual display of model
# {'show_node_structure': True} # show detailed view of node structures and projections
)
RUN_MODEL = True # True => run the model
Expand Down Expand Up @@ -404,7 +404,7 @@ def construct_model(model_name:str=MODEL_NAME,
model = construct_model()
assert 'DEBUGGING BREAK POINT'
# print(model.scheduler.consideration_queue)
# gs.output_graph_image(model.scheduler.graph, 'EGO_comp-scheduler.png')
# gs.output_graph_image(model.scheduler.graph, 'show_graph OUTPUT/EGO_comp-scheduler.png')

if DISPLAY_MODEL is not None:
if model:
Expand Down
Loading

0 comments on commit 7047302

Please sign in to comment.