Skip to content

Commit

Permalink
Fix noise learning
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Aug 18, 2024
1 parent c5fbebb commit 368c84e
Show file tree
Hide file tree
Showing 7 changed files with 662 additions and 199 deletions.
639 changes: 483 additions & 156 deletions notebooks/vonmises.ipynb

Large diffs are not rendered by default.

197 changes: 167 additions & 30 deletions scatrex/models/trajectory/node.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions scatrex/models/trajectory/node_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def sample_factor_weights(key, mu, log_std): # KxG

@jax.jit
def obs_weights_logp(sample, mean, log_std): # single sample, NxK
return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample)) # sum across obs and dimensions
obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (NxK)
return tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample) # sum across obs and dimensions
univ_obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (Nx1)
obs_weights_logp_val_and_grad = jax.jit(jax.vmap(jax.vmap(univ_obs_weights_logp_val_and_grad, in_axes=(0, None, None)), in_axes=(0,None,None))) # Take grad wrt to sample (NxK)
mc_obs_weights_logp_val_and_grad = jax.jit(jax.vmap(obs_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNxK

@jax.jit
Expand Down
7 changes: 2 additions & 5 deletions scatrex/ntssb/ntssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ def learn_model(self, n_epochs, seed=42, memoized=True, update_roots=True, updat
local_states = self.update_local_params(subkey, batch_idx=batch_idx, adaptive=adaptive, states=local_states, i=it,
param_names=locals_names, update_globals=update_globals, **kwargs)
if memoized:
self.update_sufficient_statistics(batch_idx=batch_idx)
self.update_sufficient_statistics(batch_idx=batch_idx)
self.update_node_params(subkey, i=it, adaptive=adaptive, memoized=memoized, **kwargs)
if update_roots:
self.update_root_node_params(subkey, memoized=memoized, batch_idx=batch_idx, adaptive=adaptive, i=it, **kwargs)
Expand Down Expand Up @@ -1479,11 +1479,8 @@ def descend(root, local_grads=None):
sum_E_log_1_psi += E_log_1_psi

# Go down
child_log_probs, child_local_grads = descend(child, local_grads=local_grads)
child_log_probs, _ = descend(child, local_grads=local_grads)
logqs.extend(child_log_probs)
if child_local_grads is not None:
for ii, grads in enumerate(list(child_local_grads)):
local_grads[ii] += grads

return logqs, local_grads # batch-sized

Expand Down
1 change: 1 addition & 0 deletions scatrex/ntssb/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def merge(self, key, moves_per_tssb=1, memoized=True, update_globals=False, n_ep
self.proposed_tree = deepcopy(self.tree)

def prune_reattach(self, key, proposed_tssb, tssb, n_tries=5, memoized=True, update_names=True, **learn_kwargs):
# TODO: instead of sampling target uniformly, sample proportionally to similarity of node states. I can pre-compute this and re-use it for sampling, so it only adds a small offset O(n^2) to the beginning of the MCMC!
changed = False
accepted = False
if tssb.n_nodes > 1:
Expand Down
8 changes: 4 additions & 4 deletions scatrex/ntssb/tssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,11 @@ def descend(root, local_grads=None):
sum_E_log_1_psi += E_log_1_psi

# Go down
child_log_probs, child_local_param_grads = descend(child, local_grads=local_grads)
child_log_probs, _ = descend(child, local_grads=local_grads)
logqs.extend(child_log_probs)
if child_local_param_grads is not None:
for i, grads in enumerate(list(child_local_param_grads)):
local_grads[i] += grads
# if child_local_param_grads is not None:
# for i, grads in enumerate(list(child_local_param_grads)):
# local_grads[i] += grads
# local_grads += child_local_param_grads

return logqs, local_grads
Expand Down
4 changes: 2 additions & 2 deletions scatrex/plotting/scatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sub_descend(sub_root, graph):
graph.add_edge(parent, child, alpha=prob, ls='--')
nx.draw_networkx_edges(graph, pos, edgelist=[(parent, child)], edge_color=sub_root['color'], alpha=prob, style='--')
if edge_labels and prob > 0.01:
nx.draw_networkx_edge_labels(graph, pos, font_color=sub_root['color'], edge_labels={(parent, child):f"{prob:.3f}"}, font_size=int(font_size/2), alpha=float(prob))
nx.draw_networkx_edge_labels(graph, pos, font_color=sub_root['color'], edge_labels={(parent, child):f"{prob:.3f}"}, font_size=int(font_size/2), alpha=float(prob), bbox=dict(alpha=0))
for child in sub_root['children']:
sub_descend(child, graph)

Expand Down Expand Up @@ -83,7 +83,7 @@ def plot_tree(tree, G = None, param_key='param', data=None, labels=True, alpha=0
nx.draw_networkx_edges(G, pos, edgelist=[(parent, node)], edge_color=tree_dict[parent]['color'], alpha=parent_probs.loc[node, parent]*alpha, **edge_options)
if edge_labels and parent_probs.loc[node, parent] > 0.01:
nx.draw_networkx_edge_labels(G, pos, edge_labels={(parent, node):f'{parent_probs.loc[node, parent]:.3f}'}, font_color=tree_dict[parent]['color'],
font_size=int(font_size/2), alpha=parent_probs.loc[node, parent]*alpha)
font_size=int(font_size/2), alpha=parent_probs.loc[node, parent]*alpha, bbox=dict(alpha=0))
else:
parent = tree_dict[node]['parent']
G.add_edge(parent, node)
Expand Down

0 comments on commit 368c84e

Please sign in to comment.