Skip to content

Commit

Permalink
Add milabench.datasets.fake_images, improve instruments
Browse files Browse the repository at this point in the history
  • Loading branch information
breuleux committed Mar 25, 2022
1 parent 5b2d40a commit e375d90
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 30 deletions.
8 changes: 8 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[flake8]
ignore=E501,E203,W503,F722
exclude=tests/feat38.py
per-file-ignores=
__init__.py:F401
examples/*:F821
tests/*:F821
tests/test_tools.py:F401
Empty file added milabench/datasets/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions milabench/datasets/fake_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import multiprocessing
import os

from torchvision.datasets import FakeData
from tqdm import tqdm

from ..fs import XPath


def write(args):
image_size, count, offset, outdir = args
dataset = FakeData(
size=count, image_size=image_size, num_classes=1000, random_offset=offset
)

for i, (image, y) in tqdm(enumerate(dataset), total=count):
class_val = int(y)
image_name = f"{offset + i}.jpeg"

path = os.path.join(outdir, str(class_val))
os.makedirs(path, exist_ok=True)

image_path = os.path.join(path, image_name)
image.save(image_path)


def generate(image_size, n, outdir):
p_count = min(multiprocessing.cpu_count(), 8)
count = n // p_count
offset_list = [(image_size, count, i, outdir) for i in range(0, n, count)]
pool = multiprocessing.Pool(p_count)
pool.map(write, offset_list)


def generate_sets(root, sets, shape):
root = XPath(root)
sentinel = root / "done"
if sentinel.exists():
print(f"{root} was already generated")
return
if root.exists():
print(f"{root} exists but is not marked complete; deleting")
root.rm()
for name, n in sets.items():
print(f"Generating {name}")
generate(shape, n, os.path.join(root, name))
sentinel.touch()
121 changes: 115 additions & 6 deletions milabench/instruments.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,102 @@
import time
from threading import Thread

from hrepr import trepr
from voir.tools import gated, parametrized

from .utils import REAL_STDOUT


@gated("--display", "Display given")
def display(ov):
ov.given.display()


class Plain:
def __init__(self, x, fmt="{}"):
self._object = x
self.fmt = fmt

def __rich__(self):
return self.fmt.format(str(trepr(self._object, max_depth=2, sequence_max=10)))


@gated("--dash", "Display dash")
def dash(ov):
"""Create a simple terminal dashboard using rich.
This displays a live table of the last value for everything given.
"""
yield ov.phases.init

from rich.console import Console, Group
from rich.live import Live
from rich.pretty import Pretty
from rich.progress import ProgressBar
from rich.table import Table

gv = ov.given

# Current rows are stored here
rows = {}

# First, a table with the latest value of everything that was given
table = Table.grid(padding=(0, 3, 0, 0))
table.add_column("key", style="bold green")
table.add_column("value")

console = Console(color_system="standard", file=REAL_STDOUT)

@ (gv["?#stdout"].roll(10) | gv["?#stderr"].roll(10)).subscribe
def _(txt):
ov.give(stdout="".join(txt))

# This updates the table every time we get new values
@gv.where("!silent").subscribe
def _(values):
if {"total", "progress", "descr"}.issubset(values.keys()):
k = values["descr"]
k = f"\\[{k}]"
if k not in rows:
progress_bar = ProgressBar(finished_style="blue", width=50)
table.add_row(k, progress_bar)
rows[k] = progress_bar
progress_bar = rows[k]
progress_bar.update(total=values["total"], completed=values["progress"])
return

units = values.get("units", None)

for k, v in values.items():
if k.startswith("$") or k.startswith("#") or k == "units":
continue
if k in rows:
rows[k]._object = v
else:
if units:
rows[k] = Plain(v, f"{{}} {units}")
else:
rows[k] = Plain(v)
table.add_row(k, rows[k])

with Live(table, refresh_per_second=4, console=console):
yield ov.phases.run_script


@parametrized("--stop", type=int, default=0, help="Number of iterations to run for")
def stop(ov):
yield ov.phases.load_script
yield ov.phases.load_script(priority=-100)
stop = ov.options.stop
if stop:
ov.given.where("step").skip(stop) >> ov.stop
steps = ov.given.where("step")
steps.map_indexed(
lambda _, idx: {"progress": idx, "total": stop, "descr": "train"}
).give()
steps.skip(stop) >> ov.stop


@gated("--rates")
def rates(ov):
@gated("--train-rate")
def train_rate(ov):
yield ov.phases.load_script

sync = None
Expand All @@ -33,8 +111,10 @@ def setsync(use_cuda):
ov.given["?use_cuda"].first_or_default(False) >> setsync

times = (
ov.given.keep("step", "batch_size")
ov.given.where("step", "batch")
.kmap(batch_size=lambda batch: len(batch))
.augment(time=lambda: time.time_ns())
.keep("time", "batch_size")
.pairwise()
.buffer_with_time(1.0)
)
Expand All @@ -52,7 +132,36 @@ def _(elems):
n = sum(e1["batch_size"] for e1, e2 in elems)
t /= 1_000_000_000

ov.give(rate=n / t)
if t:
ov.give(rate=n / t, units="items/s")


@gated("--loading-rate")
def loading_rate(ov):
yield ov.phases.load_script

def _timing():
t0 = time.time_ns()
results = yield
t1 = time.time_ns()
if "batch" in results:
seconds = (t1 - t0) / 1000000000
data = results["batch"]
if isinstance(data, list):
data = data[0]
return len(data) / seconds
else:
return None

@ov.given.where("loader").ksubscribe
def _(loader):
typ = type(iter(loader))
(
ov.probe("typ.next(!$x:@enter, #value as batch, !!$y:@exit)")
.wmap(_timing)
.map(lambda x: {"loading_rate": x, "units": "items/s"})
.give()
)


class GPUMonitor(Thread):
Expand Down
55 changes: 35 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ license = "MIT"

[tool.poetry.dependencies]
python = "^3.8"
giving = "^0.3.7"
ptera = "^1.0.1"
giving = "^0.3.10"
ptera = "^1.2.0"
rich = "^10.13.0"
coleo = "^0.3.0"
requests = "^2.26.0"
Expand All @@ -17,6 +17,7 @@ GitPython = "^3.1.24"
PyYAML = "^6.0"
ovld = "^0.3.2"
GPUtil = "^1.4.0"
hrepr = "^0.4.0"

[tool.poetry.dev-dependencies]
black = "^21.10b0"
Expand All @@ -33,10 +34,12 @@ build-backend = "poetry.core.masonry.api"
milabench = "milabench.cli:main"

[tool.poetry.plugins."voir.instrument"]
profile_gpu = "milabench.instruments:profile_gpu"
dash = "milabench.instruments:dash"
display = "milabench.instruments:display"
loading_rate = "milabench.instruments:loading_rate"
profile_gpu = "milabench.instruments:profile_gpu"
stop = "milabench.instruments:stop"
rates = "milabench.instruments:rates"
train_rate = "milabench.instruments:train_rate"

[tool.isort]
multi_line_output = 3
Expand Down

0 comments on commit e375d90

Please sign in to comment.