Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem about update_memory #25

Open
Void-JackLee opened this issue Sep 7, 2023 · 1 comment
Open

Problem about update_memory #25

Void-JackLee opened this issue Sep 7, 2023 · 1 comment

Comments

@Void-JackLee
Copy link

Void-JackLee commented Sep 7, 2023

Hi @emalgorithm, I got some problems when reading your codes.

When memory_update_at_start=True, the msg_agg and msg_func will calc twice, before the compute_embedding and after compute_embedding. Before the compute_embedding, the get_updated_memory function will calc all nodes' memory. After the compute_embedding, update_memory function will calc positive nodes.

if self.memory_update_at_start:
    # Persist the updates to the memory only for sources and destinations (since now we have
    # new messages for them)
    self.update_memory(positives, self.memory.messages)

    assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
      "Something wrong in how the memory was updated"

    # Remove messages for the positives since we have already updated the memory using them
    self.memory.clear_messages(positives)

  unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
                                                                source_node_embedding,
                                                                destination_nodes,
                                                                destination_node_embedding,
                                                                edge_times, edge_idxs)
  unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
                                                                          destination_node_embedding,
                                                                          source_nodes,
                                                                          source_node_embedding,
                                                                          edge_times, edge_idxs)
  if self.memory_update_at_start:
    self.memory.store_raw_messages(unique_sources, source_id_to_messages)
    self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
  else:
    self.update_memory(unique_sources, source_id_to_messages)
    self.update_memory(unique_destinations, destination_id_to_messages)

The code annotation here was "Persist the updates to the memory only for sources and destinations (since now we have new messages for them)", but actually the message in this batch was update after the memory update, update_memory function was updating memory from the message in last batch. So here comes a problem that update_memory(positives, self.memory.messages) was updating positive nodes in this batch, and updated messages was from last batch. I don't understand why the code is doing this, maybe it's a bug?

I think here needs to update all nodes' memory (or record last batch's positive nodes), or update memory in get_updated_memory function directly (replace it to update_memory).

@shadow150519
Copy link

shadow150519 commented Mar 8, 2024

according to paper,because they want to memory component get grads,so they need to update memory before calculate embedding,but they can't use n-th batch to update memory due to memory leak problem, instead they use (n-1)-th batch. Only set memory_update_at_start=True, can memory component get grads and params get updated

So here comes a problem that update_memory(positives, self.memory.messages) was updating positive nodes in this batch, and updated messages was from last batch. I don't understand why the code is doing this, maybe it's a bug?

the proper order is

  1. get_updated_memory use n-1 batch events but don't really update memory, just get a copy, this extra calculation is to let memory participate in the embedding computation so that memory component can have gradients
  2. get pos node embedding
  3. use n-1 batch message to update memory
  4. convert n batch events to raw message

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants