Skip to content

Commit

Permalink
update polypred to take a weighted sum of PRS instead of a weighted s…
Browse files Browse the repository at this point in the history
…um of betas, which loses accuracy for some reason we don't understand involving plink
  • Loading branch information
omerwe committed May 29, 2024
1 parent 7f1493d commit 357add2
Showing 1 changed file with 172 additions and 103 deletions.
275 changes: 172 additions & 103 deletions polypred.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def compute_prs_for_file(args,
plink_cmd += ' --bfile %s --score %s sum'%(plink_file_prefix, betas_file)
else:
raise ValueError('neither --bed nor --pgen specified')
if args.center:
plink_cmd += ' center'
if ranges_file is not None:
scores_file = os.path.join(temp_dir, next(tempfile._get_candidate_names()))
df_betas[['SNP_bim', 'score']].drop_duplicates('SNP_bim').to_csv(scores_file, sep='\t', header=False, index=False)
Expand Down Expand Up @@ -175,6 +177,7 @@ def load_betas_files(betas_file, verbose=True):

#rename columns if needed
df_betas.rename(columns={'sid':'SNP', 'nt1':'A1', 'nt2':'A2', 'BETA_MEAN':'BETA', 'ldpred_inf_beta':'BETA', 'chrom':'CHR', 'Chrom':'CHR', 'pos':'BP'}, inplace=True, errors='ignore')

if not is_numeric_dtype(df_betas['CHR']):
if df_betas['CHR'].str.startswith('chrom_').all():
df_betas['CHR'] = df_betas['CHR'].str[6:].astype(np.int64)
Expand Down Expand Up @@ -263,124 +266,191 @@ def computs_prs_all_files(args, betas_file, disable_jackknife=False, keep_file=N



def estimate_mixing_weights(args):
def compute_prs(args):

#if we need to perform predictions, make sure the mixweights file is found
if args.predict and args.betas.count(',') > 0:
mixweights_file = args.mixweights_prefix +'.mixweights'
if not os.path.exists(mixweights_file):
raise ValueError('mixweights file %s not found'%(mixweights_file))

#read phenotypes
df_pheno = pd.read_csv(args.pheno, names=['FID', 'IID', 'PHENO'], index_col='IID', delim_whitespace=True)

#make sure that we didn't include a header line
try:
float(df_pheno['PHENO'].iloc[0])
except:
df_pheno = df_pheno.iloc[1:]
df_pheno['PHENO'] = df_pheno['PHENO'].astype(np.float64)
if np.any(df_pheno.index.duplicated()):
raise ValueError('duplicate ids found in %s'%(args.pheno))

#compute a PRS for each beta file
beta_files = args.betas.split(',')
df_prs_sum_list = []
df_prs_list = []
for betas_file in beta_files:
df_prs_sum = computs_prs_all_files(args, betas_file, disable_jackknife=True, keep_file=args.pheno)
df_prs_sum_list.append(df_prs_sum[['SCORESUM']])
for df_prs_sum in df_prs_sum_list:
assert np.all(df_prs_sum.index == df_prs_sum_list[0].index)
df_prs_sum_all = pd.concat(df_prs_sum_list, axis=1)

#sync df_pheno and df_prs_sum_all
df_prs_sum_all.index = df_prs_sum_all.index.astype(str)
df_pheno.index = df_pheno.index.astype(str)
index_shared = df_prs_sum_all.index.intersection(df_pheno.index)
assert len(index_shared)>0
if len(index_shared) < df_prs_sum_all.shape[0]:
df_prs_sum_all = df_prs_sum_all.loc[index_shared]
if df_pheno.shape[0] != df_prs_sum_all.shape[0] or np.any(df_prs_sum_all.index != df_pheno.index):
df_pheno = df_pheno.loc[df_prs_sum_all.index]
df_prs = computs_prs_all_files(args, betas_file, disable_jackknife=not args.predict, keep_file=args.pheno)
df_prs_list.append(df_prs)
for df_prs in df_prs_list:
assert np.all(df_prs.index == df_prs_list[0].index)
df_prs_all = pd.concat(df_prs_list, axis=1)


#compute mixing weights if needed
if args.estimate_mixweights:

#read phenotypes
df_pheno = pd.read_csv(args.pheno, names=['FID', 'IID', 'PHENO'], index_col='IID', delim_whitespace=True)

#make sure that we didn't include a header line
try:
float(df_pheno['PHENO'].iloc[0])
except:
df_pheno = df_pheno.iloc[1:]
df_pheno['PHENO'] = df_pheno['PHENO'].astype(np.float64)
if np.any(df_pheno.index.duplicated()):
raise ValueError('duplicate ids found in %s'%(args.pheno))

#sync df_pheno and df_prs_all
df_prs_all.index = df_prs_all.index.astype(str)
df_pheno.index = df_pheno.index.astype(str)
index_shared = df_prs_all.index.intersection(df_pheno.index)
assert len(index_shared)>0
if len(index_shared) < df_prs_all.shape[0]:
df_prs_all = df_prs_all.loc[index_shared]
if df_pheno.shape[0] != df_prs_all.shape[0] or np.any(df_prs_all.index != df_pheno.index):
df_pheno = df_pheno.loc[df_prs_all.index]

#extract just the SCORESUM columns
df_prs_sum_all = df_prs_all['SCORESUM'].copy()

#flip PRS that are negatively correlated with the phenotype
is_flipped = np.zeros(df_prs_sum_all.shape[1], dtype=bool)
linreg_univariate = LinearRegression()
for c_i in range(df_prs_sum_all.shape[1]):
linreg_univariate.fit(df_prs_sum_all.iloc[:, [c_i]], df_pheno['PHENO'])
is_flipped[c_i] = linreg_univariate.coef_[0] < 0
df_prs_sum_all.loc[:, is_flipped] *= -1

#compute mixing weights
linreg = LinearRegression(positive = not args.allow_neg_mixweights)
linreg.fit(df_prs_sum_all, df_pheno['PHENO'])
mix_weights, intercept = linreg.coef_, linreg.intercept_
r2_score = metrics.r2_score(df_pheno['PHENO'], linreg.predict(df_prs_sum_all))
logging.info('In-sample R2: %0.3f'%(r2_score))

#create and print df_coef, and save it to disk
df_coef = pd.Series(mix_weights, index=beta_files)
df_coef.loc['intercept'] = intercept
mix_weights_file = args.output_prefix+'.mixweights'
df_coef.to_frame(name='mix_weight').to_csv(mix_weights_file, sep='\t')
logging.info('Writing mixing weights to %s'%(mix_weights_file))

#compute weighted betas
df_betas_weighted = None
for is_flipped_beta, betas_file, mix_weight in zip(is_flipped, beta_files, mix_weights):
df_betas = load_betas_files(betas_file)
df_betas = df_betas[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']]
df_betas['BETA'] *= mix_weight
if is_flipped_beta: df_betas['BETA'] = -df_betas['BETA']
if df_betas_weighted is None:
df_betas_weighted = df_betas
continue
#flip PRS that are negatively correlated with the phenotype
is_flipped = np.zeros(df_prs_sum_all.shape[1], dtype=bool)
linreg_univariate = LinearRegression()
for c_i in range(df_prs_sum_all.shape[1]):
linreg_univariate.fit(df_prs_sum_all.iloc[:, [c_i]], df_pheno['PHENO'])
is_flipped[c_i] = linreg_univariate.coef_[0] < 0
df_prs_sum_all.loc[:, is_flipped] *= -1

#estimate mixing weights
linreg = LinearRegression(positive = not args.allow_neg_mixweights)
linreg.fit(df_prs_sum_all, df_pheno['PHENO'])
mix_weights, intercept = linreg.coef_, linreg.intercept_
r2_score = metrics.r2_score(df_pheno['PHENO'], linreg.predict(df_prs_sum_all))
logging.info('In-sample R2: %0.3f'%(r2_score))

index_shared = df_betas.index.intersection(df_betas_weighted.index)
df_betas['BETA2'] = df_betas['BETA']
df_new = df_betas_weighted.loc[index_shared].merge(df_betas.loc[index_shared, ['BETA2']], left_index=True, right_index=True)
df_new['BETA'] += df_new['BETA2']
del df_new['BETA2']
del df_betas['BETA2']
df_list = [df_new, df_betas.loc[~df_betas.index.isin(index_shared)], df_betas_weighted.loc[~df_betas_weighted.index.isin(index_shared)]]
df_betas_weighted = pd.concat(df_list, axis=0)
df_betas_weighted.sort_values(['CHR', 'BP', 'A1'], inplace=True)

#save output to file
df_betas_weighted[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']].to_csv(args.output_prefix+'.betas', sep='\t', index=False, float_format='%0.6e')
logging.info('Saving weighted betas to %s'%(args.output_prefix+'.betas'))
#create and print df_coef, and save it to disk
df_coef = pd.Series(mix_weights, index=beta_files)
df_coef.loc[is_flipped] *= -1
df_coef.loc['intercept'] = intercept
mix_weights_file = args.output_prefix+'.mixweights'
df_coef.to_frame(name='mix_weight').to_csv(mix_weights_file, sep='\t')
logging.info('Writing mixing weights to %s'%(mix_weights_file))

#flip the PRS back
df_prs_sum_all.loc[:, is_flipped] *= -1


def compute_prs(args):
#perform predictions
if args.predict:

#extract just the SCORESUM columns
df_prs_sum_all = df_prs_all['SCORESUM']

#just take the PRS if there's only a single beta
if args.betas.count(',') == 0:
assert (df_prs_all.columns=='SCORESUM').sum() == 1
s_combined_prs = df_prs_sum_all

#if there's more than one beta, take the linear combination
else:
mixweights_file = args.mixweights_prefix +'.mixweights'
s_mixweights = pd.read_csv(mixweights_file, delim_whitespace=True, squeeze=True)
if np.any(s_mixweights.index[:-1] != args.betas.split(',')):
raise ValueError('The provided betas file do not match the mix weights file')
assert s_mixweights.index[-1] == 'intercept'
s_combined_prs = df_prs_sum_all.dot(s_mixweights.iloc[:-1].values) + s_mixweights.loc['intercept']

#save the PRS to disk
df_prs_sum = s_combined_prs.reset_index(drop=False)
df_prs_sum.columns = ['IID', 'PRS']
df_prs_sum['FID'] = df_prs_sum['IID']
df_prs_sum = df_prs_sum[['FID', 'IID', 'PRS']]
df_prs_sum.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f')

#handle jackknife
set_jk_columns = set([c for c in df_prs_all.columns if '.jk' in c])
df_prs_sum_jk = pd.DataFrame(index=df_prs_all.index, columns=set_jk_columns)
if df_prs_sum_jk.shape[1] > 1:
for jk_column in set_jk_columns:
if args.betas.count(',') == 0:
assert (df_prs_all.columns==jk_column).sum() == 1
df_prs_sum_jk[jk_column] = df_prs_all[jk_column]
else:
#import ipdb; ipdb.set_trace()
df_prs_sum_jk[jk_column] = df_prs_all[jk_column].dot(s_mixweights.iloc[:-1].values) + s_mixweights.loc['intercept']

df_prs_sum_jk.reset_index().to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f')

logging.info('Saving PRS to %s'%(args.output_prefix+'.prs'))




# #compute weighted betas
# df_betas_weighted = None
# for is_flipped_beta, betas_file, mix_weight in zip(is_flipped, beta_files, mix_weights):
# df_betas = load_betas_files(betas_file)
# df_betas = df_betas[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']]
# df_betas['BETA'] *= mix_weight
# if is_flipped_beta: df_betas['BETA'] = -df_betas['BETA']
# if df_betas_weighted is None:
# df_betas_weighted = df_betas
# continue

# index_shared = df_betas.index.intersection(df_betas_weighted.index)
# df_betas['BETA2'] = df_betas['BETA']
# df_new = df_betas_weighted.loc[index_shared].merge(df_betas.loc[index_shared, ['BETA2']], left_index=True, right_index=True)
# df_new['BETA'] += df_new['BETA2']
# del df_new['BETA2']
# del df_betas['BETA2']
# df_list = [df_new, df_betas.loc[~df_betas.index.isin(index_shared)], df_betas_weighted.loc[~df_betas_weighted.index.isin(index_shared)]]
# df_betas_weighted = pd.concat(df_list, axis=0)
# df_betas_weighted.sort_values(['CHR', 'BP', 'A1'], inplace=True)

# #save weighted betas to file
# df_betas_weighted[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']].to_csv(args.output_prefix+'.betas', sep='\t', index=False, float_format='%0.6e')
# logging.info('Saving weighted betas to %s'%(args.output_prefix+'.betas'))

if args.betas.count(',') > 0:
raise ValueError('--predict can only be used with a single betas file')
df_prs_sum = computs_prs_all_files(args, args.betas, disable_jackknife=False, keep_file=args.keep)
df_prs_sum.reset_index(inplace=True, drop=False)
df_prs_sum.columns = df_prs_sum.columns.str.replace('SCORESUM', 'PRS')
df_prs_sum_main = df_prs_sum[['FID', 'IID', 'PRS']]
df_prs_sum_jk = df_prs_sum[['FID', 'IID'] + [c for c in df_prs_sum.columns if c.startswith('PRS.')]]

df_prs_sum_main.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f')
if df_prs_sum_jk.shape[1]>1:
df_prs_sum_jk.to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f')
logging.info('Saving PRS to %s'%(args.output_prefix+'.prs'))

# def compute_prs(args):

# if args.betas.count(',') > 0:
# raise ValueError('--predict can only be used with a single betas file')
# df_prs_sum = computs_prs_all_files(args, args.betas, disable_jackknife=False, keep_file=args.keep)
# df_prs_sum.reset_index(inplace=True, drop=False)
# df_prs_sum.columns = df_prs_sum.columns.str.replace('SCORESUM', 'PRS')
# df_prs_sum_main = df_prs_sum[['FID', 'IID', 'PRS']]
# df_prs_sum_jk = df_prs_sum[['FID', 'IID'] + [c for c in df_prs_sum.columns if c.startswith('PRS.')]]

# df_prs_sum_main.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f')
# if df_prs_sum_jk.shape[1]>1:
# df_prs_sum_jk.to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f')
# logging.info('Saving PRS to %s'%(args.output_prefix+'.prs'))


def check_args(args):
if int(args.predict) + int(args.combine_betas) != 1:
raise ValueError('you must specify either --predict or --combine-betas (but not both)')
if int(args.predict) + int(args.estimate_mixweights) != 1:
raise ValueError('you must specify either --predict or --estimate-mixweights (but not both)')
if args.plink_exe is None and args.plink2_exe is None:
raise ValueError('you must specify either --plink-exe or --plink2-exe')
if args.plink_exe is not None and not os.path.exists(args.plink_exe):
raise ValueError('%s not found'%(args.plink_exe))
if args.plink2_exe is not None and not os.path.exists(args.plink2_exe):
raise ValueError('%s not found'%(args.plink2_exe))
if args.combine_betas:
if args.estimate_mixweights:
if args.keep is not None:
raise ValueError('you cannot provide both --combine-betas and --keep')
raise ValueError('you cannot provide both --estimate-mixweights and --keep')
if args.pheno is None:
raise ValueError('you must provide --pheno if you specify --combine-betas')
raise ValueError('you must provide --pheno if you specify --estimate-mixweights')
if args.betas.count(',')==0:
raise ValueError('you must provide multiple files in --betas if you specify --combine-betas')
raise ValueError('you must provide multiple files in --betas if you specify --estimate-mixweights')
if args.predict:
if args.mixweights_prefix is None and args.betas.count(',') > 0:
raise ValueError('you must provide --mixweights-prefix together with --predict if you have more than one beta file')
if args.num_jk<0:
raise ValueError('--num-jk must be >=0')
if args.pheno is not None and args.predict:
raise ValueError('--pheno can only be used with --combine-betas')
raise ValueError('--pheno can only be used with --estimate-mixweights')

if len(list(args.files)) == 0:
raise ValueError('no input files specified')
Expand All @@ -391,9 +461,10 @@ def check_args(args):
parser = argparse.ArgumentParser()

parser.add_argument('--betas', required=True, help='files with SNP effect sizes (comma separated). A1 is the effect allele.')
parser.add_argument('--mixweights-prefix', help='Prefix of files with mixing weights (required if you use --predict with more than one betas file')
parser.add_argument('--output-prefix', required=True, help='Prefix of output file')

parser.add_argument('--combine-betas', default=False, action='store_true', help='If specified, PolyPred will estimate mixing weights')
parser.add_argument('--estimate-mixweights', default=False, action='store_true', help='If specified, PolyPred will estimate mixing weights')
parser.add_argument('--allow-neg-mixweights', default=False, action='store_true', help='If specified, PolyPred will not enforce non-negative mixing weights')
parser.add_argument('--predict', default=False, action='store_true', help='If specified, PolyPred will compute PRS')
parser.add_argument('--pheno', default=None, help='Phenotype file (required for estimating mixing weights)')
Expand All @@ -403,6 +474,7 @@ def check_args(args):
parser.add_argument('--extract', default=None, help='A text file with rsids of SNPs to use (one per line)')
parser.add_argument('--keep', default=None, help='A text file with ids of individuals to use (two columns per line, each containing FID,IID)')
parser.add_argument('--num-jk', type=int, default=200, help='number of genomic jackknife blocks')
parser.add_argument('--center', default=False, action='store_true', help='If specified, the PRS will be centered')

parser.add_argument('--memory', type=int, default=2, help='Maximum memory usage (in GB)')
parser.add_argument('--threads', type=int, default=1, help='Number of CPU threads')
Expand All @@ -420,21 +492,18 @@ def check_args(args):

#check that the output directory exists
if len(os.path.dirname(args.output_prefix))>0 and not os.path.exists(os.path.dirname(args.output_prefix)):
raise ValueError('output directory %s doesn\'t exist'%(os.path.dirname(args.output_prefix)))
raise ValueError('output directory %s doesn\'t exist'%(os.path.dirname(args.output_prefix)))



#configure logger
configure_logger(args.output_prefix)

#check arguments
check_args(args)

#estimate mixing weights if needed
if args.combine_betas:
estimate_mixing_weights(args)

#compute PRS if needed
if args.predict:
compute_prs(args)
#Estimate mixiwing weights and/or compute PRS
compute_prs(args)

print()

Expand Down

0 comments on commit 357add2

Please sign in to comment.