Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify / Add Tests of Log-Probability Calculation #149

Merged
merged 49 commits into from
Aug 29, 2023

Conversation

gordonkoehn
Copy link
Collaborator

@gordonkoehn gordonkoehn commented Aug 14, 2023

Verifying that the log probability is calculated correctly.

The calculation is split into the calculation of

  • mutation likelihood
    $$P(D_{ij} | A(T)_{ik})$$
  • log mutation likelihood
    $$\log P(D_{ij} | A(T)_{ik})$$
  • log-probability
    $$\sum_{j=1}^m \log \sum_{k=1}^{n+1} \exp \sum_{j=1}^{n} \log P(D_{ij} | A(T)_{ik})$$

@gordonkoehn
Copy link
Collaborator Author

I went through the mutation likelihood calculations - there are examples and exact by-hand calculations for unit tests. This is likely to be correct.

@gordonkoehn
Copy link
Collaborator Author

log mutation likelihood is also just element-wise calculation of the logarithm - nothing error prone here

@gordonkoehn
Copy link
Collaborator Author

The calculation of the log-probability used jnp.einsum .... this may be a place where we could some over the wrong direction.

@gordonkoehn
Copy link
Collaborator Author

Do a manual calculation of log prob with two cells 3 mutations on paper.

@gordonkoehn
Copy link
Collaborator Author

gordonkoehn commented Aug 15, 2023

I manually calculated the log probability for an error-free mutation matrix, with 2 cells and 3 mutations.

The result agrees with the one by our current implementation up to at least 10 decimals.

See test: test_logprobability_fn_exact_m2n3

ManualLogProb.pdf

This should be a sufficient test that we sum over the correct dimensions.

I am confident that the current implementation does what I do on paper here.

My next biggest doubt is whether I do understand SCITE correctly.
I remember struggling quite a bit to figure this out, largely my dimensions.

@pawel-czyz Could you perhaps go through these calculations to verify my understanding against yours?

@gordonkoehn gordonkoehn marked this pull request as ready for review August 15, 2023 13:41
@pawel-czyz
Copy link
Member

pawel-czyz commented Aug 15, 2023

Impressive work! I looked at your derivation and I didn't spot a mistake. However, I'm far from being an oracle with perfect recall 😉

Let me tell a story from the good ol' days, when I was still young and interested in competitive programming.

We had to construct fast algorithms in C++. When I say fast, I mean really fast. Optimizations were endless and we usually had time pressure of several hours, so it was easy to introduce a programming bug.

We always checked them on simple examples, then hoped they would work on the tests which the organisers used to score our programs (not available to us until the competition has finished), and moved to the next problem.

I hate to say that, but annoyingly often they scored $\approx 0$ points. It turns out that the best programmers actually had a different habit.

They first generated a lot of artificial data sets matching the problem description (some of them were supposed to be tricky, e.g., to test for some off-by-one errors) and implemented the simplest (and usually unacceptably slow!) baseline possible. Then, they had the baseline running while they were coding the fast (and complex) solution. It often turned out that (a) there existed some rare ("pathological") cases not covered by the simple, calculable by hand, tests and (b) majority of the tests used to evaluate the algorithm by the organizers consisted of these tricky cases (i.e., although the tricky tests consist of a small fraction of all tests possible to generate, organizers have made sure that these were 70-90% of the tests from which one's score is calculated).

In our case we may face a similar situation: although on most trees and cells (and all easy to calculate by hand) the fast algorithm is working, perhaps there exist some pathological examples on which it may give wrong answer.
Our population of interest ("tests that matter for the final score in the competition") consists of exactly these rare pathological examples (trees that are different, but seem to give the same likelihood), so another implementation, using as simple and robust (and possibly slow) techniques would be good to compare to.

@pawel-czyz
Copy link
Member

Another validation technique came to my mind: for the data set with 5 mutations 200+ cells, small noise and two trees differing in 0-2 edge (or another pair of two trees with minor differences) giving the same total loglikelihood.

What are the loglikelihoods of individual cells? Perhaps this can pinpoint if there's an issue...

@pawel-czyz
Copy link
Member

pawel-czyz commented Aug 21, 2023

Surprisingly, both log-prob functions pass this test. I can it for quite a few small trees.

I like this test, it's a good sanity check!

The scales are indeed worrying. There are two things which may help resolve the issue:

  • An implementation in terms of very, very small functions, so that everything is well-tested at each phase. (I.e., (true tree, attachment_index) -> true mutations vector of one cell as one function (true mutations vector, rates, noisy mutations vector -> loglikelihood and marginalisation over attachments all in different functions).
  • Perhaps it'd be good to bring the normalization terms in the loglikelihood (that is, the probability of cell attachment is $1/n_\text{nodes}$, rather than $1$). Then, $\sum_\text{data} P(\text{data} \mid \text{tree}, \text{rates} ) =1$. For e.g., 1 cells and 10 mutations it's $2^{10}$ options for the data vector, so it should be tractable to enumerate over all of them and check if the sum is right.

@gordonkoehn
Copy link
Collaborator Author

@pawel-czyz

This investigation was a rollercoaster. Here is the summary. I implemented two partially and completely orthogonal versions of the log_probability_fn, the latter uses only loops and vectors.

I have validated it all by hand for small trees and am 99% certain that the SLOW orthogonal implementation does how I understand SCITE works. - Put it this way, I am less certain I understand SCITE correctly.

Besides the orthogonal test between all three implementations, I have implemented tests for:

  1. Do we score a true tree higher, given it is identifiable (many cells, no noise), than a tree one MCMC move away?
  2. Do we score a tree higher under the perfect mutation matrix compared to the noisy one.

For small trees and cell numbers, the world is beautiful. - all tests check out.

I have only run the orthogonal test for many cells so far and the picture is confusing to me. I check their outputs to 10^-6 tolerance. Is that to rigorous for jax / numpy / base python. Have no good foundation for this tolerance value. What shall we allow?

For few cells and mutations all implementations agree. For 20 mutations and 200 cells, they don't. Their values aren't far off.

I am uncertain if this is just the normal trouble with computers or if this may be the reason for the unidentifiablility problem we observed for 200 cells and 5 mutations.

I'll investigate.

@gordonkoehn gordonkoehn requested review from pawel-czyz and removed request for pawel-czyz August 21, 2023 20:52
@pawel-czyz
Copy link
Member

pawel-czyz commented Aug 22, 2023

This investigation was a rollercoaster. Here is the summary. I implemented two partially and completely orthogonal versions of the log_probability_fn, the latter uses only loops and vectors.

Cool work!

I have validated it all by hand for small trees and am 99% certain that the SLOW orthogonal implementation does how I understand SCITE works. - Put it this way, I am less certain I understand SCITE correctly.

Let's discuss it with a whiteboard, then 🙂

Besides the orthogonal test between all three implementations, I have implemented tests for:

Do we score a true tree higher, given it is identifiable (many cells, no noise), than a tree one MCMC move away?
Do we score a tree higher under the perfect mutation matrix compared to the noisy one.
For small trees and cell numbers, the world is beautiful. - all tests check out.

Great!

I have only run the orthogonal test for many cells so far and the picture is confusing to me. I check their outputs to 10^-6 tolerance. Is that to rigorous for jax / numpy / base python. Have no good foundation for this tolerance value.
What shall we allow?

Log-probability up to $10^{-6}$ is definitely good enough to me! (The probability not really – probabilities of most trees should be close to 0.)

For few cells and mutations all implementations agree. For 20 mutations and 200 cells, they don't. Their values aren't far off.

I am uncertain if this is just the normal trouble with computers or if this may be the reason for the unidentifiablility problem we observed for 200 cells and 5 mutations.
I'll investigate.

I 've added some comments on how to make it more numerically stable 🙂

src/pyggdrasil/tree_inference/_logprob.py Show resolved Hide resolved
tests/tree_inference/test_logprob.py Show resolved Hide resolved
tests/tree_inference/test_logprob.py Outdated Show resolved Hide resolved
tests/tree_inference/test_logprob.py Outdated Show resolved Hide resolved
tests/tree_inference/test_logprob_validator.py Outdated Show resolved Hide resolved
tests/tree_inference/test_logprob_validator.py Outdated Show resolved Hide resolved
@gordonkoehn
Copy link
Collaborator Author

gordonkoehn commented Aug 23, 2023

  • Add test to check for different log probability after swapping nodes ?

Copy link
Collaborator Author

@gordonkoehn gordonkoehn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing from the log_prob tests, I have implemented the tree reordering to fix the issue with the calculation of the fast log_prob.

Now proposed trees will be reordered to have 0,...,N labels, compatible with the fast calculation of the log_probability_fn. For that matter I have implemented a fast _reorder_tree function based on a permutation matrix operation.

To further clarify the usage of ordered trees I have added a subclass of Tree that asserts the labels called OrderedTree.

Please let me know what you think :)

Next I'll get onto the MARK04 experiment.... performance of HUNTRESS with many trees.

src/pyggdrasil/tree_inference/_logprob.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_mcmc.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_mcmc.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_tree.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_tree.py Outdated Show resolved Hide resolved
tests/tree_inference/test_logprob.py Show resolved Hide resolved
tests/tree_inference/test_logprob.py Show resolved Hide resolved
tests/tree_inference/test_tree.py Show resolved Hide resolved
Copy link
Member

@pawel-czyz pawel-czyz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well-done! 🥳 Wonderful investigation and the fix 🙂

I wonder: how many posterior modes seem to be now? Would it be possible to re-run the experiments?

src/pyggdrasil/tree_inference/_mcmc.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_tree.py Outdated Show resolved Hide resolved
src/pyggdrasil/tree_inference/_tree.py Outdated Show resolved Hide resolved
tests/tree_inference/test_logprob.py Show resolved Hide resolved
tests/tree_inference/test_tree.py Show resolved Hide resolved
src/pyggdrasil/tree_inference/_mcmc.py Show resolved Hide resolved
@gordonkoehn gordonkoehn merged commit d95ee28 into main Aug 29, 2023
1 check passed
@gordonkoehn gordonkoehn deleted the gordon/log_prob_red branch August 29, 2023 08:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants