-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Branch length estimation #50
base: refactor
Are you sure you want to change the base?
Conversation
…ngths to conform to API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR! I left you several comments that I hope are helpful to you. All tests pass, and I think the code is very well written.
Let me know when you've addressed these comments and I'll be happy to take another pass.
import copy | ||
from typing import List, Tuple | ||
|
||
import cvxpy as cp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to add this to the setup.py, my tests break because this wasn't installed during make install
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
QQ: I noticed many of the requirements have a minimum version requirement, e.g. numpy > 1.17
, is there a reason for that? I was thinking to just add cvxpy
without a version requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having version requirements is helpful for a couple of reasons -
(1) Sometimes, APIs are different between versions (e.g., networkx changes the way it calls some functionality in versions <2 and >2, I believe)
(2) Sometimes, earlier versions of packages throw warnings (this is especially relevant if, for example, a particular package relies on another package that has updated and is deprecating some feature). Requiring a minimal version (or maximal version, for that matter) may help you get rid of warnings being thrown around.
In general, it's likely that having an outdated version of a particular package will be fine and the worst thing that will happen is that a warning might be thrown. However, I think it's good for software reproducibility so that when a new user comes along they can create a very similar environment to the developer. Hope this makes sense!
log_likelihood: The log-likelihood of the training data under the | ||
estimated model. | ||
log_loss: The log-loss of the training data under the estimated model. | ||
This is the log likhelihood plus the regularization terms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo -> "likelihood"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
||
def estimate_branch_lengths(self, tree: Tree) -> None: | ||
r""" | ||
See base class. The only caveat is that this method raises if it fails |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incomplete sentence: "this method raises [an error?]..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've always used "raises" for short to mean that it "raises an error", but it might not be proper english.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah - you might have more software engineering lingo than I do! If it's not too much of a problem, let's be more verbose here in the docstring.
Can you also add a Raises
section to the docstring (ala Args
and Returns
)?
def dfs(node: int, tree: Tree): | ||
node_state = tree.get_state(node) | ||
for child in tree.children(node): | ||
# Compute the state of the child | ||
child_state = "" | ||
edge_length = tree.get_age(node) - tree.get_age(child) | ||
# print(f"{node} -> {child}, length {edge_length}") | ||
assert edge_length >= 0 | ||
for i in range(num_characters): | ||
# See what happens to character i | ||
if node_state[i] != "0": | ||
# The character has already mutated; there in nothing | ||
# to do | ||
child_state += node_state[i] | ||
continue | ||
else: | ||
# Determine if the character will mutate. | ||
mutates = ( | ||
np.random.exponential(1.0 / mutation_rate) | ||
< edge_length | ||
) | ||
if mutates: | ||
child_state += "1" | ||
else: | ||
child_state += "0" | ||
tree.set_state(child, child_state) | ||
dfs(child, tree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment - a depth-first traversal function already exists in networkx.
But, if you want to do this a special way, I think this belongs as a method for the Tree class, not as a nested function here. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should look into the function from networkx that you mention. What exactly does it do? Can you please point me to it?
We need to do two things in overlay_lineage_tracing_data
: (1) iterate over all edges down the tree (the tree traversal) (2) do some custom computation on each edge as we visit it. If we were to any of this behavior to the tree class I guess is would be (1)? IMO (2) should rest in the LineageTracingSimulator
since it is something that varies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's a good point. The action items at each internal node in the DF traversal will change so it's not necessarily a general method for the tree class.
However, having a DFS method in the class that just returns nodes in post-order might be worthwhile and could cut down on some duplicated code in the other classes that operate on the CassiopeiaTree
object. I'll add this a todo in the class.
np.random.exponential(1.0 / mutation_rate) | ||
< edge_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could be misremembering, but I believe np.random.exponential
already inverts the scale parameter (i.e. np.random.exponential(b)
draws from 1/b * exp(x/b)
). [You can read the docs here: https://numpy.org/doc/stable/reference/random/generated/numpy.random.exponential.html -- indeed, it looks like the function inverts the scale parameter for you]
If this is the case, do you want to pass instead use np.random.exponential(mutation_rate)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.random.exponential
uses the scale parameterization, so we should invert the rate as in np.random.exponential(1.0 / mutation_rate)
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry - I'm a bit confused. If the density function you want is lambda * exp(x*lambda)
(where lambda
is the mutation rate) then you are good to go. O.w. you should just pass in lambda
to the exponential density function.
cassiopeia/tools/tree.py
Outdated
def reconstruct_ancestral_states(self): | ||
r""" | ||
Reconstructs ancestral states with maximum parsimony. | ||
""" | ||
root = self.root() | ||
|
||
def dfs(v: int) -> None: | ||
children = self.children(v) | ||
n_children = len(children) | ||
if n_children == 0: | ||
return | ||
for child in children: | ||
dfs(child) | ||
children_states = [self.get_state(child) for child in children] | ||
n_characters = len(children_states[0]) | ||
state = "" | ||
for character_id in range(n_characters): | ||
states_for_this_character = set( | ||
[ | ||
children_states[i][character_id] | ||
for i in range(n_children) | ||
] | ||
) | ||
if len(states_for_this_character) == 1: | ||
state += states_for_this_character.pop() | ||
else: | ||
state += "0" | ||
self.set_state(v, state) | ||
if v == root: | ||
# Reset state to all zeros! | ||
self.set_state(v, "0" * n_characters) | ||
|
||
dfs(root) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Richard wrote a function that does this in cassiopeia.solver.solver_utilities
. You might just want to take that because it takes into account missing data, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right! I just found it easier to code this up myself then to figure out how to get the right inputs to annotate_ancestral_characters
, because e.g. the Tree class I am using holds character vectors as a str
, not as a List[int]
:( Probably a bad design choice of mine? This is one thing I wanted to discuss! I left the character vector as a str
because it met all my needs, and changing it to List[int]
meant changing the code prematurely. (We definitely want to have only one implementation of parsimony reconstruction when we merge, the question is what's the API/where it lives.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! This is definitely something out of scope of your branch length estimator.
But, yes I feel that using lists to represent the character states of a sample is more desirable. One major reason for this is the separation of characters in a string is not always intuitive (e.g., if you have double-digit or triple-digit state representations). One can get around this by setting a delimiter (e.g., the |
symbol) but if you always end up splitting the string to change something and then rejoining, why not just ditch the string data structure altogether for this?
… cuts_p <= x <= cuts_v
I migrated my branch length estimation code here, under
cassiopeia/tools
. This PR should be see as a big draft to get discussions started around APIs. For example, note that my APIs currently depend on aTree
class, rather than onnetworkx.DiGraph
.Note also that my code is totally self-contained right now, segregated from any utilities that Cassiopeia might already provide, such as ancestral state reconstruction via maximum parsimony (I have a method for this in the
Tree
class which I just implemented de-novo).