Skip to content

Commit

Permalink
Merge pull request #42 from DHekstra/main
Browse files Browse the repository at this point in the history
made stats functions consistent with BaseParser
  • Loading branch information
DHekstra authored Jul 25, 2023
2 parents 08eb192 + a9b4d82 commit 162a4ce
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 172 deletions.
79 changes: 46 additions & 33 deletions rsbooster/stats/ccanom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@
import seaborn as sns


def parse_arguments():
"""Parse commandline arguments"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter, description=__doc__
)

# Required arguments
parser.add_argument(
"mtz",
nargs="+",
help="MTZs containing crossvalidation data from careless",
)
parser.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)

return parser#.parse_args()
from rsbooster.stats.parser import BaseParser
class ArgumentParser(BaseParser):
def __init__(self):
super().__init__(
description=__doc__
)
# Required arguments
self.add_argument(
"mtz",
nargs="+",
help="MTZs containing crossvalidation data from careless",
)

self.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)


def make_halves_ccanom(mtz, bins=10):
Expand All @@ -51,8 +51,11 @@ def make_halves_ccanom(mtz, bins=10):
def analyze_ccanom_mtz(mtzpath, bins=10, return_labels=True, method="spearman"):
"""Compute CCsym from 2-fold cross-validation"""

mtz = rs.read_mtz(mtzpath)

if type(mtzpath) is rs.dataset.DataSet:
mtz=mtzpath
else:
mtz = rs.read_mtz(mtzpath)

# Error handling -- make sure MTZ file is appropriate
if "half" not in mtz.columns:
raise ValueError("Please provide MTZs from careless crossvalidation")
Expand All @@ -75,11 +78,7 @@ def analyze_ccanom_mtz(mtzpath, bins=10, return_labels=True, method="spearman"):
return result


def main():

# Parse commandline arguments
args = parse_arguments().parse_args()

def run_analysis(args):
results = []
labels = None
for m in args.mtz:
Expand All @@ -96,18 +95,32 @@ def main():
results["CCanom"] = results[("DF1", "DF2")]
results.drop(columns=[("DF1", "DF2")], inplace=True)

print(results)
for k in ('bin', 'repeat'):
results[k] = results[k].to_numpy('int32')

if args.output is not None:
results.to_csv(args.output)
else:
print(results.to_string())

# print(results.info())

sns.lineplot(
data=results, x="bin", y="CCanom", hue="filename", ci="sd", palette="viridis"
data=results, x="bin", y="CCanom", hue="filename", errorbar="sd", palette="viridis"
)
plt.xticks(range(10), labels, rotation=45, ha="right", rotation_mode="anchor")
plt.ylabel(r"$CC_{anom}$ " + f"({args.method})")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.grid()
plt.tight_layout()
plt.show()
if args.image is not None:
plt.savefig(args.image)

if args.show:
plt.show()

def parse_arguments():
return ArgumentParser()

if __name__ == "__main__":
main()
def main():
run_analysis(parse_arguments().parse_args())
81 changes: 47 additions & 34 deletions rsbooster/stats/cchalf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,27 @@
import seaborn as sns


def parse_arguments():
"""Parse commandline arguments"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter, description=__doc__
)

# Required arguments
parser.add_argument(
"mtz",
nargs="+",
help="MTZs containing crossvalidation data from careless",
)

parser.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)

return parser#.parse_args()
from rsbooster.stats.parser import BaseParser
class ArgumentParser(BaseParser):
def __init__(self):
super().__init__(
description=__doc__
)

# Required arguments
self.add_argument(
"mtz",
nargs="+",
help="MTZs containing crossvalidation data from careless",
)

self.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)


def make_halves_cchalf(mtz, bins=10):
Expand All @@ -54,8 +53,11 @@ def make_halves_cchalf(mtz, bins=10):
def analyze_cchalf_mtz(mtzpath, bins=10, return_labels=True, method="spearman"):
"""Compute CChalf from 2-fold cross-validation"""

mtz = rs.read_mtz(mtzpath)

if type(mtzpath) is rs.dataset.DataSet:
mtz=mtzpath
else:
mtz = rs.read_mtz(mtzpath)

# Error handling -- make sure MTZ file is appropriate
if "half" not in mtz.columns:
raise ValueError("Please provide MTZs from careless crossvalidation")
Expand All @@ -73,11 +75,7 @@ def analyze_cchalf_mtz(mtzpath, bins=10, return_labels=True, method="spearman"):
return result


def main():

# Parse commandline arguments
args = parse_arguments().parse_args()

def run_analysis(args):
results = []
labels = None
for m in args.mtz:
Expand All @@ -94,18 +92,33 @@ def main():
results["CChalf"] = results[("F1", "F2")]
results.drop(columns=[("F1", "F2")], inplace=True)

print(results)
for k in ('bin', 'repeat'):
results[k] = results[k].to_numpy('int32')

if args.output is not None:
results.to_csv(args.output)
else:
print(results.to_string())

print(results.info())

sns.lineplot(
data=results, x="bin", y="CChalf", hue="filename", ci="sd", palette="viridis"
data=results, x="bin", y="CChalf", hue="filename", errorbar="sd", palette="viridis"
)
plt.xticks(range(10), labels, rotation=45, ha="right", rotation_mode="anchor")
plt.ylabel(r"$CC_{1/2}$ " + f"({args.method})")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.grid()
plt.tight_layout()
plt.show()
if args.image is not None:
plt.savefig(args.image)

if args.show:
plt.show()

if __name__ == "__main__":
main()

def parse_arguments():
return ArgumentParser()

def main():
run_analysis(parse_arguments().parse_args())
110 changes: 61 additions & 49 deletions rsbooster/stats/ccpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,51 @@
import seaborn as sns


def parse_arguments():
"""Parse commandline arguments"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter, description=__doc__
)

# Required arguments
parser.add_argument(
"-i",
"--inputmtzs",
nargs="+",
action="append",
required=True,
help="MTZs containing holdout prediction data from careless",
)

# Optional arguments
parser.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)
parser.add_argument(
"--mod2",
action="store_true",
help=("Use (id mod 2) to assign delays (use when employing spacegroup hack)"),
)
from rsbooster.stats.parser import BaseParser
class ArgumentParser(BaseParser):
def __init__(self):
super().__init__(
description=__doc__
)

return parser#.parse_args()
# Required arguments
self.add_argument(
"mtzs",
nargs="+",
help="MTZs containing prediction data from careless",
)

# Optional arguments
self.add_argument(
"-m",
"--method",
default="spearman",
choices=["spearman", "pearson"],
help=("Method for computing correlation coefficient (spearman or pearson)"),
)
self.add_argument(
"--mod2",
action="store_true",
help=("Use (id mod 2) to assign delays (use when employing spacegroup hack)"),
)
self.add_argument(
"--overall",
action="store_true",
default=False,
help=("Whether to report a single value for the entire dataset"),
)


def compute_ccpred(
mtzpath, overall=False, bins=10, return_labels=True, method="spearman", mod2=False
):
"""Compute CCsym from 2-fold cross-validation"""

mtz = rs.read_mtz(mtzpath)

if type(mtzpath) is rs.dataset.DataSet:
mtz=mtzpath
else:
mtz = rs.read_mtz(mtzpath)

if overall:
grouper = mtz.groupby(["test"])[["Iobs", "Ipred"]]
else:
Expand All @@ -79,23 +84,14 @@ def compute_ccpred(
return result


def main():

# Parse commandline arguments
args = parse_arguments().parse_args()
def run_analysis(args):

results = []
labels = None

if isinstance(args.inputmtzs[0], list) and len(args.inputmtzs) > 1:
overall = True
mtzs = [item for sublist in args.inputmtzs for item in sublist]
else:
overall = False
mtzs = [item for sublist in args.inputmtzs for item in sublist]

for m in mtzs:
result = compute_ccpred(m, overall=overall, method=args.method, mod2=args.mod2)
# mtzs -> args.mtzs!
for m in args.mtzs:
result = compute_ccpred(m, overall=args.overall, method=args.method, mod2=args.mod2)
if isinstance(result, tuple):
results.append(result[0])
labels = result[1]
Expand All @@ -107,8 +103,17 @@ def main():
results["CCpred"] = results[("Iobs", "Ipred")]
results.drop(columns=[("Iobs", "Ipred")], inplace=True)

print(results)
if overall:
# print(results.info())
for k in ('bin', 'test'):
results[k] = results[k].to_numpy('int32')

if args.output is not None:
results.to_csv(args.output)
else:
print(results.to_string())

print(results.info())
if args.overall:
g = sns.relplot(
data=results,
x="id",
Expand Down Expand Up @@ -140,8 +145,15 @@ def main():
ax.set_xticks(range(10))
ax.set_xticklabels(labels, rotation=45, ha="right", rotation_mode="anchor")
ax.grid(True)

if args.image is not None:
plt.savefig(args.image)

if args.show:
plt.show()

def parse_arguments():
return ArgumentParser()

if __name__ == "__main__":
main()
def main():
run_analysis(parse_arguments().parse_args())
Loading

0 comments on commit 162a4ce

Please sign in to comment.