Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protein_mpnn_utils.py cleanup #20

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 42 additions & 51 deletions protein_mpnn_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import print_function
import json, time, os, sys, glob
import shutil
import json, time
from string import ascii_lowercase, ascii_uppercase

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset

import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import itertools

#A number of functions/classes are adopted from: https://github.com/jingraham/neurips19-graph-protein-design
Expand All @@ -30,12 +26,14 @@ def _S_to_seq(S, mask):
seq = ''.join([alphabet[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0])
return seq

def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None):
def parse_PDB_biounits(x, atoms: list[str] | None, chain: str | None = None):
'''
input: x = PDB filename
atoms = atoms to extract (optional)
output: (length, atoms, coords=(x,y,z)), sequence
'''
if atoms is None:
atoms = ["N", "CA", "C"]

alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
states = len(alpha_1)
Expand Down Expand Up @@ -115,54 +113,47 @@ def N_to_AA(x):
except TypeError:
return 'no_chain', 'no_chain'

def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False):
c=0
def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False) -> list[dict]:
pdb_dict_list = []
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
extra_alphabet = [str(item) for item in list(np.arange(300))]
chain_alphabet = init_alphabet + extra_alphabet


if input_chain_list:
chain_alphabet = input_chain_list

chain_alphabet = input_chain_list
else:
chain_alphabet = list(ascii_uppercase) + list(ascii_lowercase) + [str(item) for item in range(300)]

if ca_only:
sidechain_atoms = ["CA"]
else:
sidechain_atoms = ["N", "CA", "C", "O"]

pdb_dict = {}
num_chains = 0
concat_seq = []
for letter in chain_alphabet:
xyz, seq = parse_PDB_biounits(path_to_pdb, atoms=sidechain_atoms, chain=letter)
if not isinstance(xyz, str):
concat_seq.append(seq[0])
pdb_dict[f"seq_chain_{letter}"] = seq[0]

biounit_names = [path_to_pdb]
for biounit in biounit_names:
my_dict = {}
s = 0
concat_seq = ''
concat_N = []
concat_CA = []
concat_C = []
concat_O = []
concat_mask = []
coords_dict = {}
for letter in chain_alphabet:
if ca_only:
sidechain_atoms = ['CA']
coords_dict_chain = {f"CA_chain_{letter}": xyz.tolist()}
else:
sidechain_atoms = ['N', 'CA', 'C', 'O']
xyz, seq = parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
if type(xyz) != str:
concat_seq += seq[0]
my_dict['seq_chain_'+letter]=seq[0]
coords_dict_chain = {}
if ca_only:
coords_dict_chain['CA_chain_'+letter]=xyz.tolist()
else:
coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
my_dict['coords_chain_'+letter]=coords_dict_chain
s += 1
fi = biounit.rfind("/")
my_dict['name']=biounit[(fi+1):-4]
my_dict['num_of_chains'] = s
my_dict['seq'] = concat_seq
if s <= len(chain_alphabet):
pdb_dict_list.append(my_dict)
c+=1
coords_dict_chain = {
f"N_chain_{letter}": xyz[:, 0, :].tolist(),
f"CA_chain_{letter}": xyz[:, 1, :].tolist(),
f"C_chain_{letter}": xyz[:, 2, :].tolist(),
f"O_chain_{letter}": xyz[:, 3, :].tolist(),
}
pdb_dict[f"coords_chain_{letter}"] = coords_dict_chain
num_chains += 1

fi = path_to_pdb.rfind("/")
pdb_dict["name"] = path_to_pdb[(fi + 1) : -4]
pdb_dict["num_of_chains"] = num_chains
pdb_dict["seq"] = "".join(concat_seq)
if num_chains <= len(chain_alphabet):
pdb_dict_list.append(pdb_dict)

return pdb_dict_list


Expand Down