Skip to content

Commit

Permalink
Use explicit pydantic model in selector business logic (#454)
Browse files Browse the repository at this point in the history
Sorry for this bulky PR; this is really not my style :(

RHO-LOSS needs more sophisticated parameters to the downsampling config.
We currently use `dict` in business code to store the config, which is
hard to reason, error-prune, and too cumbersome to further support
RHO-LOSS.

Therefore, In this PR I changed the occurrence of all `dict` to the
appropriate pydantic models, and also simplified some selection config
settings.

After this PR, we will almost have replaced every occurrence of using
`dict` as the config with the pydantic model!
  • Loading branch information
XianzheMa authored May 31, 2024
1 parent 6c3d63f commit 028a646
Show file tree
Hide file tree
Showing 59 changed files with 783 additions and 779 deletions.
2 changes: 1 addition & 1 deletion benchmark/mnist/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
data:
dataset_id: mnist
transformations: ["transforms.ToTensor()",
Expand Down
2 changes: 1 addition & 1 deletion benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 10000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
data:
dataset_id: arxiv
bytes_parser_function: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 10000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
seed: 42
epochs_per_trigger: 1
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
seed: 42
epochs_per_trigger: 1
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
seed: 42
epochs_per_trigger: 1
data:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
data:
dataset_id: fmow
transformations: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
data:
dataset_id: huffpost
bytes_parser_function: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ training:
maximum_keys_in_memory: 1000
storage_backend: "database"
limit: -1
reset_after_trigger: True
tail_triggers: 0
data:
dataset_id: yearbook
transformations: []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Metric,
ModelConfig,
ModynPipelineConfig,
NewDataSelectionStrategy,
NewDataStrategyConfig,
OptimizationCriterion,
OptimizerConfig,
OptimizerParamGroup,
Expand Down Expand Up @@ -41,8 +41,8 @@ def gen_pipeline_config(name: str, trigger: TriggerConfig, eval_strategy: EvalSt
],
optimization_criterion=OptimizationCriterion(name="CrossEntropyLoss"),
checkpointing=CheckpointingConfig(activated=False),
selection_strategy=NewDataSelectionStrategy(
maximum_keys_in_memory=1000, storage_backend="database", limit=-1, reset_after_trigger=True
selection_strategy=NewDataStrategyConfig(
maximum_keys_in_memory=1000, storage_backend="database", limit=-1, tail_triggers=0
),
),
data=DataConfig(
Expand Down
6 changes: 3 additions & 3 deletions integrationtests/online_dataset/test_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
init_metadata_db,
register_pipeline,
)
from modyn.config.schema.pipeline import NewDataSelectionStrategy
from modyn.config.schema.pipeline import NewDataStrategyConfig
from modyn.selector.internal.grpc.generated.selector_pb2 import DataInformRequest
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub
from modyn.storage.internal.grpc.generated.storage_pb2 import (
Expand Down Expand Up @@ -205,8 +205,8 @@ def prepare_selector(num_dataworkers: int, keys: list[int]) -> Tuple[int, int]:
# We test the NewData strategy for finetuning on the new data, i.e., we reset without limit
# We also enforce high partitioning (maximum_keys_in_memory == 2) to ensure that works

strategy_config = NewDataSelectionStrategy(
maximum_keys_in_memory=2, limit=-1, reset_after_trigger=True, storage_backend="database"
strategy_config = NewDataStrategyConfig(
maximum_keys_in_memory=2, limit=-1, tail_triggers=0, storage_backend="database"
)
pipeline_config = get_minimal_pipeline_config(max(num_dataworkers, 1), strategy_config.model_dump(by_alias=True))
init_metadata_db(get_modyn_config())
Expand Down
36 changes: 18 additions & 18 deletions integrationtests/selector/integrationtest_selector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import grpc
from integrationtests.utils import get_minimal_pipeline_config, get_modyn_config, init_metadata_db, register_pipeline
from modyn.config.schema.pipeline import (
CoresetSelectionStrategy,
DownsamplingConfig,
NewDataSelectionStrategy,
CoresetStrategyConfig,
NewDataStrategyConfig,
PresamplingConfig,
)
from modyn.config.schema.sampling.downsampling_config import LossDownsamplingConfig
from modyn.selector.internal.grpc.generated.selector_pb2 import (
DataInformRequest,
GetAvailableLabelsRequest,
Expand All @@ -19,21 +19,21 @@
# TODO(54): Write more integration tests for different strategies.


def get_coreset_strategy_config() -> CoresetSelectionStrategy:
return CoresetSelectionStrategy(
def get_coreset_strategy_config() -> CoresetStrategyConfig:
return CoresetStrategyConfig(
maximum_keys_in_memory=250,
storage_backend="database", # TODO(#324): Support local backend
limit=-1,
reset_after_trigger=True,
tail_triggers=0,
presampling_config=PresamplingConfig(strategy="LabelBalanced", ratio=50),
)


def get_newdata_strategy_config() -> NewDataSelectionStrategy:
return NewDataSelectionStrategy(
def get_newdata_strategy_config() -> NewDataStrategyConfig:
return NewDataStrategyConfig(
maximum_keys_in_memory=2,
limit=-1,
reset_after_trigger=True,
tail_triggers=0,
storage_backend="database",
)

Expand Down Expand Up @@ -122,7 +122,7 @@ def test_label_balanced_force_same_size():

strategy_config = get_coreset_strategy_config()
strategy_config.maximum_keys_in_memory = 100
strategy_config.reset_after_trigger = False
strategy_config.tail_triggers = None
strategy_config.presampling_config.force_column_balancing = True
strategy_config.presampling_config.ratio = 90

Expand Down Expand Up @@ -197,7 +197,7 @@ def test_label_balanced_force_all_samples():

strategy_config = get_coreset_strategy_config()
strategy_config.maximum_keys_in_memory = 100
strategy_config.reset_after_trigger = False
strategy_config.tail_triggers = None
strategy_config.presampling_config.force_required_target_size = True
strategy_config.presampling_config.ratio = 90

Expand Down Expand Up @@ -405,17 +405,17 @@ def test_newdata() -> None:
assert len(total_samples) == 6


def test_abstract_downsampler(reset_after_trigger) -> None:
def test_abstract_downsampler(reset_after_trigger: bool) -> None:
selector_channel = connect_to_selector_servicer()
selector = SelectorStub(selector_channel)

# sampling every datapoint
strategy_config = get_coreset_strategy_config()
strategy_config.maximum_keys_in_memory = 50000
strategy_config.reset_after_trigger = reset_after_trigger
strategy_config.tail_triggers = 0 if reset_after_trigger else None
strategy_config.presampling_config.ratio = 20
strategy_config.presampling_config.strategy = "Random"
strategy_config.downsampling_config = DownsamplingConfig(strategy="Loss", ratio=10, sample_then_batch=False)
strategy_config.downsampling_config = LossDownsamplingConfig(ratio=10, sample_then_batch=False)

pipeline_config = get_minimal_pipeline_config(2, strategy_config.model_dump(by_alias=True))
pipeline_id = register_pipeline(pipeline_config, get_modyn_config())
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_empty_triggers() -> None:

# TODO(MaxiBoether): use local strategy here as well after implementing it
strategy_config = get_newdata_strategy_config()
strategy_config.reset_after_trigger = False
strategy_config.tail_triggers = None

pipeline_config = get_minimal_pipeline_config(2, strategy_config.model_dump(by_alias=True))
pipeline_id = register_pipeline(pipeline_config, get_modyn_config())
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_many_samples_evenly_distributed():
# TODO(MaxiBoether): use local strategy here as well after implementing it
strategy_config = get_newdata_strategy_config()
strategy_config.maximum_keys_in_memory = 5000
strategy_config.reset_after_trigger = False
strategy_config.tail_triggers = None

pipeline_config = get_minimal_pipeline_config(2, strategy_config.model_dump(by_alias=True))
pipeline_id = register_pipeline(pipeline_config, get_modyn_config())
Expand Down Expand Up @@ -795,7 +795,7 @@ def test_many_samples_unevenly_distributed():
# TODO(MaxiBoether): use local strategy here as well after implementing it
strategy_config = get_newdata_strategy_config()
strategy_config.maximum_keys_in_memory = 4999
strategy_config.reset_after_trigger = False
strategy_config.tail_triggers = None

pipeline_config = get_minimal_pipeline_config(2, strategy_config.model_dump(by_alias=True))
pipeline_id = register_pipeline(pipeline_config, get_modyn_config())
Expand Down Expand Up @@ -864,7 +864,7 @@ def test_get_available_labels(reset_after_trigger: bool):
# TODO(MaxiBoether): use local strategy here as well after implementing it
strategy_config = get_newdata_strategy_config()
strategy_config.maximum_keys_in_memory = 2
strategy_config.reset_after_trigger = reset_after_trigger
strategy_config.tail_triggers = 0 if reset_after_trigger else None

pipeline_config = get_minimal_pipeline_config(2, strategy_config.model_dump(by_alias=True))
pipeline_id = register_pipeline(pipeline_config, get_modyn_config())
Expand Down
15 changes: 10 additions & 5 deletions modyn/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from pathlib import Path

import yaml
from modyn.config.schema.sampling.downsampling_config import ( # noqa: F401
MultiDownsamplingConfig,
SingleDownsamplingConfig,
)

from .schema.config import (
BinaryFileByteOrder,
Expand All @@ -24,19 +28,20 @@
TensorboardConfig,
TrainingServerConfig,
)
from .schema.pipeline import (
from .schema.pipeline import ( # noqa: F401
CheckpointingConfig,
CoresetStrategyConfig,
DataConfig,
DownsamplingConfig,
EvalDataConfig,
EvaluationConfig,
FreshnessSamplingStrategyConfig,
FullModelStrategy,
IncrementalModelStrategy,
LrSchedulerConfig,
Metric,
ModelConfig,
ModynPipelineConfig,
MultiDownsamplingConfig,
NewDataStrategyConfig,
OptimizationCriterion,
OptimizerConfig,
OptimizerParamGroup,
Expand Down Expand Up @@ -79,8 +84,8 @@
"IncrementalModelStrategy",
"PipelineModelStorageConfig",
"PresamplingConfig",
"DownsamplingConfig",
"MultiDownsamplingConfig",
"SingleDownsamplingConfig",
"FreshnessSamplingStrategyConfig",
"SelectionStrategy",
"CheckpointingConfig",
"OptimizerParamGroup",
Expand Down
2 changes: 1 addition & 1 deletion modyn/config/examples/example-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ training:
name: NewDataStrategy
maximum_keys_in_memory: 500000
limit: -1
reset_after_trigger: True
tail_triggers: 0
storage_backend: database
processor_type: basic_processor_strategy
data:
Expand Down
8 changes: 8 additions & 0 deletions modyn/config/schema/modyn_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

from pydantic import BaseModel


class ModynBaseModel(BaseModel):
class Config:
extra = "forbid"
Loading

0 comments on commit 028a646

Please sign in to comment.