Skip to content

Commit

Permalink
Update data pipeline to support post batch transforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696673931
  • Loading branch information
The kauldron Authors committed Nov 14, 2024
1 parent 1f99873 commit 460895a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
11 changes: 9 additions & 2 deletions kauldron/data/py/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class PyGrainPipeline(pipelines.Pipeline):
See doc:
Attributes:
transforms: A list of transformations to apply to the dataset. Each
transformation should be either a `grain.MapTransform` or a
transforms: A list of transformations to apply to the dataset before
batching. Each transformation should be either a `grain.MapTransform` or a
`grain.RandomMapTransform`.
post_batch_transforms: A list of transformations to apply after batching.
Each transformation should be either a `grain.MapTransform` or a
`grain.RandomMapTransform`.
num_epochs: Number of epoch. If missing, iterate indefinitely (number of
iteration is given by `cfg.num_training_steps`)
Expand All @@ -62,6 +65,9 @@ class PyGrainPipeline(pipelines.Pipeline):
transforms: tr_normalize.Transformations = dataclasses.field(
default_factory=tuple
)
post_batch_transforms: tr_normalize.Transformations = dataclasses.field(
default_factory=tuple
)

# Params only relevant for the root top-level dataset (when dataset mixture)
num_epochs: Optional[int] = None
Expand Down Expand Up @@ -113,6 +119,7 @@ def _root_ds(self) -> grain.IterDataset:
# batching.
if self.batch_size:
ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder)
ds = transform_utils.apply_transforms(ds, self.post_batch_transforms)

# Distribute the execution across multiple worker processes.
num_workers = _get_num_workers(self.num_workers)
Expand Down
12 changes: 7 additions & 5 deletions kauldron/data/py/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Utils for using Kauldron transforms with PyGrain."""

from typing import Any, Callable, Mapping
from typing import Any, Callable, Mapping, TypeVar

import grain.python as grain
from kauldron.data.transforms import abc as tr_abc
from kauldron.data.transforms import normalize as tr_normalize

_T = TypeVar("_T", grain.MapDataset, grain.IterDataset)


class PyGrainMapAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter from `kd.data.MapTransform` to pygrain."""
Expand Down Expand Up @@ -60,8 +62,8 @@ def _adapt_for_pygrain(


def apply_transforms(
ds: grain.MapDataset, transforms: tr_normalize.Transformations
) -> grain.MapDataset:
ds: _T, transforms: tr_normalize.Transformations
) -> _T:
"""Apply the transformations to the dataset."""
if isinstance(transforms, Mapping):
transforms = transforms.values()
Expand All @@ -72,8 +74,8 @@ def apply_transforms(


def _apply_transform(
ds: grain.MapDataset, tr: grain.Transformation
) -> grain.MapDataset:
ds: _T, tr: grain.Transformation
) -> _T:
"""Apply a list of single transformation."""
match tr:
case grain.MapTransform():
Expand Down
16 changes: 14 additions & 2 deletions kauldron/data/tf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC):
returns the `tf.data.Dataset` for the current process.
Attributes:
transforms: A list of `grain.Transformation` to apply to the dataset. Can be
a dict to allow easier CLI / sweep access (
transforms: A list of `grain.Transformation` to apply to the dataset before
batching. Can be a dict to allow easier CLI / sweep access (
`--cfg.train_ds.transforms.img_scale.in_vrange=(-1,1)`)
post_batch_transforms: A list of `grain.Transformation` to apply to the
dataset after batching. Can be a dict to allow easier CLI / sweep access
(`--cfg.train_ds.post_batch_transforms.img_scale.in_vrange=(-1,1)`)
tf_data_options: An optional tf.data.Options instance to be applied to the
dataset.
prefetch_size: Number of batches to prefetch for this dataset. Defaults to
Expand All @@ -79,6 +82,7 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC):
# TODO(epot): Users should also be able to specify drop_reminder or mask
batch_drop_remainder: bool = True
transforms: _Transforms = dataclasses.field(default_factory=tuple)
post_batch_transforms: _Transforms = dataclasses.field(default_factory=tuple)

# Those fields are only applied once at the top level
tf_data_options: Optional[tf.data.Options] = None
Expand Down Expand Up @@ -203,13 +207,21 @@ def _apply_transforms(self, ds: tf.data.Dataset) -> tf.data.Dataset:
transforms.extend(self.transforms.values())
else:
transforms.extend(self.transforms)

post_batch_transforms = []
if isinstance(self.post_batch_transforms, Mapping):
post_batch_transforms.extend(self.post_batch_transforms.values())
else:
post_batch_transforms.extend(self.post_batch_transforms)

if self.batch_size:
transforms.append(
grain.TfBatch(
batch_size=self.host_batch_size,
drop_remainder=self.batch_drop_remainder,
)
)
transforms.extend(post_batch_transforms)
ds = tr_utils.apply_transformations(ds, transforms)
return ds

Expand Down

0 comments on commit 460895a

Please sign in to comment.