Skip to content

Commit

Permalink
Merge pull request #358 from aldakata/main
Browse files Browse the repository at this point in the history
SLURM set default num workers to numbas default
  • Loading branch information
andrewilyas authored May 6, 2024
2 parents 9834429 + 352d1ef commit 1071c57
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions ffcv/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
FFCV loader
"""
import enum
from os import environ
from os import environ, sched_getaffinity
import ast
from multiprocessing import cpu_count
from re import sub
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(self,
self.recompile = recompile

if self.num_workers < 1:
self.num_workers = cpu_count()
self.num_workers = len(sched_getaffinity(0))

Compiler.set_num_threads(self.num_workers)

Expand Down
4 changes: 2 additions & 2 deletions ffcv/pipeline/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from multiprocessing import cpu_count
import torch as ch
import warnings

from os import sched_getaffinity

class Compiler:

Expand All @@ -15,7 +15,7 @@ def set_enabled(cls, b):
@classmethod
def set_num_threads(cls, n):
if n < 1 :
n = cpu_count()
n = len(sched_getaffinity(0))
cls.num_threads = n
set_num_threads(n)
ch.set_num_threads(n)
Expand Down
4 changes: 2 additions & 2 deletions ffcv/writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from typing import Callable, List, Mapping
from os import SEEK_END, path
from os import SEEK_END, path, sched_getaffinity
import numpy as np
from time import sleep
import ctypes
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, fname: str, fields: Mapping[str, Field],
self.num_workers = num_workers
# We use all cores by default
if self.num_workers < 1:
self.num_workers = cpu_count()
self.num_workers = len(sched_getaffinity(0))

if not is_power_of_2(page_size):
raise ValueError(f'page_size isnt a power of 2')
Expand Down

0 comments on commit 1071c57

Please sign in to comment.