Skip to content

Commit

Permalink
improved emission probability calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
mchernys committed Nov 28, 2024
1 parent c153f2e commit 30c7dae
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions src/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct FullHMM <: HMM
n::Int64 # Number of reference sequences
L::Int64 # Length of reference sequences
K::Int64 # Number of mutation rates
S::Matrix{Float64} # Reference sequences
S::Matrix{UInt8} # Reference sequences
switch_probability::Float64 # Probability of switching to a different reference sequence at each site.
end

Expand All @@ -47,23 +47,25 @@ initialstate(hmm::HMM) = 1 / hmm.N
# transition_probability
a(samestate::Bool, hmm::HMM) = samestate ? 1 - hmm.switch_probability : hmm.switch_probability / (hmm.N - 1)

function get_b(hmm_obs::UInt8, obs::UInt8, same_obs_prob::Float64, diff_obs_prob::Float64)
if (obs == 0x05) | (obs == 0x06)
return 1.0
else
if hmm_obs == obs
return same_obs_prob
else
return diff_obs_prob
end
end
end

# calculate all observation probabilities for an observation vector
# uses more memory but a bit less time than a function
function get_bs(hmm::ApproximateHMM, O::Vector{UInt8}, mutation_probabilities::Vector{Float64})
b = Matrix{Float64}(undef, hmm.N, hmm.L)
@inbounds for i in 1:hmm.N # states (reference sequences)
same_obs_prob = 1 - mutation_probabilities[i]
diff_obs_prob = mutation_probabilities[i] / 5
for j in 1:hmm.L # timepoints
hmm_obs = hmm.S[i, j]
if O[j] == 6
prob = 1
elseif hmm_obs == O[j]
prob = same_obs_prob
else
prob = diff_obs_prob
end
b[i, j] = prob
b[i, j] = get_b(hmm.S[i, j], O[j], 1 - mutation_probabilities[i], mutation_probabilities[i] / 3)
end
end
return b
Expand All @@ -74,19 +76,11 @@ function get_bs(hmm::FullHMM, O::Vector{UInt8}, mutation_probabilities::Vector{F
ref2stateindices = stateindicesofref.(1:hmm.n, Ref(hmm))
@inbounds for i in 1:hmm.K # mutation rates
same_obs_prob = 1 - mutation_probabilities[i]
diff_obs_prob = mutation_probabilities[i] / 5
diff_obs_prob = mutation_probabilities[i] / 3
for j in 1:hmm.n # references
ind = ref2stateindices[j][i]
for k in 1:hmm.L # timepoints
hmm_obs = hmm.S[j, k]
obs = O[k]
b[ind, k] = if obs == 6
1.0
elseif hmm_obs == obs
same_obs_prob
else
diff_obs_prob
end
b[ind, k] = get_b(hmm.S[j, k], O[k], same_obs_prob, diff_obs_prob)
end
end
end
Expand Down

0 comments on commit 30c7dae

Please sign in to comment.