Skip to content

Commit

Permalink
add docstring to GeneratorBase
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 7, 2024
1 parent 9cccdf1 commit ed610bf
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions src/astroNN/nn/utilities/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,58 @@

class GeneratorBase(keras.utils.PyDataset):
"""
| Top-level class of astroNN data pipeline to generate data for NNs.
You need to implement the ``__getitem__`` in the generator sub-class
:History: 2019-Feb-17 - Updated - Henry Leung (University of Toronto)
Top-level data generator class to generate batches
Subclass this class to create a custom data generator, need to implement __getitem__ method
Parameters
----------
batch_size: int
batch size
shuffle: bool
shuffle the data or not after each epoch
steps_per_epoch: int
steps per epoch
data: dict
data dictionary
np_rng: numpy.random.Generator
numpy random generator
History
-------
2019-Feb-17 - Updated - Henry Leung (University of Toronto)
2024-Sept-6 - Updated - Henry Leung (University of Toronto)
"""

def __init__(self, batch_size, shuffle, steps_per_epoch, data, manual_reset):
super().__init__()
def __init__(self, data, *, batch_size=32, shuffle=True, manual_reset=False, steps_per_epoch=None, np_rng=None, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
self.shuffle = shuffle
# see if it needs to be reset idx manually if on_epoch_end() cannot be reached like val_generator
self.manual_reset = manual_reset

self.steps_per_epoch = steps_per_epoch
if steps_per_epoch is None: # all data should shae the same length
self.steps_per_epoch = int(np.ceil(len(data[list(data.keys())[0]]) / batch_size))
else:
self.steps_per_epoch = steps_per_epoch

if np_rng is None:
self.np_rng = np.random.default_rng()
else:
self.np_rng = np_rng

def __len__(self):
return self.steps_per_epoch

def _get_exploration_order(self, idx_list):
"""
:param idx_list:
:return:
"""
# shuffle (if applicable) and find exploration order
if self.shuffle is True:
if self.shuffle:
idx_list = np.copy(idx_list)
np.random.shuffle(idx_list)
self.np_rng.shuffle(idx_list)

return idx_list

def sparsify(self, y):
"""Returns labels in binary NumPy array"""
# n_classes = # Enter number of classes
# return np.array([[1 if y[i] == j else 0 for j in range(n_classes)]
# for i in range(y.shape[0])])
pass

def input_d_checking(self, inputs, idx_list_temp):
x_dict = {}
float_dtype = keras.backend.floatx()
Expand Down

0 comments on commit ed610bf

Please sign in to comment.