diff --git a/pyproject.toml b/pyproject.toml index 5633f15..824ac07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ select = [ # flake8-bugbear # "B", # flake8-simplify -# "SIM", + "SIM", # isort "I", ] diff --git a/tensorboardX/comet_utils.py b/tensorboardX/comet_utils.py index fcb0a89..31e2231 100644 --- a/tensorboardX/comet_utils.py +++ b/tensorboardX/comet_utils.py @@ -35,7 +35,7 @@ def wrapper(*args, **kwargs): if self._logging is None and comet_installed: self._logging = False try: - if 'api_key' not in self._comet_config.keys(): + if 'api_key' not in self._comet_config: comet_ml.init() if comet_ml.get_global_experiment() is not None: logger.warning("You have already created a comet \ diff --git a/tensorboardX/summary.py b/tensorboardX/summary.py index c005752..549a03a 100644 --- a/tensorboardX/summary.py +++ b/tensorboardX/summary.py @@ -108,7 +108,7 @@ def hparams(hparam_dict=None, metric_dict=None): hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_BOOL"))) continue - if isinstance(v, int) or isinstance(v, float): + if isinstance(v, (int, float)): v = make_np(v)[0] ssi.hparams[k].number_value = v hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) @@ -126,7 +126,7 @@ def hparams(hparam_dict=None, metric_dict=None): content=content.SerializeToString())) ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) - mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict] exp = Experiment(hparam_infos=hps, metric_infos=mts) content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) @@ -388,26 +388,26 @@ def make_video(tensor, fps): # encode sequence of images into gif string clip = mpy.ImageSequenceClip(list(tensor), fps=fps) - - filename = tempfile.NamedTemporaryFile(suffix='.gif', delete=False).name - - if moviepy.version.__version__.startswith("0."): - logger.warning('Upgrade to moviepy >= 1.0.0 to supress the progress bar.') - clip.write_gif(filename, verbose=False) - elif moviepy.version.__version__.startswith("1."): - # moviepy >= 1.0.0 use logger=None to suppress output. - clip.write_gif(filename, verbose=False, logger=None) - else: - # Moviepy >= 2.0.0.dev1 removed the verbose argument - clip.write_gif(filename, logger=None) - - with open(filename, 'rb') as f: - tensor_string = f.read() - - try: - os.remove(filename) - except OSError: - logger.warning('The temporary file used by moviepy cannot be deleted.') + with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as fp: + filename = fp.name + + if moviepy.version.__version__.startswith("0."): + logger.warning('Upgrade to moviepy >= 1.0.0 to supress the progress bar.') + clip.write_gif(filename, verbose=False) + elif moviepy.version.__version__.startswith("1."): + # moviepy >= 1.0.0 use logger=None to suppress output. + clip.write_gif(filename, verbose=False, logger=None) + else: + # Moviepy >= 2.0.0.dev1 removed the verbose argument + clip.write_gif(filename, logger=None) + + with open(filename, 'rb') as f: + tensor_string = f.read() + + try: + os.remove(filename) + except OSError: + logger.warning('The temporary file used by moviepy cannot be deleted.') return Summary.Image(height=h, width=w, colorspace=c, encoded_image_string=tensor_string) diff --git a/tensorboardX/torchvis.py b/tensorboardX/torchvis.py index 83af457..1691999 100644 --- a/tensorboardX/torchvis.py +++ b/tensorboardX/torchvis.py @@ -25,9 +25,9 @@ def __init__(self, *args, **init_kwargs): def register(self, *args, **init_kwargs): # Sets tensorboard as the default visualization format if not specified - formats = ['tensorboard'] if not args else args + formats = args if args else ['tensorboard'] for format in formats: - if self.subscribers.get(format) is None and format in vis_formats.keys(): + if self.subscribers.get(format) is None and format in vis_formats: self.subscribers[format] = vis_formats[format](**init_kwargs.get(format, {})) def unregister(self, *args): diff --git a/tensorboardX/visdom_writer.py b/tensorboardX/visdom_writer.py index 674f6b8..17b89a6 100644 --- a/tensorboardX/visdom_writer.py +++ b/tensorboardX/visdom_writer.py @@ -63,8 +63,7 @@ def add_scalar(self, tag, scalar_value, global_step=None, main_tag='default'): [scalar_value] if exists else [scalar_value] plot_name = f'{main_tag}-{tag}' # If there is no global_step provided, follow sequential order - x_val = len(self.scalar_dict[main_tag][tag] - ) if not global_step else global_step + x_val = global_step if global_step else len(self.scalar_dict[main_tag][tag]) if exists: # Update our existing Visdom window self.vis.line( @@ -110,7 +109,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None): 'run_14h-arctanx' with the corresponding values. """ - for key in tag_scalar_dict.keys(): + for key in tag_scalar_dict: self.add_scalar(key, tag_scalar_dict[key], global_step, main_tag) @_check_connection diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py index 2c1336b..5a38cbc 100644 --- a/tensorboardX/writer.py +++ b/tensorboardX/writer.py @@ -326,7 +326,7 @@ def __append_to_scalar_dict(self, tag, scalar_value, global_step, {writer_id : [[timestamp, step, value], ...], ...}. """ from .x2num import make_np - if tag not in self.scalar_dict.keys(): + if tag not in self.scalar_dict: self.scalar_dict[tag] = [] self.scalar_dict[tag].append( [timestamp, global_step, float(make_np(scalar_value).squeeze())]) @@ -483,7 +483,7 @@ def add_scalars( fw_logdir = self._get_file_writer().get_logdir() for tag, scalar_value in tag_scalar_dict.items(): fw_tag = os.path.join(str(fw_logdir), main_tag, tag) - if fw_tag in self.all_writers.keys(): + if fw_tag in self.all_writers: fw = self.all_writers[fw_tag] else: fw = FileWriter(logdir=fw_tag) @@ -1001,11 +1001,7 @@ def add_embedding( # new funcion to append to the config file a new embedding append_pbtxt(metadata, label_img, self._get_file_writer().get_logdir(), subdir, global_step, tag) - if tag is not None: - template_filename = f"{tag}.json" - - else: - template_filename = None + template_filename = f'{tag}.json' if tag is not None else None self._get_comet_logger().log_embedding(mat, metadata, label_img, template_filename=template_filename)