diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 0531fc2c..9283912d 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -2,7 +2,7 @@ FFCV loader """ import enum -from os import environ +from os import environ, sched_getaffinity import ast from multiprocessing import cpu_count from re import sub @@ -140,7 +140,7 @@ def __init__(self, self.recompile = recompile if self.num_workers < 1: - self.num_workers = cpu_count() + self.num_workers = len(sched_getaffinity(0)) Compiler.set_num_threads(self.num_workers) diff --git a/ffcv/pipeline/compiler.py b/ffcv/pipeline/compiler.py index 987356cc..9ebae469 100644 --- a/ffcv/pipeline/compiler.py +++ b/ffcv/pipeline/compiler.py @@ -4,7 +4,7 @@ from multiprocessing import cpu_count import torch as ch import warnings - +from os import sched_getaffinity class Compiler: @@ -15,7 +15,7 @@ def set_enabled(cls, b): @classmethod def set_num_threads(cls, n): if n < 1 : - n = cpu_count() + n = len(sched_getaffinity(0)) cls.num_threads = n set_num_threads(n) ch.set_num_threads(n) diff --git a/ffcv/writer.py b/ffcv/writer.py index 1b70f74f..3615dd0a 100644 --- a/ffcv/writer.py +++ b/ffcv/writer.py @@ -1,6 +1,6 @@ from functools import partial from typing import Callable, List, Mapping -from os import SEEK_END, path +from os import SEEK_END, path, sched_getaffinity import numpy as np from time import sleep import ctypes @@ -143,7 +143,7 @@ def __init__(self, fname: str, fields: Mapping[str, Field], self.num_workers = num_workers # We use all cores by default if self.num_workers < 1: - self.num_workers = cpu_count() + self.num_workers = len(sched_getaffinity(0)) if not is_power_of_2(page_size): raise ValueError(f'page_size isnt a power of 2')