Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Addition of OpenMMDL Analysis Pytests #114

Open
wants to merge 2 commits into
base: Class_Pytest_addition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions openmmdl/tests/openmmdl_analysis/test_barcode_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import re
from pathlib import Path
import os
import matplotlib.pyplot as plt
import pytest
Expand Down Expand Up @@ -39,3 +40,74 @@ def test_waterids_barcode_generator(sample_dataframe_barcode_generation):
# Test the expected waterid barcode for the sample dataframe and interaction
expected_waterid_barcode = [0, 104, 105]
assert waterid_barcode == expected_waterid_barcode

def test_plot_barcodes(tmp_path):
df_all = pd.DataFrame()
barcode_plotter = BarcodePlotter(df_all)

# Test case 1: No barcode
save_path = tmp_path / "no_barcodes.png"
barcode_plotter.plot_barcodes({}, save_path)
assert not save_path.exists()

# Test case 2: Single barcode
barcode_data = {
'Barcode1': np.array([1, 1, 0, 1, 0]),
}
save_path = tmp_path / "single_barcode.png"
barcode_plotter.plot_barcodes(barcode_data, save_path)
assert save_path.exists()

# Test case 3: Multiple barcodes
barcode_data = {
"Barcode 1": np.array([1, 0, 1, 0, 1, 0]),
"Barcode 2": np.array([0, 1, 0, 1, 0, 1]),
}
save_path = tmp_path / "multiple_barcodes.png"
barcode_plotter.plot_barcodes(barcode_data, save_path)
assert save_path.exists()

def test_plot_waterbridge_piechart(sample_dataframe_barcode_generation):
barcode_plotter = BarcodePlotter(sample_dataframe_barcode_generation)
waterbridge_barcodes = [np.array([1, 0, 1, 0]), np.array([0, 0, 0, 1])]
waterbridge_interactions = ['Interaction1', 'Interaction2']
fig_type = 'png'

# Run the plotting function
barcode_plotter.plot_waterbridge_piechart(waterbridge_barcodes, waterbridge_interactions, fig_type)

# Check if files are saved in the hardcoded directory
output_dir = os.path.join("Barcodes", "Waterbridge_Piecharts")
for interaction in waterbridge_interactions:
outname_png = os.path.join(output_dir, f"{interaction}.{fig_type}")
assert os.path.exists(outname_png), f"File not found: {outname_png}"

def test_plot_barcodes_grouped():
df_all = pd.DataFrame({
'FRAME': [0, 1, 2],
'atom1_atom2_interaction': [1, 0, 1],
'atom3_atom4_interaction': [0, 1, 1],
})

interactions = ['atom1_atom2_interaction', 'atom3_atom4_interaction']
interaction_type = 'interaction'
fig_type = 'png'
barcode_plotter = BarcodePlotter(df_all)

barcode_plotter.plot_barcodes_grouped(interactions, interaction_type, fig_type)

# Expected file paths
atom2_dir = Path("Barcodes/atom2")
atom4_dir = Path("Barcodes/atom4")
total_path = Path(f"Barcodes/{interaction_type}_interactions.{fig_type}")

# Validate that directories and files were created
assert atom2_dir.exists(), f"Directory not found: {atom2_dir}"
assert atom4_dir.exists(), f"Directory not found: {atom4_dir}"

atom2_file = atom2_dir / f"atom2_{interaction_type}_barcodes.{fig_type}"
atom4_file = atom4_dir / f"atom4_{interaction_type}_barcodes.{fig_type}"

assert atom2_file.exists(), f"File not found: {atom2_file}"
assert atom4_file.exists(), f"File not found: {atom4_file}"
assert total_path.exists(), f"File not found: {total_path}"
160 changes: 160 additions & 0 deletions openmmdl/tests/openmmdl_analysis/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import pytest
import os
import MDAnalysis as mda
import mdtraj as md
from openmmdl.openmmdl_analysis.preprocessing import Preprocessing # Adjust the import as needed


@pytest.fixture
def temp_pdb_file(tmpdir):
"""Fixture to create a temporary PDB file for input."""
input_pdb = tmpdir.join("input.pdb")

# Create synthetic PDB contents with various residue names, including water and unknown
input_pdb_content = """
ATOM 1 O WAT A 1 44.338 72.730 39.423 1.00 0.00 O
ATOM 2 O WAT A 2 31.578 38.467 49.764 1.00 0.00 O
ATOM 3 O WAT A 3 38.597 28.466 49.556 1.00 0.00 O
ATOM 4 O WAT A 4 32.842 57.728 36.084 1.00 0.00 O
ATOM 5 O WAT A 5 21.918 43.049 43.250 1.00 0.00 O
ATOM 6 O WAT A 6 24.740 38.865 51.435 1.00 0.00 O
ATOM 7 O WAT A 7 29.581 48.257 50.944 1.00 0.00 O
ATOM 8 O WAT A 8 33.474 49.644 32.100 1.00 0.00 O
ATOM 9 O WAT A 9 50.718 46.583 51.141 1.00 0.00 O
ATOM 10 O * A 10 35.529 47.920 25.052 1.00 0.00 O
ATOM 11 O * A 11 42.799 36.680 51.573 1.00 0.00 O
ATOM 12 O * A 12 44.499 25.396 38.351 1.00 0.00 O
ATOM 13 O * A 13 36.184 75.754 32.555 1.00 0.00 O
ATOM 14 O * A 14 28.279 70.291 27.450 1.00 0.00 O
"""

# Write the contents to the file
input_pdb.write(input_pdb_content)

return str(input_pdb)

@pytest.fixture
def temp_pdb_lig_file(tmpdir):
"""Fixture to create a temporary PDB file with a ligand."""
input_pdb = tmpdir.join("input.pdb")

# Create synthetic PDB contents with a ligand (e.g., a simple water molecule)
input_pdb_content = """
ATOM 1 N LIG A 1 10.104 13.524 8.240 1.00 20.00 N
ATOM 2 C LIG A 1 11.104 14.524 8.240 1.00 20.00 C
ATOM 3 O LIG A 1 12.104 15.524 8.240 1.00 20.00 O
ATOM 4 H LIG A 1 13.104 16.524 8.240 1.00 20.00 H
ATOM 5 H LIG A 1 14.104 17.524 8.240 1.00 20.00 H
ATOM 6 C PRO A 2 15.104 18.524 8.240 1.00 20.00 C
ATOM 7 N PRO A 2 16.104 19.524 8.240 1.00 20.00 N
ATOM 8 O PRO A 2 17.104 20.524 8.240 1.00 20.00 O
"""

# Write the contents to the file
input_pdb.write(input_pdb_content)

return str(input_pdb)


def test_process_pdb_file(temp_pdb_file):
input_pdb = temp_pdb_file

# Create a Preprocessing instance
preprocessing = Preprocessing()

# Call the process_pdb_file method
preprocessing.process_pdb_file(input_pdb)

# Load the modified PDB file to check residue names
u = mda.Universe(input_pdb)
print(u)
ag = u.select_atoms("all")
for a in ag:
print(a)

# Check if the residues have been correctly renamed
water_residues = [atom.residue.resname for atom in u.atoms if atom.residue.resname in ["HOH", "UNK"]]
print(water_residues)

# Check that the water residues were renamed correctly
assert water_residues.count("HOH") == 9, "Expected 9 water residues (HOH)."
assert water_residues.count("UNK") == 5, "Expected 5 unknown residues (UNK)."

# Check the original residues were renamed (water and unknown)
for atom in u.atoms:
if atom.residue.resname in ["SPC", "TIP3", "TIP4", "WAT", "T3P", "T4P", "T5P"]:
assert atom.residue.resname == "HOH", f"Expected 'HOH' but got {atom.residue.resname}"
elif atom.residue.resname == "*":
assert atom.residue.resname == "UNK", f"Expected 'UNK' but got {atom.residue.resname}"

# Check if the output file exists and is not empty
assert os.path.exists(input_pdb)
assert os.path.getsize(input_pdb) > 0

def test_increase_ring_indices():
# Create a Preprocessing instance
preprocessing = Preprocessing()

# Define test inputs
ring = [1, 2, 3, 4, 5] # Example atom indices in a ring
lig_index = 10 # Ligand atom index to be added

# Call the method
result = preprocessing.increase_ring_indices(ring, lig_index)

# Define the expected result
expected_result = [11, 12, 13, 14, 15] # The ring indices after adding lig_index

# Assert the result is as expected
assert result == expected_result, f"Expected {expected_result}, but got {result}"

def test_extract_and_save_ligand_as_sdf(temp_pdb_lig_file, tmpdir):
input_pdb = temp_pdb_lig_file
output_sdf = tmpdir.join("output.sdf")
target_resname = "LIG" # The ligand residue name in the PDB file

# Create a Preprocessing instance
preprocessing = Preprocessing()

# Call the method to extract the ligand and save it as SDF
preprocessing.extract_and_save_ligand_as_sdf(input_pdb, str(output_sdf), target_resname)

# Check if the output SDF file exists and is not empty
assert os.path.exists(output_sdf)
assert os.path.getsize(output_sdf) > 0, "The output SDF file is empty."

# Check if the temporary PDB file (lig.pdb) was removed
assert not os.path.exists("lig.pdb"), "Temporary PDB file 'lig.pdb' was not removed."


def test_renumber_atoms_in_residues(temp_pdb_lig_file, tmpdir):
input_pdb = temp_pdb_lig_file
output_pdb = tmpdir.join("output.pdb")
lig_name = "LIG" # The ligand residue name in the PDB file

# Create a Preprocessing instance
preprocessing = Preprocessing()

# Call the method to renumber atoms in the ligand
preprocessing.renumber_atoms_in_residues(input_pdb, str(output_pdb), lig_name)

# Read the output PDB file
with open(output_pdb, "r") as f:
output_lines = f.readlines()

# Check that the atoms of the ligand have been renumbered correctly
renumbered_atoms = []
for line in output_lines:
if line.startswith("ATOM"):
atom_name = line[12:16].strip()
residue_name = line[17:20].strip()

# Only check atoms of the ligand
if residue_name == lig_name:
renumbered_atoms.append(atom_name)

# The atom names in the ligand should follow the pattern: N1, C1, O1, H1, H2
expected_renumbered_atoms = ['N1', 'C1', 'O1', 'H1', 'H2']

# Assert that the renumbered atom names match the expected output
assert renumbered_atoms == expected_renumbered_atoms, f"Expected {expected_renumbered_atoms}, but got {renumbered_atoms}"
Loading