Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Domain adaptation #108

Open
wants to merge 170 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
20d6375
Update .gitignore
HendrikSchmidt Jan 6, 2023
11b8f6f
Merge branch 'development' into domain-adaptation
HendrikSchmidt Jan 6, 2023
8bbfff4
introduce target split
HendrikSchmidt Jan 6, 2023
e83317f
Merge branch 'hotfix-train-val' into domain-adaptation
HendrikSchmidt Jan 6, 2023
d62d524
refactor fold size
HendrikSchmidt Jan 7, 2023
be991ff
Update train.py
HendrikSchmidt Jan 7, 2023
dcf1a6d
make cpu and gin flag general, rename gin flag
HendrikSchmidt Jan 7, 2023
ef4e687
Update preprocess.py
HendrikSchmidt Jan 7, 2023
63c06e2
add basis for domain adaptation
HendrikSchmidt Jan 7, 2023
4f6006b
Merge branch 'hotfix-train-val' into domain-adaptation
HendrikSchmidt Jan 10, 2023
83782b5
refactor folds for targets
HendrikSchmidt Jan 10, 2023
997819c
Merge branch 'development' into domain-adaptation
HendrikSchmidt Jan 11, 2023
f31d9ac
Merge branch 'development' into domain-adaptation
HendrikSchmidt Jan 11, 2023
43ebd5f
add evaluation function to test whole dataset
HendrikSchmidt Jan 11, 2023
405129a
update predict function for booster
HendrikSchmidt Jan 11, 2023
7a6b01c
update domain adaptation script
HendrikSchmidt Jan 11, 2023
7515597
remove weight tuning, adapt for LR
HendrikSchmidt Jan 11, 2023
0789bcf
iterate over target sizes
HendrikSchmidt Jan 11, 2023
aa3345a
aggregate and average DA metrics
HendrikSchmidt Jan 11, 2023
3da7f6b
Update domain_adaptation.py
HendrikSchmidt Jan 11, 2023
6f6a82e
use data_dir for da
HendrikSchmidt Jan 11, 2023
ccb8bbe
disable confusion matrix
HendrikSchmidt Jan 12, 2023
1b6319f
Update Transformer.gin
HendrikSchmidt Jan 12, 2023
bb297d8
rename encoder to model in wrapper
HendrikSchmidt Jan 12, 2023
119ad35
fix model path
HendrikSchmidt Jan 12, 2023
7ffbb4b
Update Transformer.gin
HendrikSchmidt Jan 12, 2023
12036b5
load correct wrapper
HendrikSchmidt Jan 12, 2023
40642ea
initialize wrapper without model
HendrikSchmidt Jan 12, 2023
c1e3222
instantiate encoder in wrapper
HendrikSchmidt Jan 12, 2023
9866215
revert instantiation
HendrikSchmidt Jan 12, 2023
4cc5fca
load model configs
HendrikSchmidt Jan 12, 2023
aefb165
update lgbm config
HendrikSchmidt Jan 12, 2023
85543e2
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
0030a69
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
2afc544
include model in log_dir
HendrikSchmidt Jan 12, 2023
1954954
reduce hyperparameter training for shallow models
HendrikSchmidt Jan 12, 2023
761e354
reset gin config for repeated HP tuning
HendrikSchmidt Jan 12, 2023
f7c01fe
remove duplicate output transform
HendrikSchmidt Jan 12, 2023
685ee3e
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
3f68754
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
efe115f
move metrics calculation to wrapper
HendrikSchmidt Jan 12, 2023
c3299fb
fix function call
HendrikSchmidt Jan 12, 2023
294ef74
increase batch_size for test
HendrikSchmidt Jan 12, 2023
5f36045
Update wrappers.py
HendrikSchmidt Jan 12, 2023
d57b375
Update Transformer.gin
HendrikSchmidt Jan 12, 2023
442a3c4
Update wrappers.py
HendrikSchmidt Jan 12, 2023
0967ff5
Update wrappers.py
HendrikSchmidt Jan 12, 2023
de88877
Update wrappers.py
HendrikSchmidt Jan 12, 2023
2ecda40
Update wrappers.py
HendrikSchmidt Jan 12, 2023
6c4cd96
Update wrappers.py
HendrikSchmidt Jan 12, 2023
9fae10e
Update wrappers.py
HendrikSchmidt Jan 12, 2023
5d6cb2e
Update wrappers.py
HendrikSchmidt Jan 12, 2023
57db01b
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
b93b402
Update wrappers.py
HendrikSchmidt Jan 12, 2023
0cd0b13
Update wrappers.py
HendrikSchmidt Jan 12, 2023
623a80d
Update wrappers.py
HendrikSchmidt Jan 12, 2023
702c917
change metric calculation
HendrikSchmidt Jan 12, 2023
570f6ab
Update Transformer.gin
HendrikSchmidt Jan 12, 2023
162adad
Update domain_adaptation.py
HendrikSchmidt Jan 12, 2023
c44a1b0
Update wrappers.py
HendrikSchmidt Jan 12, 2023
ea23e1b
add dg baseline
HendrikSchmidt Jan 12, 2023
8fda309
Merge branch 'development' into domain-adaptation
HendrikSchmidt Jan 12, 2023
d56d98b
fix json encoder
HendrikSchmidt Jan 12, 2023
e9b5c79
Update Transformer.gin
HendrikSchmidt Jan 12, 2023
2ca631e
only execute da for one dataset at a time
HendrikSchmidt Jan 12, 2023
4585257
remove run dir
HendrikSchmidt Jan 12, 2023
fc0d41e
Create da_to_csv.py
HendrikSchmidt Jan 13, 2023
440aa15
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
7a2583b
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
11d80e2
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
18fe642
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
3df1fcd
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
1ce0f17
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
2e7ac54
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
fd80822
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
5b9ba03
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
22b7f61
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
095c79b
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
88a79e0
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
c8396b0
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
6e505bf
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
ea062c9
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
f66c787
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
6aeec62
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
a3e68f6
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
87b7fbe
Update da_to_csv.py
HendrikSchmidt Jan 13, 2023
2277198
remove evaluate and test_all
HendrikSchmidt Jan 14, 2023
ccc0d34
Update LogisticRegression.gin
HendrikSchmidt Jan 14, 2023
395841f
Update preprocess.py
HendrikSchmidt Jan 14, 2023
e1b3fe1
Update preprocess.py
HendrikSchmidt Jan 14, 2023
94969b7
Update da_to_csv.py
HendrikSchmidt Jan 14, 2023
c647be9
fix comments
HendrikSchmidt Jan 14, 2023
7d54d7e
test different weights
HendrikSchmidt Jan 14, 2023
c6a54ce
only plot avg
HendrikSchmidt Jan 14, 2023
ac1a7bb
test other weights
HendrikSchmidt Jan 14, 2023
da3fce1
Update domain_adaptation.py
HendrikSchmidt Jan 14, 2023
c0f8c39
Update da_to_csv.py
HendrikSchmidt Jan 14, 2023
7d91bac
Update da_to_csv.py
HendrikSchmidt Jan 14, 2023
4c03885
auc and loss based weigth functions
HendrikSchmidt Jan 14, 2023
464ad05
Update domain_adaptation.py
HendrikSchmidt Jan 14, 2023
45932cf
cache predictions
HendrikSchmidt Jan 15, 2023
be5ad30
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
5c8265a
Update preprocess.py
HendrikSchmidt Jan 15, 2023
a834354
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
5d6e273
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
5ba9428
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
a8b2021
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
0c97842
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
95addb7
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
bbb6939
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
da34511
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
bf6c6b9
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
06e8e1a
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
0df060c
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
52f6090
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
a9c83c1
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
3db23a4
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
88d5ce5
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
ac2eac1
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
fcc811f
test target with predictions
HendrikSchmidt Jan 15, 2023
f20715c
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
f2b2ac5
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
fdfca07
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
89a5a6c
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
8cf0207
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
d89edbf
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
918890a
test cc with preds
HendrikSchmidt Jan 15, 2023
871d562
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
cc83759
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
89d12a3
Update domain_adaptation.py
HendrikSchmidt Jan 15, 2023
6014ea4
boil down to relevant appraoches
HendrikSchmidt Jan 16, 2023
a6b31d4
format
HendrikSchmidt Jan 16, 2023
272747c
Update da_to_csv.py
HendrikSchmidt Jan 16, 2023
5d2d20f
Update domain_adaptation.py
HendrikSchmidt Jan 16, 2023
d896d56
fix da for miiv
HendrikSchmidt Jan 16, 2023
101f262
Update da_to_csv.py
HendrikSchmidt Jan 16, 2023
ac098ce
Update domain_adaptation.py
HendrikSchmidt Jan 17, 2023
db0393f
Update da_to_csv.py
HendrikSchmidt Jan 17, 2023
871810e
Update domain_adaptation.py
HendrikSchmidt Jan 17, 2023
f10e0bb
fix weight for combined
HendrikSchmidt Jan 17, 2023
3c83511
include max prediction
HendrikSchmidt Jan 17, 2023
a6a4551
Update domain_adaptation.py
HendrikSchmidt Jan 17, 2023
2c59078
format
HendrikSchmidt Jan 18, 2023
479246a
Update domain_adaptation.py
HendrikSchmidt Jan 18, 2023
0c2af6d
changes for sepsis
HendrikSchmidt Jan 18, 2023
4cff9eb
Update domain_adaptation.py
HendrikSchmidt Jan 18, 2023
89eead0
Update da_to_csv.py
HendrikSchmidt Jan 18, 2023
00986de
Update da_to_csv.py
HendrikSchmidt Jan 18, 2023
6aa0229
Update domain_adaptation.py
HendrikSchmidt Jan 18, 2023
4e7174c
Update domain_adaptation.py
HendrikSchmidt Jan 18, 2023
7f84b2a
correct logging for loss weighted
HendrikSchmidt Jan 19, 2023
a30a56f
only use source weights
HendrikSchmidt Jan 19, 2023
16fba8a
use debug to set source datasets
HendrikSchmidt Jan 19, 2023
4d40c7b
Update run.py
HendrikSchmidt Jan 19, 2023
b351a70
Update domain_adaptation.py
HendrikSchmidt Jan 19, 2023
c881d87
fix loss_weighted
HendrikSchmidt Jan 19, 2023
467bd74
Update domain_adaptation.py
HendrikSchmidt Jan 19, 2023
9f6f4a6
Update domain_adaptation.py
HendrikSchmidt Jan 20, 2023
b543be6
Update da_to_csv.py
HendrikSchmidt Jan 20, 2023
8bc39b8
Update da_to_csv.py
HendrikSchmidt Jan 20, 2023
45454a6
Update domain_adaptation.py
HendrikSchmidt Jan 21, 2023
e4d739f
Update domain_adaptation.py
HendrikSchmidt Jan 21, 2023
870214a
Update da_to_csv.py
HendrikSchmidt Jan 22, 2023
e069f6b
rename script
HendrikSchmidt Jan 23, 2023
c0c733b
Create sepsis_to_csv.py
HendrikSchmidt Jan 23, 2023
c1e555a
Update sepsis_to_csv.py
HendrikSchmidt Jan 23, 2023
6a7300d
tables to latex
HendrikSchmidt Jan 25, 2023
3b548ab
Merge branch 'development' into domain-adaptation
HendrikSchmidt Aug 8, 2023
60ae45f
import domain adaptation
HendrikSchmidt Aug 8, 2023
b0583c7
make train run
HendrikSchmidt Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
392 changes: 392 additions & 0 deletions icu_benchmarks/models/domain_adaptation.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions icu_benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
setup_logging,
)
from icu_benchmarks.contants import RunMode
from icu_benchmarks.models import domain_adaptation


@gin.configurable("Run")
Expand Down Expand Up @@ -116,6 +117,15 @@ def main(my_args=tuple(sys.argv[1:])):
run_dir = create_run_dir(log_dir)
source_dir = args.source_dir
gin.parse_config_file(source_dir / "train_config.gin")
if args.command == "da":
gin_config_files = (
[Path(f"configs/experiments/{args.experiment}.gin")]
if args.experiment
else [Path(f"configs/models/{model}.gin"), Path(f"configs/tasks/{task}.gin")]
)
gin.parse_config_files_and_bindings(gin_config_files, args.gin_bindings, finalize_config=False)
domain_adaptation(name, args.data_dir, args.log_dir, args.seed, args.task_name, model, debug=args.debug)
return
else:
# Train
checkpoint = log_dir / args.checkpoint if args.checkpoint else None
Expand Down
3 changes: 3 additions & 0 deletions icu_benchmarks/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def build_parser() -> ArgumentParser:
evaluate.add_argument("-sn", "--source-name", required=True, type=Path, help="Name of the source dataset.")
evaluate.add_argument("--source-dir", required=True, type=Path, help="Directory containing gin and model weights.")

# DOMAIN ADAPTATION ARGUMENTS
prep_and_train = subparsers.add_parser("da", help="Run DA experiment.", parents=[parent_parser])

return parser


Expand Down
69 changes: 69 additions & 0 deletions scripts/results/da_results_to_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import csv

rawNamesMap = {
"target": "Target",
"aumc": "AUMCdb",
"eicu": "eICU",
"hirid": "HiRID",
"miiv": "MIMIC-IV",
"convex_combination_without_target": "Convex UDA",
"max_prediction": "Max Pooling",
"target_weight_0.5": "Weighted $\\alpha=1/3$",
"target_weight_2": "Weighted $\\alpha=2/3$",
"loss_weighted": "Weighted Loss",
"bayes_opt": "Weighted Bayes",
"target_with_predictions": "Prediction-Feature",
"cc_with_preds": "Combined",
}

def csv_to_dict(file_name):
with open(file_name, 'r') as file:
reader = csv.DictReader(file)
data = [row for row in reader]
tables = {}
for row in data:
row_without_target = {key: value for key, value in row.items() if key != 'target' and key != 'target_size' and key != 'model'}
tables.setdefault((row['target'], row['target_size']), {})[row['model']] = row_without_target
return tables


def dict_to_latex(combination, data, metric):
table = '\\begin{table}[h]\n'
table += '\\centering\n'
table += '\\footnotesize'
table += '\\caption{{Sepsis prediction on {0} with target size {1}, {2} with standard deviation.}}\n'.format(rawNamesMap[combination[0]], combination[1], "AUROC" if metric == "auc" else "AUPRC")
headers = ['Model']
for model, scores in data.items():
headers += [model]

table += '\\begin{tabular}{l|' + ''.join(['c'] * (len(headers) - 1)) + '}\n'
table += '\\textbf{' + '} & \\textbf{'.join(headers) + '}\\\\\n'
table += '\\hline\n'

for score_name, score in data[model].items():
if "_avg" in score_name:
raw_name = score_name.split("_avg")[0]
if raw_name == combination[0] or not raw_name in rawNamesMap:
continue
clean_name = rawNamesMap[raw_name]
values = [clean_name]
for model in headers[1:]:
scores = data[model]
avg = "{:.2f}".format(float(scores[score_name]))
std = "{:.2f}".format(float(scores[f"{raw_name}_std"]))
values.append(f"${avg} \pm {std}$")
table += ' & '.join(values) + '\\\\\n'

table += '\\end{tabular}\n'
table += '\\end{table}\n'
return table

if __name__ == '__main__':
for metric in ["auc", "pr"]:
file_name = f'../yaib_logs/sep_{metric}.csv'
data = csv_to_dict(file_name)
for key, row in data.items():
table = dict_to_latex(key, row, metric)
print(table)
print('\n' * 5)

69 changes: 69 additions & 0 deletions scripts/results/da_results_to_latex_sep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import csv

rawNamesMap = {
"target": "Target",
"aumc": "AUMCdb",
"eicu": "eICU",
"hirid": "HiRID",
"miiv": "MIMIC-IV",
"convex_combination_without_target": "Convex UDA",
"max_prediction": "Max Pooling",
"target_weight_0.5": "Weighted $\\alpha=1/3$",
"target_weight_2": "Weighted $\\alpha=2/3$",
"loss_weighted": "Weighted Loss",
"bayes_opt": "Weighted Bayes",
"target_with_predictions": "Prediction-Feature",
"cc_with_preds": "Combined",
}

def csv_to_dict(file_name):
with open(file_name, 'r') as file:
reader = csv.DictReader(file)
data = [row for row in reader]
tables = {}
for row in data:
row_without_target = {key: value for key, value in row.items() if key != 'target' and key != 'target_size' and key != 'model'}
tables.setdefault((row['target']), {})[row['target_size']] = row_without_target
return tables


def dict_to_latex(combination, data, metric):
table = '\\begin{table}[h]\n'
table += '\\centering\n'
table += '\\footnotesize'
table += '\\caption{{Sepsis prediction on {0} with LGBM, {1} with standard deviation.}}\n'.format(rawNamesMap[combination], "AUROC" if metric == "auc" else "AUPRC")
headers = ['Target Size']
for target_size, scores in data.items():
headers += [target_size]

table += '\\begin{tabular}{l|' + ''.join(['c'] * (len(headers) - 1)) + '}\n'
table += '\\textbf{' + '} & \\textbf{'.join(headers) + '}\\\\\n'
table += '\\hline\n'

for score_name, score in data[target_size].items():
if "_avg" in score_name:
raw_name = score_name.split("_avg")[0]
if raw_name == combination[0] or not raw_name in rawNamesMap:
continue
clean_name = rawNamesMap[raw_name]
values = [clean_name]
for target_size in headers[1:]:
scores = data[target_size]
avg = "{:.2f}".format(float(scores[score_name]))
std = "{:.2f}".format(float(scores[f"{raw_name}_std"]))
values.append(f"${avg} \pm {std}$")
table += ' & '.join(values) + '\\\\\n'

table += '\\end{tabular}\n'
table += '\\end{table}\n'
return table

if __name__ == '__main__':
for metric in ["auc", "pr"]:
file_name = f'../yaib_logs/sep_{metric}.csv'
data = csv_to_dict(file_name)
for key, row in data.items():
table = dict_to_latex(key, row, metric)
print(table)
print('\n' * 5)

53 changes: 53 additions & 0 deletions scripts/results/mortality_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
from pathlib import Path
import csv

models_dir = Path("../DA_new")
for metric in ["AUC", "PR"]:
for endpoint in models_dir.iterdir():
if endpoint.is_dir():
with open(models_dir / f"{endpoint.name}_{metric}_results.csv", "w") as csv_file:
writer = csv.writer(csv_file)
info = ["model", "target", "target_size"]
source_names = [
"target",
"aumc",
"eicu",
"hirid",
"miiv",
"convex_combination_without_target",
"max_prediction",
"target_weight_0.5",
"target_weight_1",
"target_weight_2",
"loss_weighted",
"squared_loss_weighted",
"bayes_opt",
"target_with_predictions",
"cc_with_preds",
]
stats_basis = ["avg", "std"]
stats = ["avg", "std"]
# combine fieldnames and stats
full_fields = [f"{source}_{stat}" for source in source_names for stat in stats]
writer = csv.DictWriter(csv_file, fieldnames=info + full_fields)

writer.writeheader()
for model in endpoint.iterdir():
for target in ["aumc", "eicu", "hirid", "miiv"]:
target_sizes = [500, 1000, 2000]
for target_size in target_sizes:
target_str = f"target_{target_size}"
if (model / target / target_str).exists():
with open(model / target / target_str / "averaged_source_metrics.json", "r") as f:
results = json.load(f)

row_data = {"model": model.name, "target": target, "target_size": target_size}
for stat in stats_basis:
for source, source_metrics in results.items():
if stat == "CI_0.95":
row_data[f"{source}_{stat}_min"] = source_metrics[metric][0][stat][0] * 100
row_data[f"{source}_{stat}_max"] = source_metrics[metric][0][stat][1] * 100
else:
row_data[f"{source}_{stat}"] = source_metrics[metric][0][stat] * 100
writer.writerow(row_data)
52 changes: 52 additions & 0 deletions scripts/results/sepsis_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from pathlib import Path
import csv

models_dir = Path("../DA_sep_new")
for metric in ["AUC", "PR"]:
for endpoint in models_dir.iterdir():
if endpoint.is_dir():
with open(models_dir / f"{endpoint.name}_{metric}_results.csv", "w") as csv_file:
writer = csv.writer(csv_file)
info = ["model", "target", "target_size"]
source_names = [
"target",
"aumc",
"hirid",
"miiv",
"convex_combination_without_target",
"max_prediction",
"target_weight_0.5",
"target_weight_1",
"target_weight_2",
"loss_weighted",
"squared_loss_weighted",
"bayes_opt",
"target_with_predictions",
"cc_with_preds",
]
stats_basis = ["avg", "std"]
stats = ["avg", "std"]
# combine fieldnames and stats
full_fields = [f"{source}_{stat}" for source in source_names for stat in stats]
writer = csv.DictWriter(csv_file, fieldnames=info + full_fields)

writer.writeheader()
for model in endpoint.iterdir():
for target in ["aumc", "hirid", "miiv"]:
target_sizes = [500, 1000, 2000]
for target_size in target_sizes:
target_str = f"target_{target_size}"
if (model / target / target_str).exists():
with open(model / target / target_str / "averaged_source_metrics.json", "r") as f:
results = json.load(f)

row_data = {"model": model.name, "target": target, "target_size": target_size}
for stat in stats_basis:
for source, source_metrics in results.items():
if stat == "CI_0.95":
row_data[f"{source}_{stat}_min"] = source_metrics[metric][0][stat][0] * 100
row_data[f"{source}_{stat}_max"] = source_metrics[metric][0][stat][1] * 100
else:
row_data[f"{source}_{stat}"] = source_metrics[metric][0][stat] * 100
writer.writerow(row_data)