diff --git a/igfold/model/IgFold.py b/igfold/model/IgFold.py index c6c7c3a..8cde62b 100644 --- a/igfold/model/IgFold.py +++ b/igfold/model/IgFold.py @@ -314,8 +314,8 @@ def forward( cum_seq_lens = np.cumsum([0] + seq_lens) for sl_i, sl in enumerate(seq_lens): align_mask_ = align_mask.clone() - align_mask_[:, :cum_seq_lens[sl_i]] = False - align_mask_[:, cum_seq_lens[sl_i + 1]:] = False + align_mask_[:, :4*cum_seq_lens[sl_i]] = False + align_mask_[:, 4*cum_seq_lens[sl_i + 1]:] = False res_batch_mask_ = res_batch_mask.clone() res_batch_mask_[:, :cum_seq_lens[sl_i]] = False res_batch_mask_[:, cum_seq_lens[sl_i + 1]:] = False @@ -485,4 +485,4 @@ def gradient_refine( output.coords = coords output.prmsd = prmsd - return output \ No newline at end of file + return output