Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the path handling more platform agnostic. #88

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions helper_scripts/parse_multiple_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ def main(args):

import numpy as np
import os, time, gzip, json
import glob
import glob
from os import path

folder_with_pdbs_path = args.input_path
save_path = args.output_path
Expand Down Expand Up @@ -138,8 +139,7 @@ def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None):
coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
my_dict['coords_chain_'+letter]=coords_dict_chain
s += 1
fi = biounit.rfind("/")
my_dict['name']=biounit[(fi+1):-4]
my_dict['name']=path.basename(biounit)[:-4]
my_dict['num_of_chains'] = s
my_dict['seq'] = concat_seq
if s < len(chain_alphabet):
Expand Down
33 changes: 15 additions & 18 deletions protein_mpnn_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,22 @@ def main(args):


if args.path_to_model_weights:
model_folder_path = args.path_to_model_weights
if model_folder_path[-1] != '/':
model_folder_path = model_folder_path + '/'
# This adds the trailing slash if missing
model_folder_path = os.path.join(args.path_to_model_weights, '')
else:
file_path = os.path.realpath(__file__)
k = file_path.rfind("/")
file_path = os.path.dirname(os.path.realpath(__file__))
if args.ca_only:
print("Using CA-ProteinMPNN!")
model_folder_path = file_path[:k] + '/ca_model_weights/'
model_folder_path = os.path.join(file_path, 'ca_model_weights', '')
if args.use_soluble_model:
print("WARNING: CA-SolubleMPNN is not available yet")
sys.exit()
else:
if args.use_soluble_model:
print("Using ProteinMPNN trained on soluble proteins only!")
model_folder_path = file_path[:k] + '/soluble_model_weights/'
model_folder_path = os.path.join(file_path, 'soluble_model_weights', '')
else:
model_folder_path = file_path[:k] + '/vanilla_model_weights/'
model_folder_path = os.path.join(file_path, 'vanilla_model_weights', '')

checkpoint_path = model_folder_path + f'{args.model_name}.pt'
folder_for_outputs = args.out_folder
Expand Down Expand Up @@ -188,9 +186,8 @@ def main(args):
print(f'Training noise level: {noise_level_print}A')

# Build paths for experiment
base_folder = folder_for_outputs
if base_folder[-1] != '/':
base_folder = base_folder + '/'
# Add trailing slash if missing
base_folder = os.path.join(folder_for_outputs, "")
if not os.path.exists(base_folder):
os.makedirs(base_folder)

Expand Down Expand Up @@ -243,9 +240,9 @@ def main(args):
loop_c = len(fasta_seqs)
for fc in range(1+loop_c):
if fc == 0:
structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_pdb'
structure_sequence_score_file = os.path.join(base_folder, 'score_only', batch_clones[0]['name'] + f'_pdb')
else:
structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_fasta_{fc}'
structure_sequence_score_file = os.path.join(base_folder, 'score_only', batch_clones[0]['name'] + f'_fasta_{fc}')
native_score_list = []
global_native_score_list = []
if fc > 0:
Expand Down Expand Up @@ -285,7 +282,7 @@ def main(args):
elif args.conditional_probs_only:
if print_all:
print(f'Calculating conditional probabilities for {name_}')
conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
conditional_probs_only_file = os.path.join(base_folder, 'conditional_probs_only', batch_clones[0]['name'])
log_conditional_probs_list = []
for j in range(NUM_BATCHES):
randn_1 = torch.randn(chain_M.shape, device=X.device)
Expand All @@ -297,7 +294,7 @@ def main(args):
elif args.unconditional_probs_only:
if print_all:
print(f'Calculating sequence unconditional probabilities for {name_}')
unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
unconditional_probs_only_file = os.path.join(base_folder, 'unconditional_probs_only', batch_clones[0]['name'])
log_unconditional_probs_list = []
for j in range(NUM_BATCHES):
log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
Expand All @@ -314,9 +311,9 @@ def main(args):
global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
global_native_score = global_scores.cpu().data.numpy()
# Generate some sequences
ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
ali_file = os.path.join(base_folder, 'seqs', batch_clones[0]['name'] + '.fa')
score_file = os.path.join(base_folder, 'scores', batch_clones[0]['name'] + '.npz')
probs_file = os.path.join(base_folder, 'probs', batch_clones[0]['name'] + '.npz')
if print_all:
print(f'Generating sequences for: {name_}')
t0 = time.time()
Expand Down
4 changes: 2 additions & 2 deletions protein_mpnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json, time, os, sys, glob
import shutil
import numpy as np
from os import path
import torch
from torch import optim
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -177,8 +178,7 @@ def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False):
coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
my_dict['coords_chain_'+letter]=coords_dict_chain
s += 1
fi = biounit.rfind("/")
my_dict['name']=biounit[(fi+1):-4]
my_dict['name']=path.basename(biounit)[:-4]
my_dict['num_of_chains'] = s
my_dict['seq'] = concat_seq
if s <= len(chain_alphabet):
Expand Down