Skip to content

Commit

Permalink
Add DataDriftTrigger: supports one Evidently metric (#409)
Browse files Browse the repository at this point in the history
This is a clean version of PR#367.
1. Add DataDriftTrigger class to supervisor. Supports one configurable
Evidently metric. Launches drift detection every N new data points. Data
used in detection are data trained in the previous trigger and all the
untriggered new data.
2. Update Trigger interface. `Trigger.inform()` returns a Generator
instead of List.
3. Add a generic ModelDownloader in supervisor.
4. Add example pipelines using DataDriftTrigger.
5. Add Evidently to pylint known third party.
6. Change ModelDownloader to embedding encoder utils. The downloader
sets up and returns the model. The DataDriftTrigger owns the model.

Future
1. Support multiple configurable Evidently metric. #416
2. Support Alibi-Detect. #414 
3. Support custom embedding encoder. #417
4. Support different windowing for detection data, e.g. compare with all
previously trained data. #418
5. Common DataLoaderInfo #415
  • Loading branch information
jenny011 authored and robinholzi committed May 18, 2024
1 parent e176778 commit 4a34044
Show file tree
Hide file tree
Showing 25 changed files with 1,709 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
pipeline:
name: ArXiv dataset Test Pipeline
description: Example pipeline
version: 1.0.0
model:
id: ArticleNet
config:
num_classes: 172
model_storage:
full_model_strategy:
name: "PyTorchFullModel"
training:
gpus: 1
device: "cuda:0"
dataloader_workers: 2
use_previous_model: True
initial_model: random
batch_size: 96
optimizers:
- name: "default"
algorithm: "SGD"
source: "PyTorch"
param_groups:
- module: "model"
config:
lr: 0.00002
momentum: 0.9
weight_decay: 0.01
optimization_criterion:
name: "CrossEntropyLoss"
checkpointing:
activated: False
selection_strategy:
name: NewDataStrategy
maximum_keys_in_memory: 10000
config:
storage_backend: "database"
limit: -1
reset_after_trigger: True
seed: 42
epochs_per_trigger: 1
data:
dataset_id: arxiv_train
bytes_parser_function: |
def bytes_parser_function(data: bytes) -> str:
return str(data, "utf8")
tokenizer: DistilBertTokenizerTransform

trigger:
id: DataDriftTrigger
trigger_config:
data_points_for_detection: 100000
sample_size: 5000

evaluation:
device: "cuda:0"
result_writers: ["json"]
datasets:
- dataset_id: arxiv_test
bytes_parser_function: |
def bytes_parser_function(data: bytes) -> str:
return str(data, "utf8")
tokenizer: DistilBertTokenizerTransform
batch_size: 96
dataloader_workers: 2
metrics:
- name: "Accuracy"
evaluation_transformer_function: |
import torch
def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:
return torch.argmax(model_output, dim=-1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
pipeline:
name: Huffpost dataset Test Pipeline
description: Example pipeline
version: 1.0.0
model:
id: ArticleNet
config:
num_classes: 55
model_storage:
full_model_strategy:
name: "PyTorchFullModel"
training:
gpus: 1
device: "cuda:0"
dataloader_workers: 2
use_previous_model: True
initial_model: random
batch_size: 64
optimizers:
- name: "default"
algorithm: "SGD"
source: "PyTorch"
param_groups:
- module: "model"
config:
lr: 0.00002
momentum: 0.9
weight_decay: 0.01
optimization_criterion:
name: "CrossEntropyLoss"
checkpointing:
activated: False
selection_strategy:
name: NewDataStrategy
maximum_keys_in_memory: 1000
config:
storage_backend: "database"
limit: -1
reset_after_trigger: True
seed: 42
epochs_per_trigger: 1
data:
dataset_id: huffpost_train
bytes_parser_function: |
def bytes_parser_function(data: bytes) -> str:
return str(data, "utf8")
tokenizer: DistilBertTokenizerTransform

trigger:
id: DataDriftTrigger
trigger_config:
data_points_for_detection: 5000
metric_name: mmd
metric_config:
bootstrap: False

evaluation:
device: "cuda:0"
result_writers: ["json"]
datasets:
- dataset_id: huffpost_test
bytes_parser_function: |
def bytes_parser_function(data: bytes) -> str:
return str(data, "utf8")
tokenizer: DistilBertTokenizerTransform
batch_size: 64
dataloader_workers: 2
metrics:
- name: "Accuracy"
evaluation_transformer_function: |
import torch
def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:
return torch.argmax(model_output, dim=-1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
pipeline:
name: Yearbook Test Pipeline
description: Example pipeline
version: 1.0.0
model:
id: YearbookNet
config:
num_input_channels: 1
num_classes: 2
model_storage:
full_model_strategy:
name: "PyTorchFullModel"
training:
gpus: 1
device: "cuda:0"
dataloader_workers: 2
use_previous_model: True
initial_model: random
batch_size: 64
optimizers:
- name: "default"
algorithm: "SGD"
source: "PyTorch"
param_groups:
- module: "model"
config:
lr: 0.001
momentum: 0.9
optimization_criterion:
name: "CrossEntropyLoss"
checkpointing:
activated: False
selection_strategy:
name: NewDataStrategy
maximum_keys_in_memory: 1000
config:
storage_backend: "database"
limit: -1
reset_after_trigger: True
seed: 42
epochs_per_trigger: 1
data:
dataset_id: yearbook_train
transformations: []
bytes_parser_function: |
import torch
import numpy as np
def bytes_parser_function(data: bytes) -> torch.Tensor:
return torch.from_numpy(np.frombuffer(data, dtype=np.float32)).reshape(1, 32, 32)
trigger:
id: DataDriftTrigger
trigger_config:
data_points_for_detection: 1000
metric_name: model
metric_config:
threshold: 0.7

evaluation:
device: "cuda:0"
result_writers: ["json"]
datasets:
- dataset_id: yearbook_test
bytes_parser_function: |
import torch
import numpy as np
def bytes_parser_function(data: bytes) -> torch.Tensor:
return torch.from_numpy(np.frombuffer(data, dtype=np.float32)).reshape(1, 32, 32)
batch_size: 64
dataloader_workers: 2
metrics:
- name: "Accuracy"
evaluation_transformer_function: |
import torch
def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:
return torch.argmax(model_output, dim=-1)
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- grpcio>=1.63
- protobuf==5.26.*
- types-protobuf==5.26.*
- evidently
- jsonschema
- psycopg2
- sqlalchemy>=2.0
Expand Down
2 changes: 0 additions & 2 deletions modyn/config/schema/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ class _BaseSelectionStrategyConfig(BaseModel):


class FreshnessSamplingStrategyConfig(_BaseSelectionStrategyConfig):

unused_data_ratio: float = Field(
0.0,
description=(
Expand All @@ -175,7 +174,6 @@ class FreshnessSamplingStrategyConfig(_BaseSelectionStrategyConfig):


class NewDataSelectionStrategyConfig(_BaseSelectionStrategyConfig):

limit_reset: LimitResetStrategy = Field(
description=(
"Strategy to follow for respecting the limit in case of reset. Only used when reset_after_trigger == true."
Expand Down
32 changes: 22 additions & 10 deletions modyn/supervisor/internal/pipeline_executor/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import traceback
from time import sleep
from typing import Any, Optional
from typing import Any, Generator, Optional

from modyn.common.benchmark import Stopwatch
from modyn.supervisor.internal.evaluation_result_writer import LogResultWriter
Expand Down Expand Up @@ -113,6 +113,9 @@ def _setup_trigger(self) -> None:

trigger_module = dynamic_module_import("modyn.supervisor.internal.triggers")
self.trigger: Trigger = getattr(trigger_module, trigger_id)(trigger_config)
self.trigger.init_trigger(self.pipeline_id, self.pipeline_config, self.modyn_config, self.eval_directory)
if self.previous_model_id is not None:
self.trigger.inform_previous_model(self.previous_model_id)

assert self.trigger is not None, "Error during trigger initialization"

Expand Down Expand Up @@ -177,15 +180,15 @@ def _handle_new_data(self, new_data: list[tuple[int, int, int]]) -> bool:

def _handle_new_data_batch(self, batch: list[tuple[int, int, int]]) -> bool:
self._sw.start("trigger_inform", overwrite=True)
triggering_indices = self.trigger.inform(batch)
num_triggers = len(triggering_indices)
self.pipeline_log["supervisor"]["num_triggers"] += len(triggering_indices)
triggering_indices: Generator[int, None, None] = self.trigger.inform(batch)
num_triggers = self._handle_triggers_within_batch(batch, triggering_indices)

logger.info(f"There are {num_triggers} triggers in this batch.")
self.pipeline_log["supervisor"]["num_triggers"] += num_triggers
self.pipeline_log["supervisor"]["trigger_batch_times"].append(
{"batch_size": len(batch), "time": self._sw.stop("trigger_inform"), "num_triggers": num_triggers}
)

logger.info(f"There are {num_triggers} triggers in this batch.")
self._handle_triggers_within_batch(batch, triggering_indices)
return num_triggers > 0

def _run_training(self, trigger_id: int) -> None:
Expand Down Expand Up @@ -223,6 +226,7 @@ def _run_training(self, trigger_id: int) -> None:
# We store the trained model for evaluation in any case.
self._sw.start("store_trained_model", overwrite=True)
model_id = self.grpc.store_trained_model(self.current_training_id)
self.trigger.inform_previous_model(model_id)
self.pipeline_log["supervisor"]["triggers"][trigger_id]["store_trained_model_time"] = self._sw.stop()

# Only if the pipeline actually wants to continue the training on it, we set previous model.
Expand Down Expand Up @@ -270,12 +274,17 @@ def _get_trigger_timespan(

return first_timestamp, last_timestamp

def _handle_triggers_within_batch(self, batch: list[tuple[int, int, int]], triggering_indices: list[int]) -> None:
def _handle_triggers_within_batch(
self, batch: list[tuple[int, int, int]], triggering_indices: Generator[int, None, None]
) -> int:
previous_trigger_idx = 0
logger.info("Handling triggers within batch.")
self._update_pipeline_stage_and_enqueue_msg(PipelineStage.HANDLE_TRIGGERS_WITHIN_BATCH, MsgType.GENERAL)

triggering_idx_list = []

for i, triggering_idx in enumerate(triggering_indices):
triggering_idx_list.append(triggering_idx)
self._update_pipeline_stage_and_enqueue_msg(PipelineStage.INFORM_SELECTOR_AND_TRIGGER, MsgType.GENERAL)
triggering_data = batch[previous_trigger_idx : triggering_idx + 1]
previous_trigger_idx = triggering_idx + 1
Expand All @@ -294,6 +303,7 @@ def _handle_triggers_within_batch(self, batch: list[tuple[int, int, int]], trigg

num_samples_in_trigger = self.grpc.get_number_of_samples(self.pipeline_id, trigger_id)
if num_samples_in_trigger > 0:
self.trigger.inform_previous_trigger_and_data_points(trigger_id, num_samples_in_trigger)
first_timestamp, last_timestamp = self._get_trigger_timespan(i == 0, triggering_data)
self.pipeline_log["supervisor"]["triggers"][trigger_id]["first_timestamp"] = first_timestamp
self.pipeline_log["supervisor"]["triggers"][trigger_id]["last_timestamp"] = last_timestamp
Expand All @@ -309,13 +319,13 @@ def _handle_triggers_within_batch(self, batch: list[tuple[int, int, int]], trigg

self.num_triggers = self.num_triggers + 1
if self.maximum_triggers is not None and self.num_triggers >= self.maximum_triggers:
break
return len(triggering_idx_list)

# we have to inform the Selector about the remaining data in this batch.
if len(triggering_indices) == 0:
if len(triggering_idx_list) == 0:
remaining_data = batch
else:
remaining_data = batch[triggering_indices[-1] + 1 :]
remaining_data = batch[triggering_idx_list[-1] + 1 :]

logger.info(f"There are {len(remaining_data)} data points remaining after the trigger.")
if len(remaining_data) > 0:
Expand All @@ -335,6 +345,8 @@ def _handle_triggers_within_batch(self, batch: list[tuple[int, int, int]], trigg
else:
self.remaining_data_range = None

return len(triggering_idx_list)

def _init_evaluation_writer(self, name: str, trigger_id: int) -> LogResultWriter:
return self.supervisor_supported_eval_result_writers[name](self.pipeline_id, trigger_id, self.eval_directory)

Expand Down
1 change: 1 addition & 0 deletions modyn/supervisor/internal/triggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os

from .amounttrigger import DataAmountTrigger # noqa: F401
from .datadrifttrigger import DataDriftTrigger # noqa: F401
from .timetrigger import TimeTrigger # noqa: F401
from .trigger import Trigger # noqa: F401

Expand Down
6 changes: 4 additions & 2 deletions modyn/supervisor/internal/triggers/amounttrigger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generator

from modyn.supervisor.internal.triggers.trigger import Trigger


Expand All @@ -14,12 +16,12 @@ def __init__(self, trigger_config: dict):

super().__init__(trigger_config)

def inform(self, new_data: list[tuple[int, int, int]]) -> list[int]:
def inform(self, new_data: list[tuple[int, int, int]]) -> Generator[int, None, None]:
assert self.remaining_data_points < self.data_points_for_trigger, "Inconsistent remaining datapoints"

first_idx = self.data_points_for_trigger - self.remaining_data_points - 1
triggering_indices = list(range(first_idx, len(new_data), self.data_points_for_trigger))

self.remaining_data_points = (self.remaining_data_points + len(new_data)) % self.data_points_for_trigger

return triggering_indices
yield from triggering_indices
Loading

0 comments on commit 4a34044

Please sign in to comment.