Skip to content

Commit

Permalink
new changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfl15 committed Oct 23, 2023
1 parent d0e1c79 commit fc99cce
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 327 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/runTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: pytest tests/test_utils.py
- name: test-workflow
run: pytest tests/test_workflow.py
- name: unit-tests
run: pytest -vv tests/unittests.py
- name: integration-tests
run: pytest -vv tests/integration_tests.py
run: pytest -vv tests/integration_tests.py
- name: unit-tests
run: pytest -vv tests/unittests.py
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sourmash sketch fromfile ref_paths.csv -p dna,k=31,scaled=1000,abund -o ref.sig.
python ../make_training_data_from_sketches.py --ref_file ref.sig.zip --ksize 31 --num_threads ${NUM_THREADS} --ani_thresh 0.95 --prefix 'demo_ani_thresh_0.95' --outdir ./ --force

# run YACHT algorithm to check the presence of reference genomes in the query sample (inference step)
python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads ${NUM_THREADS} --min_coverage_list 1 0.6 0.2 0.1 --out ./result.xlsx
python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads ${NUM_THREADS} --min_coverage_list 1 0.6 0.2 0.1 --out_filename result.xlsx

# convert result to CAMI profile format (Optional)
python ../srcs/standardize_yacht_output.py --yacht_output result.xlsx --sheet_name min_coverage0.2 --genome_to_taxid toy_genome_to_taxid.tsv --mode cami --sample_name 'MySample' --outfile_prefix cami_result --outdir ./
Expand Down Expand Up @@ -179,8 +179,8 @@ The most important parameter of this script is `--ani_thresh`: this is average n
| File (names starting with prefix) | Content |
| ------------------------------------- | ------------------------------------------------------------ |
| _config.json | A JSON file stores the required information needed to run the next YACHT algorithm |
| _manifest.tsv | A TSV file contains organisms and their relevant info after removing the similar ones |
| _removed_orgs_to_corr_orgas_mapping.tsv | A TSV file with two columns: removed organism names ('removed_org') and their similar genomes ('corr_orgs')|
| _manifest.tsv | A TSV file contains organisms and their relevant info after removing the similar ones.
| _rep_to_corr_orgas_mapping.tsv | A TSV file contains representative organisms and their similar organisms that have been removed |


</br>
Expand All @@ -190,7 +190,7 @@ The most important parameter of this script is `--ani_thresh`: this is average n
After this, you are ready to perform the hypothesis test for each organism in your reference database. This can be accomplished with something like:

```bash
python run_YACHT.py --json 'gtdb_ani_thresh_0.95_config.json' --sample_file 'sample.sig.zip' --num_threads 32 --keep_raw --significance 0.99 --min_coverage_list 1 0.5 0.1 0.05 0.01 --out ./result.xlsx
python run_YACHT.py --json 'gtdb_ani_thresh_0.95_config.json' --sample_file 'sample.sig.zip' --num_threads 32 --keep_raw --significance 0.99 --min_coverage_list 1 0.5 0.1 0.05 0.01 --outdir ./
```

#### Parameter
Expand All @@ -207,7 +207,8 @@ The `--min_coverage_list` parameter dictates a list of `min_coverage` which indi
| --keep_raw | keep the raw result (i.e. `min_coverage=1`) no matter if the user specifies it |
| --show_all | Show all organisms (no matter if present) |
| --min_coverage_list | a list of `min_coverage` values, see more detailed description above (default: 1, 0.5, 0.1, 0.05, 0.01) |
| --out | path to output excel result (default: './result.xlsx') |
| --out_filename | filename of output excel result (default: 'result.xlsx') |
| --outdir | the path to output directory where the results and intermediate files will be genreated |

#### Output

Expand Down
27 changes: 12 additions & 15 deletions make_training_data_from_sketches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import srcs.utils as utils
from loguru import logger
import json
import shutil
logger.remove()
logger.add(sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO")

Expand Down Expand Up @@ -49,11 +48,6 @@
path_to_temp_dir = os.path.join(outdir, prefix+'_intermediate_files')
if os.path.exists(path_to_temp_dir) and not force:
raise ValueError(f"Temporary directory {path_to_temp_dir} already exists. Please remove it or given a new prefix name using parameter '--prefix'.")
else:
# remove the temporary directory if it exists
if os.path.exists(path_to_temp_dir):
logger.warning(f"Temporary directory {path_to_temp_dir} already exists. Removing it.")
shutil.rmtree(path_to_temp_dir)
os.makedirs(path_to_temp_dir, exist_ok=True)

# unzip the sourmash signature file to the temporary directory
Expand All @@ -73,12 +67,12 @@

# Find the close related genomes with ANI > ani_thresh from the reference database
logger.info("Find the close related genomes with ANI > ani_thresh from the reference database")
multisearch_result = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir)
sig_same_genoms_dict = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir)

# remove the close related organisms: any organisms with ANI > ani_thresh
# pick only the one with largest number of unique kmers from all the close related organisms
logger.info("Removing the close related organisms with ANI > ani_thresh")
remove_corr_df, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, multisearch_result)
rep_remove_dict, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, sig_same_genoms_dict)

# write out the manifest file
logger.info("Writing out the manifest file")
Expand All @@ -87,19 +81,22 @@

# write out a mapping dataframe from representative organism to the close related organisms
logger.info("Writing out a mapping dataframe from representative organism to the close related organisms")
if len(remove_corr_df) == 0:
logger.warning("No close related organisms found.")
remove_corr_df_indicator = ""
if len(rep_remove_dict) == 0:
logger.warning("No close related organisms found. No mapping dataframe is written.")
rep_remove_df = pd.DataFrame(columns=['rep_org', 'corr_orgs'])
rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv')
rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None)
else:
remove_corr_df_path = os.path.join(outdir, f'{prefix}_removed_orgs_to_corr_orgas_mapping.tsv')
remove_corr_df.to_csv(remove_corr_df_path, sep='\t', index=None)
remove_corr_df_indicator = remove_corr_df_path
rep_remove_df = pd.DataFrame([(rep_org, ','.join(corr_org_list)) for rep_org, corr_org_list in rep_remove_dict.items()])
rep_remove_df.columns = ['rep_org', 'corr_orgs']
rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv')
rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None)

# save the config file
logger.info("Saving the config file")
json_file_path = os.path.join(outdir, f'{prefix}_config.json')
json.dump({'manifest_file_path': manifest_file_path,
'remove_cor_df_path': remove_corr_df_indicator,
'rep_remove_df_path': rep_remove_df_path,
'intermediate_files_dir': path_to_temp_dir,
'scale': scale,
'ksize': ksize,
Expand Down
23 changes: 10 additions & 13 deletions run_YACHT.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import json
import warnings
import zipfile
from pathlib import Path
warnings.filterwarnings("ignore")
from tqdm import tqdm
from loguru import logger
Expand All @@ -33,7 +32,7 @@
'Each value should be between 0 and 1, with 0 being the most sensitive (and least '
'precise) and 1 being the most precise (and least sensitive).',
required=False, default=[1, 0.5, 0.1, 0.05, 0.01])
parser.add_argument('--out', type=str, help='path to output excel file', required=False, default=os.path.join(os.getcwd(), 'result.xlsx'))
parser.add_argument('--out_filename', help='Full path of output filename', required=False, default='result.xlsx')

# parse the arguments
args = parser.parse_args()
Expand All @@ -44,13 +43,7 @@
keep_raw = args.keep_raw # Keep raw results in output file.
show_all = args.show_all # Show all organisms (no matter if present) in output file.
min_coverage_list = args.min_coverage_list # a list of percentages of unique k-mers covered by reads in the sample.
out = str(Path(args.out).absolute()) # full path to output excel file
outdir = os.path.dirname(out) # path to output directory
out_filename = os.path.basename(out) # output filename

# check if the output filename is valid
if os.path.splitext(out_filename)[1] != '.xlsx':
raise ValueError(f'Output filename {out} is not a valid excel file. Please use .xlsx as the extension.')
out_filename = args.out_filename # output filename

# check if the json file exists
utils.check_file_existence(json_file_path, f'Config file {json_file_path} does not exist. '
Expand All @@ -64,8 +57,8 @@
ani_thresh = config['ani_thresh']

# Make sure the output can be written to
if not os.access(outdir, os.W_OK):
raise FileNotFoundError(f"Cannot write to the location: {outdir}.")
if not os.access(os.path.abspath(os.path.dirname(out_filename)), os.W_OK):
raise FileNotFoundError(f"Cannot write to the location: {os.path.abspath(os.path.dirname(out_filename))}.")

# check if min_coverage is between 0 and 1
for x in min_coverage_list:
Expand Down Expand Up @@ -105,6 +98,10 @@
# check that the sample scale factor is the same as the genome scale factor for all organisms
if scale != sample_sig_info[4]:
raise ValueError(f'Sample scale factor does not equal genome scale factor. Please check your input.')

# check if the output filename is valid
if not isinstance(out_filename, str) and out_filename != '':
out_filename = 'result.xlsx'

# compute hypothesis recovery
logger.info('Computing hypothesis recovery.')
Expand All @@ -129,9 +126,9 @@
manifest_list = temp_manifest_list

# save the results into Excel file
logger.info(f'Saving results to {outdir}.')
logger.info(f'Saving results to {os.path.dirname(out_filename)}.')
# save the results with different min_coverage
with pd.ExcelWriter(out, engine='openpyxl', mode='w') as writer:
with pd.ExcelWriter(out_filename, engine='openpyxl', mode='w') as writer:
# save the raw results (i.e., min_coverage=1.0)
if keep_raw:
temp_mainifest = manifest_list[0].copy()
Expand Down
103 changes: 62 additions & 41 deletions srcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,
:param ksize: int (size of kmer)
:param scale: int (scale factor)
:param path_to_temp_dir: string (path to the folder to store the intermediate files)
:return: a dataframe with symmetric pairwise multisearch result (query_name, match_name)
:return: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh)
"""
results = {}

# run the sourmash multisearch
# save signature files to a text file
Expand All @@ -95,7 +96,7 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,
sig_files.to_csv(sig_files_path, header=False, index=False)

# convert ani threshold to containment threshold
containment_thresh = (ani_thresh ** ksize)
containment_thresh = 0.9*(ani_thresh ** ksize)
cmd = f"sourmash scripts multisearch {sig_files_path} {sig_files_path} -k {ksize} -s {scale} -c {num_threads} -t {containment_thresh} -o {os.path.join(path_to_temp_dir, 'training_multisearch_result.csv')}"
logger.info(f"Running sourmash multisearch with command: {cmd}")
exit_code = os.system(cmd)
Expand All @@ -104,62 +105,82 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,

# read the multisearch result
multisearch_result = pd.read_csv(os.path.join(path_to_temp_dir, 'training_multisearch_result.csv'), sep=',', header=0)
multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True)
multisearch_result = multisearch_result.query('query_name != match_name').reset_index(drop=True)

# because the multisearch result is not symmetric, that is
# we have: A B score but not B A score
# we need to make it symmetric
A_TO_B = multisearch_result[['query_name','match_name']].drop_duplicates().reset_index(drop=True)
B_TO_A = A_TO_B[['match_name','query_name']].rename(columns={'match_name':'query_name','query_name':'match_name'})
multisearch_result = pd.concat([A_TO_B, B_TO_A]).drop_duplicates().reset_index(drop=True)

return multisearch_result
for query_name, match_name in tqdm(multisearch_result[['query_name', 'match_name']].to_numpy()):
if str(query_name) not in results:
results[str(query_name)] = [str(match_name)]
else:
results[str(query_name)].append(str(match_name))

return results

def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], multisearch_result: pd.DataFrame) -> Tuple[Dict[str, List[str]], pd.DataFrame]:
def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], sig_same_genoms_dict: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], pd.DataFrame]:
"""
Helper function that removes the close related organisms from the reference matrix.
:param sig_info_dict: a dictionary mapping all signature name from reference data to a tuple (md5sum, minhash mean abundance, minhash hashes length, minhash scaled)
:param multisearch_result: a dataframe with symmetric pairwise multisearch result (query_name, match_name)
:param sig_same_genoms_dict: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh)
:return
remove_corr_df: a dataframe with two columns: removed organism name and its close related organisms
rep_remove_dict: a dictionary with key as representative signature name and value as a list of signatures to be removed
manifest_df: a dataframe containing the processed reference signature information
"""
# extract organisms that have close related organisms and their number of unique kmers
# sort name in order to better check the removed organisms
corr_organisms = sorted([str(query_name) for query_name in multisearch_result['query_name'].unique()])
sizes = np.array([sig_info_dict[organism][2] for organism in corr_organisms])
# sort organisms by size in ascending order, so we keep the largest organism, discard the smallest
bysize = np.argsort(sizes)
corr_organisms_bysize = np.array(corr_organisms)[bysize].tolist()

# use dictionary to store the removed organisms and their close related organisms
# key: removed organism name
# value: a set of close related organisms
mapping = multisearch_result.groupby('query_name')['match_name'].agg(set).to_dict()

# remove the sorted organisms until all left genomes are distinct (e.g., ANI <= ani_thresh)
# for each genome with close related genomes, pick the one with largest number of unique kmers
rep_remove_dict = {}
temp_remove_set = set()
# loop through the organisms size in ascending order
for organism in tqdm(corr_organisms_bysize, desc='Removing close related organisms'):
## for a given organism check its close related organisms, see if there are any organisms left after removing those in the remove set
## if so, put this organism in the remove set
left_corr_orgs = mapping[organism].difference(temp_remove_set)
if len(left_corr_orgs) > 0:
temp_remove_set.add(organism)

# generate a dataframe with two columns: removed organism name and its close related organisms
logger.info('Generating a dataframe with two columns: removed organism name and its close related organisms.')
remove_corr_list = [(organism, ','.join(list(mapping[organism]))) for organism in tqdm(temp_remove_set)]
remove_corr_df = pd.DataFrame(remove_corr_list, columns=['removed_org', 'corr_orgs'])
manifest_df = []
for genome, same_genomes in tqdm(sig_same_genoms_dict.items()):
# skip if the genome has been removed
if genome in temp_remove_set:
continue
# keep same genome if it is not in the remove set
same_genomes = list(set(same_genomes).difference(temp_remove_set))
# get the number of unique kmers for each genome
unique_kmers = np.array([sig_info_dict[genome][2]] + [sig_info_dict[same_genome][2] for same_genome in same_genomes])
# get the index of the genome with largest number of unique kmers
rep_idx = np.argmax(unique_kmers)
# get the representative genome
rep_genome = genome if rep_idx == 0 else same_genomes[rep_idx-1]
# get the list of genomes to be removed
remove_genomes = same_genomes if rep_idx == 0 else [genome] + same_genomes[:rep_idx-1] + same_genomes[rep_idx:]
# update remove set
temp_remove_set.update(remove_genomes)
if len(remove_genomes) > 0:
rep_remove_dict[rep_genome] = remove_genomes

# remove the close related organisms from the reference genome list
manifest_df = []
for sig_name, (md5sum, minhash_mean_abundance, minhash_hashes_len, minhash_scaled) in tqdm(sig_info_dict.items()):
if sig_name not in temp_remove_set:
manifest_df.append((sig_name, md5sum, minhash_hashes_len, get_num_kmers(minhash_mean_abundance, minhash_hashes_len, minhash_scaled, False), minhash_scaled))
manifest_df = pd.DataFrame(manifest_df, columns=['organism_name', 'md5sum', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor'])

return remove_corr_df, manifest_df
return rep_remove_dict, manifest_df

# def compute_sample_vector(sample_hashes, hash_to_idx):
# """
# Helper function that computes the sample vector for a given sample signature.
# :param sample_hashes: hashes in the sample signature
# :param hash_to_idx: dictionary mapping hashes to indices in the training dictionary
# :return: numpy array (sample vector)
# """
# # total number of hashes in the training dictionary
# hash_to_idx_keys = set(hash_to_idx.keys())

# # total number of hashes in the sample
# sample_hashes_keys = set(sample_hashes.keys())

# # initialize the sample vector
# sample_vector = np.zeros(len(hash_to_idx_keys))

# # get the hashes that are in both the sample and the training dictionary
# sample_intersect_training_hashes = hash_to_idx_keys.intersection(sample_hashes_keys)

# # fill in the sample vector
# for sh in tqdm(sample_intersect_training_hashes):
# sample_vector[hash_to_idx[sh]] = sample_hashes[sh]

# return sample_vector


class Prediction:
"""
Expand Down
Loading

0 comments on commit fc99cce

Please sign in to comment.