Skip to content

Commit

Permalink
easy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Dec 28, 2024
1 parent 4a07798 commit 919c41c
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ select = [
# Pyflakes
"F",
# pyupgrade
# "UP",
"UP",
# flake8-bugbear
# "B",
# flake8-simplify
Expand Down
18 changes: 6 additions & 12 deletions tensorboardX/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,9 @@ def make_tsv(metadata, save_path, metadata_header=None):

named_path = os.path.join(save_path, 'metadata.tsv')

if sys.version_info[0] == 3:
with open(named_path, 'w', encoding='utf8') as f:
for x in metadata:
f.write(x + '\n')
else:
with open(named_path, 'wb') as f:
for x in metadata:
f.write((x + '\n').encode('utf-8'))
with open(named_path, 'w', encoding='utf8') as f:
for x in metadata:
f.write(x + '\n')
maybe_upload_file(named_path)


Expand Down Expand Up @@ -97,17 +92,16 @@ def append_pbtxt(metadata, label_img, save_path, subdir, global_step, tag):
with open(named_path, 'a') as f:
# step = os.path.split(save_path)[-1]
f.write('embeddings {\n')
f.write('tensor_name: "{}:{}"\n'.format(
tag, str(global_step).zfill(5)))
f.write(f'tensor_name: "{tag}:{str(global_step).zfill(5)}"\n')
f.write('tensor_path: "{}"\n'.format(join(subdir, 'tensors.tsv')))
if metadata is not None:
f.write('metadata_path: "{}"\n'.format(
join(subdir, 'metadata.tsv')))
if label_img is not None:
f.write('sprite {\n')
f.write('image_path: "{}"\n'.format(join(subdir, 'sprite.png')))
f.write('single_image_dim: {}\n'.format(label_img.shape[3]))
f.write('single_image_dim: {}\n'.format(label_img.shape[2]))
f.write(f'single_image_dim: {label_img.shape[3]}\n')
f.write(f'single_image_dim: {label_img.shape[2]}\n')
f.write('}\n')
f.write('}\n')
maybe_upload_file(named_path)
Expand Down
5 changes: 2 additions & 3 deletions tensorboardX/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Writes events to disk in a logdir."""

from __future__ import absolute_import, division, print_function

import multiprocessing
import os
Expand All @@ -27,7 +26,7 @@
from .record_writer import RecordWriter, directory_check


class EventsWriter(object):
class EventsWriter:
'''Writes `Event` protocol buffers to an event file.'''

def __init__(self, file_prefix, filename_suffix=''):
Expand Down Expand Up @@ -75,7 +74,7 @@ def close(self):
return return_value


class EventFileWriter(object):
class EventFileWriter:
"""Writes `Event` protocol buffers to an event file.
The `EventFileWriter` class creates an event file in the specified directory,
Expand Down
2 changes: 1 addition & 1 deletion tensorboardX/global_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_writer = None


class GlobalSummaryWriter(object):
class GlobalSummaryWriter:
"""A class that implements an event writer that supports concurrent logging and global logging across
different modules.
Expand Down
10 changes: 5 additions & 5 deletions tensorboardX/record_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def open_file(path):
return open(path, 'wb')


class S3RecordWriter(object):
class S3RecordWriter:
"""Writes tensorboard protocol buffer files to S3."""

def __init__(self, path):
Expand Down Expand Up @@ -99,7 +99,7 @@ def close(self):
self.closed = True


class S3RecordWriterFactory(object):
class S3RecordWriterFactory:
"""Factory for event protocol buffer files to S3."""

def open(self, path):
Expand All @@ -114,7 +114,7 @@ def directory_check(self, path):
register_writer_factory("s3", S3RecordWriterFactory())


class GCSRecordWriter(object):
class GCSRecordWriter:
"""Writes tensorboard protocol buffer files to Google Cloud Storage."""

def __init__(self, path):
Expand Down Expand Up @@ -158,7 +158,7 @@ def close(self):
self.flush()


class GCSRecordWriterFactory(object):
class GCSRecordWriterFactory:
"""Factory for event protocol buffer files to Google Cloud Storage."""

def open(self, path):
Expand All @@ -173,7 +173,7 @@ def directory_check(self, path):
register_writer_factory("gs", GCSRecordWriterFactory())


class RecordWriter(object):
class RecordWriter:
def __init__(self, path):
self._name_to_tf_name = {}
self._tf_names = set()
Expand Down
3 changes: 1 addition & 2 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import absolute_import, division, print_function

import logging
import os
Expand Down Expand Up @@ -570,7 +569,7 @@ def _get_tensor_summary(tag, tensor, content_type, json_config):
TensorShapeProto.Dim(size=tensor.shape[2]),
]))
tensor_summary = Summary.Value(
tag='{}_{}'.format(tag, content_type),
tag=f'{tag}_{content_type}',
tensor=tensor,
metadata=smd,
)
Expand Down
1 change: 0 additions & 1 deletion tensorboardX/torchvis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import gc
import time
Expand Down
16 changes: 8 additions & 8 deletions tensorboardX/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def make_grid(I, ncols=8):

def convert_to_NTCHW(tensor, input_format):
assert len(input_format) == 5, "Only 5D tensor is supported."
assert len(set(input_format)) == len(input_format), "You can not use the same dimension shorthand twice. \
input_format: {}".format(input_format)
assert len(tensor.shape) == len(input_format), "size of input tensor and input format are different. \
tensor shape: {}, input_format: {}".format(tensor.shape, input_format)
assert len(set(input_format)) == len(input_format), f"You can not use the same dimension shorthand twice. \
input_format: {input_format}"
assert len(tensor.shape) == len(input_format), f"size of input tensor and input format are different. \
tensor shape: {tensor.shape}, input_format: {input_format}"
input_format = input_format.upper()
index = [input_format.find(c) for c in 'NTCHW']
tensor_NTCHW = tensor.transpose(index)
Expand All @@ -105,10 +105,10 @@ def convert_to_NTCHW(tensor, input_format):

def convert_to_HWC(tensor, input_format): # tensor: numpy array
import numpy as np
assert len(set(input_format)) == len(input_format), "You can not use the same dimension shorthand twice. \
input_format: {}".format(input_format)
assert len(tensor.shape) == len(input_format), "size of input tensor and input format are different. \
tensor shape: {}, input_format: {}".format(tensor.shape, input_format)
assert len(set(input_format)) == len(input_format), f"You can not use the same dimension shorthand twice. \
input_format: {input_format}"
assert len(tensor.shape) == len(input_format), f"size of input tensor and input format are different. \
tensor shape: {tensor.shape}, input_format: {input_format}"
input_format = input_format.upper()

if len(input_format) == 4:
Expand Down
6 changes: 3 additions & 3 deletions tensorboardX/visdom_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def add_scalar(self, tag, scalar_value, global_step=None, main_tag='default'):
exists = self.scalar_dict[main_tag].get(tag) is not None
self.scalar_dict[main_tag][tag] = self.scalar_dict[main_tag][tag] + \
[scalar_value] if exists else [scalar_value]
plot_name = '{}-{}'.format(main_tag, tag)
plot_name = f'{main_tag}-{tag}'
# If there is no global_step provided, follow sequential order
x_val = len(self.scalar_dict[main_tag][tag]
) if not global_step else global_step
Expand Down Expand Up @@ -283,7 +283,7 @@ def add_pr_curve(self, tag, labels, predictions, global_step=None, num_threshold
Y=precision,
name=tag,
opts={
'title': 'PR Curve for {}'.format(tag),
'title': f'PR Curve for {tag}',
'xlabel': 'recall',
'ylabel': 'precision',
},
Expand Down Expand Up @@ -316,7 +316,7 @@ def add_pr_curve_raw(self, tag, true_positive_counts,
Y=precision,
name=tag,
opts={
'title': 'PR Curve for {}'.format(tag),
'title': f'PR Curve for {tag}',
'xlabel': 'recall',
'ylabel': 'precision',
},
Expand Down
23 changes: 11 additions & 12 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Provides an API for writing protocol buffers to event files to be
consumed by TensorBoard for visualization."""

from __future__ import absolute_import, division, print_function

import atexit
import json
import logging
import os
import time
from typing import Dict, List, Optional, Union
from typing import Optional, Union

import numpy

Expand Down Expand Up @@ -46,7 +45,7 @@
pass


class DummyFileWriter(object):
class DummyFileWriter:
"""A fake file writer that writes nothing to the disk.
"""

Expand Down Expand Up @@ -79,7 +78,7 @@ def reopen(self):
return


class FileWriter(object):
class FileWriter:
"""Writes protocol buffers to event files to be consumed by TensorBoard.
The `FileWriter` class provides a mechanism to create an event file in a
Expand Down Expand Up @@ -216,7 +215,7 @@ def reopen(self):
self.event_writer.reopen()


class SummaryWriter(object):
class SummaryWriter:
"""Writes entries directly to event files in the logdir to be
consumed by TensorBoard.
Expand Down Expand Up @@ -361,8 +360,8 @@ def _get_comet_logger(self):

def add_hparams(
self,
hparam_dict: Dict[str, Union[bool, str, float, int]],
metric_dict: Dict[str, float],
hparam_dict: dict[str, Union[bool, str, float, int]],
metric_dict: dict[str, float],
name: Optional[str] = None,
global_step: Optional[int] = None):
"""Add a set of hyperparameters to be compared in tensorboard.
Expand Down Expand Up @@ -448,7 +447,7 @@ def add_scalar(
def add_scalars(
self,
main_tag: str,
tag_scalar_dict: Dict[str, float],
tag_scalar_dict: dict[str, float],
global_step: Optional[int] = None,
walltime: Optional[float] = None):
"""Adds many scalar data to summary.
Expand Down Expand Up @@ -742,7 +741,7 @@ def add_image_with_boxes(
global_step: Optional[int] = None,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'CHW',
labels: Optional[List[str]] = None,
labels: Optional[list[str]] = None,
**kwargs):
"""Add image and draw bounding boxes on the image.
Expand Down Expand Up @@ -1104,7 +1103,7 @@ def add_pr_curve_raw(

def add_custom_scalars_multilinechart(
self,
tags: List[str],
tags: list[str],
category: str = 'default',
title: str = 'untitled'):
"""Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument
Expand All @@ -1122,7 +1121,7 @@ def add_custom_scalars_multilinechart(

def add_custom_scalars_marginchart(
self,
tags: List[str],
tags: list[str],
category: str = 'default',
title: str = 'untitled'):
"""Shorthand for creating marginchart. Similar to ``add_custom_scalars()``, but the only necessary argument
Expand All @@ -1141,7 +1140,7 @@ def add_custom_scalars_marginchart(

def add_custom_scalars(
self,
layout: Dict[str, Dict[str, List]]):
layout: dict[str, dict[str, list]]):
"""Create special chart by collecting charts tags in 'scalars'. Note that this function can only be called once
for each SummaryWriter() object. Because it only provides metadata to tensorboard, the function can be called
before or after the training loop. See ``examples/demo_custom_scalars.py`` for more.
Expand Down
3 changes: 1 addition & 2 deletions tensorboardX/x2num.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# DO NOT alter/distruct/free input object !
from __future__ import absolute_import, division, print_function

import logging

Expand Down Expand Up @@ -31,7 +30,7 @@ def make_np(x):
if 'jax' in str(type(x)):
return check_nan(np.array(x))
raise NotImplementedError(
'Got {}, but expected numpy array or torch tensor.'.format(type(x)))
f'Got {type(x)}, but expected numpy array or torch tensor.')


def prepare_pytorch(x):
Expand Down

0 comments on commit 919c41c

Please sign in to comment.