Skip to content

Commit

Permalink
Merge pull request #32 from gbouras13/dev
Browse files Browse the repository at this point in the history
v0.1.4
  • Loading branch information
gbouras13 authored Mar 26, 2024
2 parents 1766024 + 3e78d33 commit ac2716e
Show file tree
Hide file tree
Showing 18 changed files with 959 additions and 60 deletions.
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# History

0.1.4 (2024-03-26)
------------------

* Fixes #31 issue with older Pharokka genbank input (prior to v1.5.0) that lacked 'transl_table' field
* All Pharokka genbank input prior to v1.5.0 will be transl_table 11 (it is before pyrodigal-gv was added)
* Fixes genbank parsing bug that would occur if the ID/locus tag of the features in the inout genbank were longer than 54 characters

0.1.3 (2024-03-19)
------------------

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["setuptools>=61.0", "wheel>=0.37.1"]
[project]
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata/
name = "phold"
version = "0.1.3" # change VERSION too
version = "0.1.4" # change VERSION too
description = "Phage Annotations using Protein Structures"
readme = "README.md"
requires-python = ">=3.8, <3.12"
Expand Down
39 changes: 19 additions & 20 deletions src/phold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@
from loguru import logger
from pycirclize.parser import Genbank


from phold.plot.plot import create_circos_plot

from phold.databases.db import install_database, validate_db
from phold.features.create_foldseek_db import generate_foldseek_db_from_aa_3di
from phold.features.predict_3Di import get_T5_model
from phold.features.query_remote_3Di import query_remote_3di
from phold.plot.plot import create_circos_plot
from phold.subcommands.compare import subcommand_compare
from phold.subcommands.predict import subcommand_predict
from phold.utils.constants import DB_DIR
from phold.utils.util import (
begin_phold,
clean_up_temporary_files,
end_phold,
get_version,
print_citation,
)
from phold.utils.validation import check_dependencies, instantiate_dirs, validate_input
from phold.utils.util import (begin_phold, clean_up_temporary_files, end_phold,
get_version, print_citation)
from phold.utils.validation import (check_dependencies, instantiate_dirs,
validate_input)

log_fmt = (
"[<green>{time:YYYY-MM-DD HH:mm:ss}</green>] <level>{level: <8}</level> | "
Expand Down Expand Up @@ -454,7 +448,7 @@ def predict(
@click.option(
"--filter_pdbs",
is_flag=True,
help="Flag that creates a copy of the PDBs with matching record IDs found in the GenBank. Helpful if you have a directory with lots of PDBs and want to annotate only e.g. 1 phage.",
help="Flag that creates a copy of the .pdb files with matching record IDs found in the input GenBank file. Helpful if you have a directory with lots of .pdb files and want to annotate only e.g. 1 phage.",
)
@common_options
@compare_options
Expand Down Expand Up @@ -693,9 +687,14 @@ def proteins_predict(
)
@click.option(
"--pdb_dir",
help="Path to directory with pdbs. The FASTA headers need to match names of the pdb files",
help="Path to directory with .pdb files. The FASTA headers need to match names of the .pdb files",
type=click.Path(),
)
@click.option(
"--filter_pdbs",
is_flag=True,
help="Flag that creates a copy of the .pdb files with matching record IDs found in the input. Helpful if you have a directory with lots of .pdb files and want to annotate only some.",
)
@common_options
@compare_options
def proteins_compare(
Expand All @@ -711,6 +710,7 @@ def proteins_compare(
predictions_dir,
pdb,
pdb_dir,
filter_pdbs,
keep_tmp_files,
split,
split_threshold,
Expand Down Expand Up @@ -740,6 +740,7 @@ def proteins_compare(
"--predictions_dir": predictions_dir,
"--pdb": pdb,
"--pdb_dir": pdb_dir,
"--filter_pdbs": filter_pdbs,
"--keep_tmp_files": keep_tmp_files,
"--split": split,
"--split_threshold": split_threshold,
Expand Down Expand Up @@ -795,7 +796,7 @@ def proteins_compare(
pdb,
pdb_dir,
logdir,
filter_pdbs=False,
filter_pdbs,
split=split,
split_threshold=split_threshold,
remote_flag=False,
Expand Down Expand Up @@ -891,35 +892,33 @@ def remote(

fasta_aa: Path = Path(output) / f"{prefix}_aa.fasta"


# makes the nested dictionary {contig_id:{cds_id: cds_feature}}

for record_id, record in gb_dict.items():
cds_dict[record_id] = {}

for cds_feature in record.features:
if cds_feature.type == "CDS":
if fasta_flag is False:
cds_feature.qualifiers["translation"] = cds_feature.qualifiers["translation"][0]
cds_feature.qualifiers["translation"] = cds_feature.qualifiers[
"translation"
][0]
cds_dict[record_id][cds_feature.qualifiers["ID"][0]] = cds_feature
else:
cds_dict[record_id][cds_feature.qualifiers["ID"]] = cds_feature


## write the CDS to file
# FASTA -> takes the whole thing
# Pharokka GBK -> requires just the first entry, the GBK is parsed as a list

with open(fasta_aa, "w+") as out_f:
for contig_id, rest in cds_dict.items():

aa_contig_dict = cds_dict[contig_id]
# writes the CDS to file
for seq_id, cds_feature in aa_contig_dict.items():
out_f.write(f">{contig_id}:{seq_id}\n")
out_f.write(f"{cds_feature.qualifiers['translation']}\n")


############
# prostt5 remote
############
Expand Down
6 changes: 4 additions & 2 deletions src/phold/features/create_foldseek_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@ def generate_foldseek_db_from_pdbs(
no_pdb_cds_ids = []

for id in sequences_aa.keys():
cds_id = id.split(":")[1]
# record_id = id.split(":")[0]
# in case the header has a colon in it - this will cause a bug if so
cds_id = id.split(":")[1:]
cds_id = ":".join(cds_id).strip()

# record_id = id.split(":")[0]
# this is potentially an issue if a contig has > 9999 AAs
# need to fix with Pharokka possibly. Unlikely to occur but might!
# enforce names as '{cds_id}.pdb'
Expand Down
4 changes: 3 additions & 1 deletion src/phold/features/predict_3Di.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def get_T5_model(
device = torch.device("cpu")
dev_name = "cpu"
if cpu is not True:
logger.warning("No available GPU was found, but --cpu was not specified")
logger.warning(
"No available GPU was found, but --cpu was not specified"
)
logger.warning("ProstT5 will be run with CPU only")

# logger device only if the function is called
Expand Down
9 changes: 6 additions & 3 deletions src/phold/features/predict_3Di_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForTokenClassification, T5EncoderModel, T5Tokenizer
from transformers import (DataCollatorForTokenClassification, T5EncoderModel,
T5Tokenizer)
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers.models.t5.modeling_t5 import (T5Config, T5PreTrainedModel,
T5Stack)
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)

from phold.features.predict_3Di import write_predictions
from phold.utils.constants import FINETUNE_DIR
Expand Down
4 changes: 3 additions & 1 deletion src/phold/features/query_remote_3Di.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from loguru import logger


def query_remote_3di(cds_dict: Dict[str, dict], fasta_3di: Path, fasta_flag: bool) -> None:
def query_remote_3di(
cds_dict: Dict[str, dict], fasta_3di: Path, fasta_flag: bool
) -> None:
"""
Query remote Foldseek ProstT5 server for 3Di predictions of amino acid sequences and write to file.
Expand Down
14 changes: 11 additions & 3 deletions src/phold/io/handle_genbank.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,17 @@ def write_genbank(
else:
# because for some reason when parsing the pharokka genbank, it is a list, fasta it is not
if fasta_flag is True:
transl_table = cds_feature.qualifiers["transl_table"]
try:
transl_table = cds_feature.qualifiers["transl_table"]
except:
# for older pharokka input before v1.5.0
transl_table = "11"
else:
transl_table = cds_feature.qualifiers["transl_table"][0]
try:
transl_table = cds_feature.qualifiers["transl_table"][0]
except:
# for older pharokka input before v1.5.0
transl_table = "11"

# to reverse the start and end coordinates for output tsv + fix genbank 0 index start relative to pharokka
if cds_feature.location.strand == -1: # neg strand
Expand All @@ -259,7 +267,7 @@ def write_genbank(

if fasta_flag is True:
cds_id = cds_feature.qualifiers["ID"]
else: # because for some reason when parsing the pharokka genbank, it is a list
else: # because for some reason when parsing the pharokka genbank, it is a list
cds_id = cds_feature.qualifiers["ID"][0]

cds_info = {
Expand Down
12 changes: 6 additions & 6 deletions src/phold/plot/plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from pathlib import Path
from typing import List, Dict
from typing import Dict, List

from loguru import logger
from pycirclize import Circos
from pycirclize.parser import Genbank
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np
from Bio import SeqUtils
from Bio.Seq import Seq
from Bio.SeqFeature import SeqFeature
from loguru import logger
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from pycirclize import Circos
from pycirclize.parser import Genbank


def create_circos_plot(
Expand Down
21 changes: 16 additions & 5 deletions src/phold/results/topfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ def get_topfunctions(
"envhog_" + foldseek_df.loc[mask, "tophit_protein"]
)

foldseek_df["phrog"] = foldseek_df["phrog"].astype("str")
# strip off efam
mask = foldseek_df["phrog"].str.startswith("efam_")
foldseek_df.loc[mask, "phrog"] = foldseek_df.loc[mask, "phrog"].str.replace(
"efam_", ""
)
# no need to add it on to protein - already done

foldseek_df["phrog"] = foldseek_df["phrog"].astype("str")
# read in the mapping tsv
phrog_annot_mapping_tsv: Path = Path(database) / "phold_annots.tsv"
phrog_mapping_df = pd.read_csv(phrog_annot_mapping_tsv, sep="\t")
Expand Down Expand Up @@ -167,10 +173,15 @@ def weighted_function(group: pd.DataFrame) -> pd.DataFrame:
value / total_functional_bitscore, 3
)

top_bitscore_function = max(
weighted_counts_normalised, key=weighted_counts_normalised.get
)
top_bitscore_perc = max(weighted_counts_normalised.values())
# error where weighted_counts_normalised was empty for maxseqs = 10000
if weighted_counts_normalised:
top_bitscore_function = max(
weighted_counts_normalised, key=weighted_counts_normalised.get
)
top_bitscore_perc = max(weighted_counts_normalised.values())
else:
top_bitscore_function = "unknown function"
top_bitscore_perc = 0

d = {
"function_with_highest_bitscore_proportion": [top_bitscore_function],
Expand Down
11 changes: 5 additions & 6 deletions src/phold/subcommands/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from loguru import logger

from phold.features.create_foldseek_db import (
generate_foldseek_db_from_aa_3di,
generate_foldseek_db_from_pdbs,
)
generate_foldseek_db_from_aa_3di, generate_foldseek_db_from_pdbs)
from phold.features.run_foldseek import create_result_tsv, run_foldseek_search
from phold.features.split_3Di import split_3di_fasta_by_prob
from phold.io.handle_genbank import write_genbank
from phold.io.sub_db_outputs import create_sub_db_outputs
from phold.results.topfunction import calculate_topfunctions_results, get_topfunctions
from phold.results.topfunction import (calculate_topfunctions_results,
get_topfunctions)


def subcommand_compare(
Expand Down Expand Up @@ -90,15 +89,15 @@ def subcommand_compare(
if fasta_flag is False:
if cds_feature.type == "CDS":
# update DNA, RNA and nucleotide metabolism from pharokka as it is broken as of 1.6.1
if "DNA" in cds_feature.qualifiers["function"][0] :
if "DNA" in cds_feature.qualifiers["function"][0]:
cds_feature.qualifiers["function"][
0
] = "DNA, RNA and nucleotide metabolism"
cds_feature.qualifiers["function"] = [
cds_feature.qualifiers["function"][0]
] # Keep only the first element
# moron, auxiliary metabolic gene and host takeover as it is broken as of 1.6.1
if "moron" in cds_feature.qualifiers["function"][0] :
if "moron" in cds_feature.qualifiers["function"][0]:
cds_feature.qualifiers["function"][
0
] = "moron, auxiliary metabolic gene and host takeover"
Expand Down
15 changes: 12 additions & 3 deletions src/phold/subcommands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,18 @@ def subcommand_predict(
cds_feature.qualifiers["translation"] = cds_feature.qualifiers[
"translation"
][0]
cds_dict[record_id][
cds_feature.qualifiers["ID"][0]
] = cds_feature

# for really long CDS IDs (over 54 chars), a space will be introduced
# this is because the ID will go over a second line
# weird bug noticed it on the Mgnify contigs annotated with Pharokka

cds_id = cds_feature.qualifiers["ID"][0]
if len(cds_id) >= 54:
# Remove all spaces from the string
cds_id = cds_id.replace(" ", "")

cds_dict[record_id][cds_id] = cds_feature

else:
cds_dict[record_id][cds_feature.qualifiers["ID"]] = cds_feature

Expand Down
2 changes: 1 addition & 1 deletion src/phold/utils/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.3
0.1.4
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""


def pytest_addoption(parser):
parser.addoption("--gpu_available", action="store_true")
parser.addoption("--run_remote", action="store_true")
parser.addoption("--threads", action="store", default=1)

Loading

0 comments on commit ac2716e

Please sign in to comment.