Skip to content

Commit

Permalink
update the functions for removing the similar organoms
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuma committed Oct 19, 2023
1 parent 6db73a9 commit 75a7d93
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 73 deletions.
21 changes: 9 additions & 12 deletions make_training_data_from_sketches.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@

# Find the close related genomes with ANI > ani_thresh from the reference database
logger.info("Find the close related genomes with ANI > ani_thresh from the reference database")
sig_same_genoms_dict = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir)
multisearch_result = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir)

# remove the close related organisms: any organisms with ANI > ani_thresh
# pick only the one with largest number of unique kmers from all the close related organisms
logger.info("Removing the close related organisms with ANI > ani_thresh")
rep_remove_dict, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, sig_same_genoms_dict)
remove_corr_df, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, multisearch_result, ani_thresh, ksize)

# write out the manifest file
logger.info("Writing out the manifest file")
Expand All @@ -79,22 +79,19 @@

# write out a mapping dataframe from representative organism to the close related organisms
logger.info("Writing out a mapping dataframe from representative organism to the close related organisms")
if len(rep_remove_dict) == 0:
logger.warning("No close related organisms found. No mapping dataframe is written.")
rep_remove_df = pd.DataFrame(columns=['rep_org', 'corr_orgs'])
rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv')
rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None)
if len(remove_corr_df) == 0:
logger.warning("No close related organisms found.")
remove_corr_df_indicator = ""
else:
rep_remove_df = pd.DataFrame([(rep_org, ','.join(corr_org_list)) for rep_org, corr_org_list in rep_remove_dict.items()])
rep_remove_df.columns = ['rep_org', 'corr_orgs']
rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv')
rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None)
remove_corr_df_path = os.path.join(outdir, f'{prefix}_removed_orgs_to_corr_orgas_mapping.tsv')
remove_corr_df.to_csv(remove_corr_df_path, sep='\t', index=None)
remove_corr_df_indicator = remove_corr_df_path

# save the config file
logger.info("Saving the config file")
json_file_path = os.path.join(outdir, f'{prefix}_config.json')
json.dump({'manifest_file_path': manifest_file_path,
'rep_remove_df_path': rep_remove_df_path,
'remove_cor_df_path': remove_corr_df_indicator,
'intermediate_files_dir': path_to_temp_dir,
'scale': scale,
'ksize': ksize,
Expand Down
96 changes: 35 additions & 61 deletions srcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,
:param ksize: int (size of kmer)
:param scale: int (scale factor)
:param path_to_temp_dir: string (path to the folder to store the intermediate files)
:return: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh)
:return: a dataframe with symmetric pairwise multisearch result (query_name, match_name)
"""
results = {}

Expand All @@ -96,7 +96,7 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,
sig_files.to_csv(sig_files_path, header=False, index=False)

# convert ani threshold to containment threshold
containment_thresh = 0.9*(ani_thresh ** ksize)
containment_thresh = (ani_thresh ** ksize)
cmd = f"sourmash scripts multisearch {sig_files_path} {sig_files_path} -k {ksize} -s {scale} -c {num_threads} -t {containment_thresh} -o {os.path.join(path_to_temp_dir, 'training_multisearch_result.csv')}"
logger.info(f"Running sourmash multisearch with command: {cmd}")
exit_code = os.system(cmd)
Expand All @@ -105,82 +105,56 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int,

# read the multisearch result
multisearch_result = pd.read_csv(os.path.join(path_to_temp_dir, 'training_multisearch_result.csv'), sep=',', header=0)
multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True)
multisearch_result = multisearch_result.query('query_name != match_name').reset_index(drop=True)

# because the multisearch result is not symmetric, that is
# we have: A B score but not B A score
# we need to make it symmetric
A_TO_B = multisearch_result[['query_name','match_name']].drop_duplicates().reset_index(drop=True)
B_TO_A = A_TO_B[['match_name','query_name']].rename(columns={'match_name':'query_name','query_name':'match_name'})
multisearch_result = pd.concat([A_TO_B, B_TO_A]).drop_duplicates().reset_index(drop=True)

for query_name, match_name in tqdm(multisearch_result[['query_name', 'match_name']].to_numpy()):
if str(query_name) not in results:
results[str(query_name)] = [str(match_name)]
else:
results[str(query_name)].append(str(match_name))

return results
return multisearch_result

def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], sig_same_genoms_dict: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], pd.DataFrame]:
def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], multisearch_result: pd.DataFrame, ani_thresh: float, ksize: int) -> Tuple[Dict[str, List[str]], pd.DataFrame]:
"""
Helper function that removes the close related organisms from the reference matrix.
:param sig_info_dict: a dictionary mapping all signature name from reference data to a tuple (md5sum, minhash mean abundance, minhash hashes length, minhash scaled)
:param sig_same_genoms_dict: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh)
:param multisearch_result: a dataframe with symmetric pairwise multisearch result (query_name, match_name)
:return
rep_remove_dict: a dictionary with key as representative signature name and value as a list of signatures to be removed
remove_corr_df: a dataframe with two columns: removed organism name and its close related organisms
manifest_df: a dataframe containing the processed reference signature information
"""
# for each genome with close related genomes, pick the one with largest number of unique kmers
rep_remove_dict = {}
# extract organisms that have close related organisms and their number of unique kmers
corr_organisms = [query_name for query_name in multisearch_result['query_name'].unique()]
sizes = np.array([sig_info_dict[organism][2] for organism in corr_organisms])
# sort organisms by size in ascending order, so we keep the largest organism, discard the smallest
bysize = np.argsort(sizes)
corr_organisms_bysize = np.array(corr_organisms)[bysize].tolist()

# remove the sorted organisms until all left genomes are distinct (e.g., ANI <= ani_thresh)
temp_remove_set = set()
manifest_df = []
for genome, same_genomes in tqdm(sig_same_genoms_dict.items()):
# skip if the genome has been removed
if genome in temp_remove_set:
continue
# keep same genome if it is not in the remove set
same_genomes = list(set(same_genomes).difference(temp_remove_set))
# get the number of unique kmers for each genome
unique_kmers = np.array([sig_info_dict[genome][2]] + [sig_info_dict[same_genome][2] for same_genome in same_genomes])
# get the index of the genome with largest number of unique kmers
rep_idx = np.argmax(unique_kmers)
# get the representative genome
rep_genome = genome if rep_idx == 0 else same_genomes[rep_idx-1]
# get the list of genomes to be removed
remove_genomes = same_genomes if rep_idx == 0 else [genome] + same_genomes[:rep_idx-1] + same_genomes[rep_idx:]
# update remove set
temp_remove_set.update(remove_genomes)
if len(remove_genomes) > 0:
rep_remove_dict[rep_genome] = remove_genomes
# loop through the organisms size in ascending order
for organism in tqdm(corr_organisms_bysize, desc='Removing close related organisms'):
## for a given organism check its close related organisms, see if there are any organisms left after removing those in the remove set
## if so, put this organism in the remove set
left_corr_orgs = set(multisearch_result.query(f'query_name == "{organism}"')['match_name']).difference(temp_remove_set)
if len(left_corr_orgs) > 0:
temp_remove_set.add(organism)

# generate a dataframe with two columns: removed organism name and its close related organisms
logger.info(f'Generating a dataframe with two columns: removed organism name and its close related organisms.')
remove_corr_list = [(organism, ','.join(list(set(multisearch_result.query(f'query_name == "{organism}"')['match_name'])))) for organism in tqdm(temp_remove_set)]
remove_corr_df = pd.DataFrame(remove_corr_list, columns=['removed_org', 'corr_orgs'])

# remove the close related organisms from the reference genome list
manifest_df = []
for sig_name, (md5sum, minhash_mean_abundance, minhash_hashes_len, minhash_scaled) in tqdm(sig_info_dict.items()):
if sig_name not in temp_remove_set:
manifest_df.append((sig_name, md5sum, minhash_hashes_len, get_num_kmers(minhash_mean_abundance, minhash_hashes_len, minhash_scaled, False), minhash_scaled))
manifest_df = pd.DataFrame(manifest_df, columns=['organism_name', 'md5sum', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor'])

return rep_remove_dict, manifest_df

# def compute_sample_vector(sample_hashes, hash_to_idx):
# """
# Helper function that computes the sample vector for a given sample signature.
# :param sample_hashes: hashes in the sample signature
# :param hash_to_idx: dictionary mapping hashes to indices in the training dictionary
# :return: numpy array (sample vector)
# """
# # total number of hashes in the training dictionary
# hash_to_idx_keys = set(hash_to_idx.keys())

# # total number of hashes in the sample
# sample_hashes_keys = set(sample_hashes.keys())

# # initialize the sample vector
# sample_vector = np.zeros(len(hash_to_idx_keys))

# # get the hashes that are in both the sample and the training dictionary
# sample_intersect_training_hashes = hash_to_idx_keys.intersection(sample_hashes_keys)

# # fill in the sample vector
# for sh in tqdm(sample_intersect_training_hashes):
# sample_vector[hash_to_idx[sh]] = sample_hashes[sh]

# return sample_vector

return remove_corr_df, manifest_df

class Prediction:
"""
Expand Down

0 comments on commit 75a7d93

Please sign in to comment.