diff --git a/examples/instrumented_loop_example.py b/examples/instrumented_loop_example.py index 9176300e..b3e343bb 100644 --- a/examples/instrumented_loop_example.py +++ b/examples/instrumented_loop_example.py @@ -1,12 +1,18 @@ +import random from time import sleep -from flowcept import Flowcept, flowcept_loop +from flowcept import Flowcept, FlowceptLoop + +iterations = 3 -epochs = range(1, 3) with Flowcept(): - for _ in flowcept_loop(items=epochs, loop_name="epochs", item_name='epoch'): + loop = FlowceptLoop(iterations) + for item in loop: + loss = random.random() sleep(0.05) + print(item, loss) + # The following is optional, in case you want to capture values generated inside the loop. + loop.end_iter({"item": item, "loss": loss}) docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id}) -print(len(docs)) -assert len(docs) == 3 # 1 (parent_task) + 2 (sub_tasks) +assert len(docs) == iterations + 1 # The whole loop itself is a task diff --git a/pyproject.toml b/pyproject.toml index dbb6338c..5224a6ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ line-length = 100 [tool.ruff.lint] extend-select = ["E501", "D"] -ignore = ["D200", "D212"] +ignore = ["D200", "D212", "D105", "D401"] [tool.ruff.lint.pydocstyle] convention = "numpy" diff --git a/src/flowcept/__init__.py b/src/flowcept/__init__.py index 5455f15f..2d6b273b 100644 --- a/src/flowcept/__init__.py +++ b/src/flowcept/__init__.py @@ -2,14 +2,33 @@ from flowcept.configs import SETTINGS_PATH from flowcept.version import __version__ -from flowcept.flowcept_api.flowcept_controller import Flowcept -from flowcept.instrumentation.decorators.flowcept_task import flowcept_task, flowcept_loop + from flowcept.commons.flowcept_dataclasses.workflow_object import ( WorkflowObject, ) def __getattr__(name): + if name == "Flowcept": + from flowcept.flowcept_api.flowcept_controller import Flowcept + + return Flowcept + + elif name == "flowcept_task": + from flowcept.instrumentation.decorators.flowcept_task import flowcept_task + + return flowcept_task + + elif name == "FlowceptLoop": + from flowcept.instrumentation.decorators.flowcept_loop import FlowceptLoop + + return FlowceptLoop + + elif name == "telemetry_flowcept_task": + from flowcept.instrumentation.decorators.flowcept_task import telemetry_flowcept_task + + return telemetry_flowcept_task + if name == "MLFlowInterceptor": from flowcept.flowceptor.adapters.mlflow.mlflow_interceptor import ( MLFlowInterceptor, @@ -54,10 +73,11 @@ def __getattr__(name): "TensorboardInterceptor", "ZambezeInterceptor", "TaskQueryAPI", - "WorkflowObject", "flowcept_task", - "flowcept_loop", + "FlowceptLoop", + "telemetry_flowcept_task", "Flowcept", + "WorkflowObject", "__version__", "SETTINGS_PATH", ] diff --git a/src/flowcept/commons/flowcept_dataclasses/task_object.py b/src/flowcept/commons/flowcept_dataclasses/task_object.py index f3de966c..b88fbdb0 100644 --- a/src/flowcept/commons/flowcept_dataclasses/task_object.py +++ b/src/flowcept/commons/flowcept_dataclasses/task_object.py @@ -1,11 +1,11 @@ """Task object module.""" -from enum import Enum from typing import Dict, AnyStr, Any, Union, List import msgpack import flowcept from flowcept.commons.flowcept_dataclasses.telemetry import Telemetry +from flowcept.commons.vocabulary import Status from flowcept.configs import ( HOSTNAME, PRIVATE_IP, @@ -16,25 +16,6 @@ ) -class Status(str, Enum): - """Status class. - - Inheriting from str here for JSON serialization. - """ - - SUBMITTED = "SUBMITTED" - WAITING = "WAITING" - RUNNING = "RUNNING" - FINISHED = "FINISHED" - ERROR = "ERROR" - UNKNOWN = "UNKNOWN" - - @staticmethod - def get_finished_statuses(): - """Get finished status.""" - return [Status.FINISHED, Status.ERROR] - - class TaskObject: """Task class.""" diff --git a/src/flowcept/commons/query_utils.py b/src/flowcept/commons/query_utils.py index 1b3f6eea..097e61de 100644 --- a/src/flowcept/commons/query_utils.py +++ b/src/flowcept/commons/query_utils.py @@ -6,7 +6,7 @@ import pandas as pd -from flowcept.commons.flowcept_dataclasses.task_object import Status +from flowcept.commons.vocabulary import Status def get_doc_status(row): diff --git a/src/flowcept/commons/utils.py b/src/flowcept/commons/utils.py index 7730fb4f..d37ffb62 100644 --- a/src/flowcept/commons/utils.py +++ b/src/flowcept/commons/utils.py @@ -13,7 +13,7 @@ from flowcept import configs from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.configs import PERF_LOG -from flowcept.commons.flowcept_dataclasses.task_object import Status +from flowcept.commons.vocabulary import Status def get_utc_now() -> float: diff --git a/src/flowcept/commons/vocabulary.py b/src/flowcept/commons/vocabulary.py index 3611836f..8c7010b7 100644 --- a/src/flowcept/commons/vocabulary.py +++ b/src/flowcept/commons/vocabulary.py @@ -1,5 +1,7 @@ """Vocab module.""" +from enum import Enum + class Vocabulary: """Vocab class.""" @@ -14,3 +16,22 @@ class Settings: MLFLOW_KIND = "mlflow" TENSORBOARD_KIND = "tensorboard" DASK_KIND = "dask" + + +class Status(str, Enum): + """Status class. + + Inheriting from str here for JSON serialization. + """ + + SUBMITTED = "SUBMITTED" + WAITING = "WAITING" + RUNNING = "RUNNING" + FINISHED = "FINISHED" + ERROR = "ERROR" + UNKNOWN = "UNKNOWN" + + @staticmethod + def get_finished_statuses(): + """Get finished status.""" + return [Status.FINISHED, Status.ERROR] diff --git a/src/flowcept/flowceptor/adapters/dask/dask_interceptor.py b/src/flowcept/flowceptor/adapters/dask/dask_interceptor.py index 6b3ba13e..0bc87658 100644 --- a/src/flowcept/flowceptor/adapters/dask/dask_interceptor.py +++ b/src/flowcept/flowceptor/adapters/dask/dask_interceptor.py @@ -5,8 +5,8 @@ from flowcept import WorkflowObject from flowcept.commons.flowcept_dataclasses.task_object import ( TaskObject, - Status, ) +from flowcept.commons.vocabulary import Status from flowcept.flowceptor.adapters.base_interceptor import ( BaseInterceptor, ) diff --git a/src/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py b/src/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py index f55c9e9a..b0abdbbc 100644 --- a/src/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py +++ b/src/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py @@ -8,8 +8,8 @@ from flowcept.commons.flowcept_dataclasses.task_object import ( TaskObject, - Status, ) +from flowcept.commons.vocabulary import Status from flowcept.commons.utils import get_utc_now from flowcept.flowceptor.adapters.interceptor_state_manager import ( InterceptorStateManager, diff --git a/src/flowcept/flowceptor/consumers/document_inserter.py b/src/flowcept/flowceptor/consumers/document_inserter.py index 06d14670..81d4df71 100644 --- a/src/flowcept/flowceptor/consumers/document_inserter.py +++ b/src/flowcept/flowceptor/consumers/document_inserter.py @@ -213,7 +213,7 @@ def _message_handler(self, msg_obj: dict): self._handle_workflow_message(msg_obj) return True elif msg_type is None: - self.logger.warning(f"Message without type???\n {msg_obj}") + self.logger.error(f"Message without type??? --> {msg_obj}") return True else: self.logger.error("Unexpected message type") diff --git a/src/flowcept/instrumentation/decorators/flowcept_loop.py b/src/flowcept/instrumentation/decorators/flowcept_loop.py new file mode 100644 index 00000000..7bd05f11 --- /dev/null +++ b/src/flowcept/instrumentation/decorators/flowcept_loop.py @@ -0,0 +1,153 @@ +"""FlowCept Loop module.""" + +import typing +import uuid +from time import time + +from flowcept import Flowcept +from flowcept.commons.flowcept_logger import FlowceptLogger +from flowcept.commons.vocabulary import Status +from flowcept.flowceptor.adapters.instrumentation_interceptor import InstrumentationInterceptor + + +class FlowceptLoop: + """ + A utility class to wrap and instrument iterable loops for telemetry and tracking. + + The `FlowceptLoop` class supports iterating over a collection of items or a numeric range + while capturing metadata for each iteration and for the loop as a whole. This is particularly + useful in scenarios where tracking and instrumentation of loop executions is required. + + Parameters + ---------- + items : typing.Union[typing.Sized, int] + The items to iterate over. Must either be an iterable with a `__len__` method or an integer + representing the range of iteration. + loop_name : str, optional + A descriptive name for the loop (default is "loop"). + item_name : str, optional + The name used for each item in the telemetry (default is "item"). + parent_task_id : str, optional + The ID of the parent task associated with the loop, if applicable (default is None). + workflow_id : str, optional + The workflow ID to associate with this loop. If not provided, it will be generated or + inferred from the current workflow context. + + Raises + ------ + Exception + If `items` is not an iterable with a `__len__` method or an integer. + + Notes + ----- + This class integrates with the `Flowcept` system for telemetry and tracking, ensuring + detailed monitoring of loops and their iterations. It is designed for cases where + capturing granular runtime behavior of loops is critical. + """ + + def __init__( + self, + items: typing.Union[typing.Sized, int], + loop_name="loop", + item_name="item", + parent_task_id=None, + workflow_id=None, + ): + self._next_counter = 0 + self.logger = FlowceptLogger() + if hasattr(items, "__len__"): + self._iterable = items + self._max = len(self._iterable) + elif isinstance(items, int): + self._iterable = range(items) + self._max = len(self._iterable) + else: + raise Exception("You must use an iterable has at least a __len__ method defined.") + + self._interceptor = InstrumentationInterceptor.get_instance() + self._iterator = iter(self._iterable) + self._last_iteration_task = None + self._current_iteration_task = {} + self._loop_name = loop_name + self._item_name = item_name + self._parent_task_id = parent_task_id + self._workflow_id = workflow_id or Flowcept.current_workflow_id or str(uuid.uuid4()) + + def __iter__(self): + return self + + def _begin_loop(self): + self.logger.debug("Capturing loop init.") + self._whole_loop_task = { + "started_at": (started_at := time()), + "task_id": str(started_at), + "type": "task", + "activity_id": self._loop_name, + "workflow_id": self._workflow_id, + } + if self._parent_task_id: + self._whole_loop_task["parent_task_id"] = self._parent_task_id + self._interceptor.intercept(self._whole_loop_task) + self._capture_iteration_bounds() + + def _end_loop(self): + self._capture_iteration_bounds() + self.logger.debug("Capturing loop end.") + self._end_iteration_task(self._last_iteration_task) + self._whole_loop_task["status"] = Status.FINISHED.value + self._whole_loop_task["ended_at"] = time() + self._interceptor.intercept(self._whole_loop_task) + + def __next__(self): + # Basic idea: the beginning of the current iteration is the end of the last + self._current_item = next(self._iterator) + + if self._next_counter == 0: + self._begin_loop() + elif self._next_counter == self._max - 1: + self._end_loop() + elif self._next_counter < self._max - 1: + self._capture_iteration_bounds() + + self._next_counter += 1 + return self._current_item + + def _capture_iteration_bounds(self): + if self._last_iteration_task is not None: + self.logger.debug(f"Capturing the end of iteration {self._next_counter-1}.") + self._end_iteration_task(self._last_iteration_task) + + self.logger.debug(f"Capturing the init of iteration {self._next_counter}.") + self._current_iteration_task = self._begin_iteration_task(self._current_item) + self._last_iteration_task = self._current_iteration_task + + def _begin_iteration_task(self, item): + iteration_task = { + "workflow_id": self._workflow_id, + "activity_id": self._loop_name + "_iteration", + "used": {"i": self._next_counter, self._item_name: item}, + "parent_task_id": self._whole_loop_task["task_id"], + "started_at": time(), + "telemetry_at_start": self._interceptor.telemetry_capture.capture().to_dict(), + "type": "task", + } + return iteration_task + + def _end_iteration_task(self, iteration_task): + iteration_task["status"] = "FINISHED" + self._interceptor.intercept(self._last_iteration_task) + + def end_iter(self, generated_value: typing.Dict): + """ + Finalizes the current iteration by associating generated values with the iteration metadata. + + This method updates the metadata of the current iteration to include the values generated + during the iteration, ensuring they are properly logged and tracked. + + Parameters + ---------- + generated_value : dict + A dictionary containing the generated values for the current iteration. These values + will be stored in the `generated` field of the iteration's metadata. + """ + self._current_iteration_task["generated"] = generated_value diff --git a/src/flowcept/instrumentation/decorators/flowcept_task.py b/src/flowcept/instrumentation/decorators/flowcept_task.py index a7180c23..13677378 100644 --- a/src/flowcept/instrumentation/decorators/flowcept_task.py +++ b/src/flowcept/instrumentation/decorators/flowcept_task.py @@ -2,11 +2,10 @@ from time import time from functools import wraps -from flowcept import Flowcept from flowcept.commons.flowcept_dataclasses.task_object import ( TaskObject, - Status, ) +from flowcept.commons.vocabulary import Status from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.commons.utils import replace_non_serializable @@ -14,6 +13,7 @@ REPLACE_NON_JSON_SERIALIZABLE, INSTRUMENTATION_ENABLED, ) +from flowcept.flowcept_api.flowcept_controller import Flowcept from flowcept.flowceptor.adapters.instrumentation_interceptor import InstrumentationInterceptor @@ -35,7 +35,8 @@ def default_args_handler(task_message: TaskObject, *args, **kwargs): def telemetry_flowcept_task(func=None): """Get telemetry task.""" - interceptor = InstrumentationInterceptor.get_instance() + if INSTRUMENTATION_ENABLED: + interceptor = InstrumentationInterceptor.get_instance() def decorator(func): @wraps(func) @@ -44,14 +45,14 @@ def wrapper(*args, **kwargs): task_obj["type"] = "task" task_obj["started_at"] = time() task_obj["activity_id"] = func.__qualname__ - task_obj["task_id"] = str(id(task_obj)) - task_obj["workflow_id"] = kwargs.pop("workflow_id") + task_obj["task_id"] = str(task_obj["started_at"]) + task_obj["workflow_id"] = kwargs.pop("workflow_id", Flowcept.current_workflow_id) task_obj["used"] = kwargs tel = interceptor.telemetry_capture.capture() if tel is not None: task_obj["telemetry_at_start"] = tel.to_dict() try: - result = func(*args, **kwargs) + result = func(task_id=task_obj["task_id"], *args, **kwargs) task_obj["status"] = Status.FINISHED.value except Exception as e: task_obj["status"] = Status.ERROR.value @@ -75,7 +76,8 @@ def wrapper(*args, **kwargs): def lightweight_flowcept_task(func=None): """Get lightweight task.""" - interceptor = InstrumentationInterceptor.get_instance() + if INSTRUMENTATION_ENABLED: + interceptor = InstrumentationInterceptor.get_instance() def decorator(func): @wraps(func) @@ -101,8 +103,9 @@ def wrapper(*args, **kwargs): def flowcept_task(func=None, **decorator_kwargs): """Get flowcept task.""" - interceptor = InstrumentationInterceptor.get_instance() - logger = FlowceptLogger() + if INSTRUMENTATION_ENABLED: + interceptor = InstrumentationInterceptor.get_instance() + logger = FlowceptLogger() def decorator(func): @wraps(func) @@ -144,46 +147,3 @@ def wrapper(*args, **kwargs): return decorator else: return decorator(func) - - -def _flowcept_loop_task(generator_func): - interceptor = InstrumentationInterceptor.get_instance() - - def wrapper(*args, **kwargs): - whole_loop_obj = TaskObject() - whole_loop_obj.started_at = time() - whole_loop_obj.task_id = str(whole_loop_obj.started_at) - whole_loop_obj.activity_id = kwargs.pop("loop_name", "loop") - whole_loop_obj.workflow_id = kwargs.pop("workflow_id", Flowcept.current_workflow_id) - item_name = kwargs.pop("item_name", "item") - - i = 0 - for item in generator_func(*args, **kwargs): - iteration_obj = TaskObject() - iteration_obj.activity_id = whole_loop_obj.activity_id + "_iteration" - iteration_obj.parent_task_id = whole_loop_obj.task_id - iteration_obj.workflow_id = whole_loop_obj.workflow_id - iteration_obj.started_at = time() - iteration_obj.used = {"i": i} - iteration_obj.telemetry_at_start = interceptor.telemetry_capture.capture() - if type(item) in {int, float, str}: - iteration_obj.used[item_name] = item - else: - iteration_obj.used[item_name] = id(item) - iteration_obj.task_id = str(iteration_obj.started_at) - yield item - iteration_obj.ended_at = time() - iteration_obj.telemetry_at_end = interceptor.telemetry_capture.capture() - iteration_obj.status = Status.FINISHED - interceptor.intercept(iteration_obj.to_dict()) - i += 1 - interceptor.intercept(whole_loop_obj.to_dict()) - - return wrapper - - -@_flowcept_loop_task -def flowcept_loop(items, loop_name=None, item_name=None, workflow_id=None, *args, **kwargs): - """Instrumentation facility to help you capture loops.""" - for item in items: - yield item diff --git a/src/flowcept/instrumentation/decorators/flowcept_torch.py b/src/flowcept/instrumentation/decorators/flowcept_torch.py index c306a703..adaa98e4 100644 --- a/src/flowcept/instrumentation/decorators/flowcept_torch.py +++ b/src/flowcept/instrumentation/decorators/flowcept_torch.py @@ -2,9 +2,7 @@ from time import time from functools import wraps -from flowcept.commons.flowcept_dataclasses.task_object import ( - Status, -) +from flowcept.commons.vocabulary import Status from typing import List, Dict import uuid diff --git a/tests/api/task_query_api_test.py b/tests/api/task_query_api_test.py index 60fdf9f7..13d39ddd 100644 --- a/tests/api/task_query_api_test.py +++ b/tests/api/task_query_api_test.py @@ -14,8 +14,8 @@ from flowcept.commons.daos.docdb_dao.docdb_dao_base import DocumentDBDAO from flowcept.commons.flowcept_dataclasses.task_object import ( TaskObject, - Status, ) +from flowcept.commons.vocabulary import Status from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.configs import WEBSERVER_PORT, WEBSERVER_HOST, MONGO_ENABLED from flowcept.flowcept_api.task_query_api import TaskQueryAPI diff --git a/tests/decorator_tests/flowcept_task_decorator_test.py b/tests/decorator_tests/flowcept_task_decorator_test.py index 833d2ac2..c72fab1e 100644 --- a/tests/decorator_tests/flowcept_task_decorator_test.py +++ b/tests/decorator_tests/flowcept_task_decorator_test.py @@ -7,15 +7,16 @@ from time import time, sleep import flowcept.instrumentation.decorators -from flowcept import Flowcept +from flowcept import Flowcept, FlowceptLoop import unittest from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.commons.utils import assert_by_querying_tasks_until +from flowcept.commons.vocabulary import Status from flowcept.instrumentation.decorators.flowcept_task import ( flowcept_task, - lightweight_flowcept_task, flowcept_loop, + lightweight_flowcept_task, ) @@ -259,25 +260,74 @@ def test_decorated_function_timed(self): print("Overheads: " + str(overheads)) assert all(map(lambda v: v < threshold, overheads)) + def test_flowcept_loop_types(self): + + with Flowcept(): + items = range(3) + loop = FlowceptLoop(items=items) + for _ in loop: + pass + docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id}) + assert len(docs) == len(items) + 1 + + with Flowcept(): + items = [10, 20, 30] + loop = FlowceptLoop(items=items) + for _ in loop: + pass + docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id}) + assert len(docs) == len(items) + 1 + + with Flowcept(): + items = "abcd" + loop = FlowceptLoop(items=items) + for _ in loop: + pass + docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id}) + assert len(docs) == len(items) + 1 + + with Flowcept(): + items = np.array([0.5, 1.0, 1.5]) + loop = FlowceptLoop(items=items, loop_name="our_loop") + for _ in loop: + loop.end_iter({"a": 1}) + docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id, "activity_id": "our_loop_iteration"}) + assert len(docs) == len(items) + assert all(d["generated"]["a"] == 1 for d in docs) + def test_flowcept_loop_generator(self): - epochs = range(1, 3) + number_of_epochs = 3 + epochs = range(0, number_of_epochs) with Flowcept(): - for _ in flowcept_loop(items=epochs, loop_name="epochs", item_name='epoch'): + loop = FlowceptLoop(items=epochs, loop_name="epochs", item_name="epoch") + for e in loop: sleep(0.05) - + loss = random.random() + print(e, loss) + loop.end_iter({"loss": loss}) docs = Flowcept.db.query(filter={"workflow_id": Flowcept.current_workflow_id}) - assert len(docs) == 3 # 1 (parent_task) + 2 (sub_tasks) + assert len(docs) == number_of_epochs+1 # 1 (parent_task) + #epochs (sub_tasks) iteration_tasks = [] whole_loop_task = None for d in docs: if d["activity_id"] == "epochs": whole_loop_task = d + assert whole_loop_task["ended_at"] is not None + assert whole_loop_task["status"] == Status.FINISHED.value else: + assert d["started_at"] is not None + assert d["used"]["i"] >= 0 + assert d["generated"]["loss"] > 0 iteration_tasks.append(d) - assert len(iteration_tasks) == 2 - assert all(t["parent_task_id"] == whole_loop_task["task_id"] for t in iteration_tasks) - + assert len(iteration_tasks) == number_of_epochs + sorted_iteration_tasks = sorted(iteration_tasks, key=lambda x: x['used']['i']) + for i in range(len(sorted_iteration_tasks)): + t = sorted_iteration_tasks[i] + assert t["used"]["i"] == i + assert t["used"]["epoch"] == i + assert t["status"] == Status.FINISHED.value + assert t["parent_task_id"] == whole_loop_task["task_id"] diff --git a/tests/decorator_tests/ml_tests/dl_trainer.py b/tests/decorator_tests/ml_tests/dl_trainer.py index 8abb8828..8eabcdca 100644 --- a/tests/decorator_tests/ml_tests/dl_trainer.py +++ b/tests/decorator_tests/ml_tests/dl_trainer.py @@ -1,6 +1,7 @@ from uuid import uuid4 import torch +from torch.utils.data import Subset, DataLoader from torchvision import datasets, transforms from torch import nn, optim from torch.nn import functional as F @@ -88,28 +89,30 @@ def forward(self, x): class ModelTrainer(object): @staticmethod - def build_train_test_loader(batch_size=128, random_seed=0): + def build_train_test_loader(batch_size=128, random_seed=0, debug=True, subset_size=1000): torch.manual_seed(random_seed) - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "mnist_data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor()]), - ), - batch_size=batch_size, - shuffle=True, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "mnist_data", - train=False, - transform=transforms.Compose([transforms.ToTensor()]), - ), - batch_size=batch_size, - shuffle=True, + # Load the full MNIST dataset + train_dataset = datasets.MNIST( + "mnist_data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor()]), + ) + test_dataset = datasets.MNIST( + "mnist_data", + train=False, + transform=transforms.Compose([transforms.ToTensor()]), ) + + if debug: + # Create smaller subsets for debugging + train_dataset = Subset(train_dataset, range(subset_size)) + test_dataset = Subset(test_dataset, range(subset_size)) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + return train_loader, test_loader @staticmethod