-
Notifications
You must be signed in to change notification settings - Fork 7
/
Xray_Reborn.py
140 lines (126 loc) · 5.05 KB
/
Xray_Reborn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Xray_Reborn dataset."""
import tensorflow_datasets as tfds
from pathlib import Path
import re
import random
# TODO(Xray_Reborn): Markdown description that will appear on the catalog page.
_DESCRIPTION = """
Description is **formatted** as markdown.
It should also contain any processing which has been applied (if any),
(e.g. corrupted example skipped, images cropped,...):
"""
# TODO(Xray_Reborn): BibTeX citation
_CITATION = """
@article{kermany2018identifying,
title={Identifying medical diagnoses and treatable diseases by image-based deep learning},
author={Kermany, Daniel S and Goldbaum, Michael and Cai, Wenjia and Valentim, Carolina CS and Liang, Huiying and Baxter, Sally L and McKeown, Alex and Yang, Ge and Wu, Xiaokang and Yan, Fangbing and others},
journal={Cell},
volume={172},
number={5},
pages={1122--1131},
year={2018},
publisher={Elsevier}
}
"""
ANNOTATIONS_FNAME = 'annotations.txt'
TEST_FNAME = 'test.txt'
VALID_FNAME = 'valid.txt'
DATA_DIR = '/content/XRay_'
CLASSES = ['NORMAL', 'PNEUMONIA']
PNEUMONIAS = ['VIRUS', 'BACTERIA']
CLASSES_EXT = CLASSES + PNEUMONIAS
FNAME_PAT = re.compile('^(.+?)-(.+)\.jpeg$')
RAND_SEED = 42
NUM_ORIG_TRAIN = 5232
NUM_ORIG_TEST = 623
NUM_ORIG_TOTAL = NUM_ORIG_TRAIN + NUM_ORIG_TEST
class XrayReborn(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for Xray_Reborn dataset."""
VERSION = tfds.core.Version('1.0.0')
RELEASE_NOTES = {
'1.0.0': 'Initial release.',
}
def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
# TODO(Xray_Reborn): Specifies the tfds.core.DatasetInfo object
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict({
# These are the features of your dataset like images, labels ...
'image': tfds.features.Image(shape=(None, None, 3), encoding_format='jpeg'),
'label': tfds.features.ClassLabel(names=CLASSES),
'fname': tfds.features.Text(),
}),
# If there's a common (input, target) tuple from the
# features, specify them here. They'll be used if
# `as_supervised=True` in `builder.as_dataset`.
supervised_keys=('image', 'label'), # Set to `None` to disable
homepage='https://dataset-homepage/',
citation=_CITATION,
)
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
# TODO(tfds): Downloads the data and defines the splits
# path = dl_manager.download_and_extract('file:///Users/wing/ixig/FourthBrain/Midterm/Data.zip')
path = Path(DATA_DIR)
train_paths = list((path / 'train').glob('*/*.jpeg'))
test_paths = list((path / 'test').glob('*/*.jpeg'))
combined_paths = train_paths + test_paths
random.seed(RAND_SEED)
random.shuffle(combined_paths)
train_paths = combined_paths[:len(train_paths)]
test_paths = combined_paths[-len(test_paths):]
assert len(train_paths) + len(test_paths) == len(combined_paths)
# print('\n>>>', len(train_paths), len(test_paths), len(combined_paths))
# raise Exception
# TODO(tfds): Returns the Dict[split names, Iterator[Key, Example]]
return {
'test': self._generate_act(combined_paths, TEST_FNAME),
'validation': self._generate_act(combined_paths, VALID_FNAME),
'train_act': self._generate_act(combined_paths, ANNOTATIONS_FNAME),
# 'train': self._generate_examples(train_paths, 0.0, 0.9),
# 'train_1pc': self._generate_examples(train_paths, 0.0, 0.01),
# 'train_2pc': self._generate_examples(train_paths, 0.0, 0.02),
# 'train_5pc': self._generate_examples(train_paths, 0.0, 0.05),
# 'train_10pc': self._generate_examples(train_paths, 0.0, 0.10),
}
def _generate_act(self, paths, fname):
examples = set()
with open(fname) as f:
for line in f:
example = line.strip()
if example: examples.add(example)
assert len(examples)
print(f'{fname}: {len(examples)}')
for f in paths:
matches = FNAME_PAT.match(f.name)
klass = matches.group(1)
assert klass in CLASSES_EXT
klass = 'PNEUMONIA' if klass in PNEUMONIAS else 'NORMAL'
key = '-'.join([f.parent.parent.name, f.parent.name, matches.group(1), matches.group(2)])
# print(key)
if key not in examples: continue
yield key, {
'image': f,
'label': klass,
'fname': key,
}
def _generate_examples(self, paths, start, end):
"""Yields examples."""
# TODO(tfds): Yields (key, example) tuples from the dataset
total = len(paths)
start_idx = int(start * total)
end_idx = int(end * total)
for f in paths[start_idx: end_idx]:
matches = FNAME_PAT.match(f.name)
klass = matches.group(1)
assert klass in CLASSES_EXT
klass = 'PNEUMONIA' if klass in PNEUMONIAS else 'NORMAL'
key = '-'.join([f.parent.parent.name, f.parent.name, matches.group(1), matches.group(2)])
# print(key)
yield key, {
'image': f,
'label': klass,
'fname': key,
}