From 36733df8bde194e440a44455951fc1c1e436788a Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 31 Oct 2024 18:48:06 +0100 Subject: [PATCH] docs: specify directives --- .github/CONTRIBUTING.md | 2 +- docs/source/pages/implement.rst | 10 +++++----- docs/source/pages/lightning.rst | 8 ++++---- docs/source/pages/overview.rst | 8 ++++---- src/torchmetrics/audio/dnsmos.py | 6 ++++-- src/torchmetrics/audio/nisqa.py | 6 ++++-- src/torchmetrics/audio/pesq.py | 6 ++++-- src/torchmetrics/audio/srmr.py | 5 +++-- src/torchmetrics/audio/stoi.py | 3 ++- .../classification/calibration_error.py | 2 +- src/torchmetrics/classification/cohen_kappa.py | 4 ++-- src/torchmetrics/classification/dice.py | 2 +- src/torchmetrics/classification/hinge.py | 4 ++-- src/torchmetrics/classification/jaccard.py | 6 +++--- .../classification/matthews_corrcoef.py | 6 +++--- .../classification/precision_fixed_recall.py | 6 +++--- .../classification/precision_recall_curve.py | 6 +++--- src/torchmetrics/classification/ranking.py | 6 +++--- .../classification/recall_fixed_precision.py | 6 +++--- src/torchmetrics/classification/roc.py | 18 +++++++++--------- src/torchmetrics/collections.py | 4 ++-- src/torchmetrics/detection/_mean_ap.py | 8 ++++---- src/torchmetrics/detection/mean_ap.py | 6 +++--- src/torchmetrics/functional/audio/dnsmos.py | 3 ++- src/torchmetrics/functional/audio/nisqa.py | 3 ++- src/torchmetrics/functional/audio/pesq.py | 3 ++- src/torchmetrics/functional/audio/srmr.py | 5 +++-- src/torchmetrics/functional/audio/stoi.py | 3 ++- .../functional/classification/dice.py | 6 ++++-- src/torchmetrics/functional/image/gradients.py | 3 ++- src/torchmetrics/functional/image/psnr.py | 4 ++-- .../functional/multimodal/clip_iqa.py | 3 ++- .../functional/multimodal/clip_score.py | 3 ++- .../functional/regression/log_mse.py | 4 ++-- src/torchmetrics/image/fid.py | 5 ++--- src/torchmetrics/image/inception.py | 3 ++- src/torchmetrics/image/kid.py | 3 ++- src/torchmetrics/image/lpip.py | 6 ++---- src/torchmetrics/image/mifid.py | 6 ++++-- .../image/perceptual_path_length.py | 3 ++- src/torchmetrics/image/psnrb.py | 2 +- src/torchmetrics/metric.py | 4 ++-- src/torchmetrics/multimodal/clip_iqa.py | 3 ++- src/torchmetrics/multimodal/clip_score.py | 3 ++- src/torchmetrics/regression/kl_divergence.py | 4 ++-- src/torchmetrics/regression/log_mse.py | 4 ++-- src/torchmetrics/retrieval/base.py | 7 ++++--- .../retrieval/precision_recall_curve.py | 5 +++-- src/torchmetrics/wrappers/multitask.py | 4 ++-- 49 files changed, 132 insertions(+), 108 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 7b7211b9d61..b923d6ea262 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -103,7 +103,7 @@ def my_func(param_a: int, param_b: Optional[float] = None) -> str: >>> my_func(1, 2) 3 - .. note:: If you want to add something. + .. hint:: If you want to add something. """ p = param_b if param_b else 0 return str(param_a + p) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 1620ce29cd9..5ab044e1be1 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -257,7 +257,7 @@ and tests gets formatted in the following way: 3. ``new_metric(...)``: essentially wraps the ``_update`` and ``_compute`` private functions into one public function that makes up the functional interface for the metric. - .. note:: + .. hint:: The `functional mean squared error `_ metric is a is a great example of how to divide the logic. @@ -270,9 +270,9 @@ and tests gets formatted in the following way: ``_new_metric_compute(...)`` function in its ``compute``. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain. - .. note:: - The module `MeanSquaredError `_ - metric that corresponds to the above functional example showcases these steps. + .. note:: + The module `MeanSquaredError `_ + metric that corresponds to the above functional example showcases these steps. 4. Remember to add binding to the different relevant ``__init__`` files. @@ -291,7 +291,7 @@ and tests gets formatted in the following way: so that different combinations of inputs and parameters get tested. 5. (optional) If your metric raises any exception, please add tests that showcase this. - .. note:: + .. hint:: The `test file for MSE `_ metric shows how to implement such tests. diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index a96196396b8..0fd27a4d474 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -13,7 +13,7 @@ TorchMetrics in PyTorch Lightning TorchMetrics was originally created as part of `PyTorch Lightning `_, a powerful deep learning research framework designed for scaling models without boilerplate. -.. note:: +.. caution:: TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend always keeping both frameworks up-to-date for the best experience. @@ -69,9 +69,9 @@ LightningModule `self.log `_. - .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install + .. hint:: + Using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install as ``pip install torchmetrics['audio']`` or alternatively ``pip install librosa onnxruntime-gpu requests`` (if you do not have GPU enabled machine install ``onnxruntime`` instead of ``onnxruntime-gpu``) diff --git a/src/torchmetrics/functional/audio/nisqa.py b/src/torchmetrics/functional/audio/nisqa.py index 6696f93965d..1a1c066ebf0 100644 --- a/src/torchmetrics/functional/audio/nisqa.py +++ b/src/torchmetrics/functional/audio/nisqa.py @@ -66,7 +66,8 @@ def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor: """`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2]. - .. note:: Using this metric requires you to have ``librosa`` and ``requests`` installed. Install as + .. hint:: + Usingsing this metric requires you to have ``librosa`` and ``requests`` installed. Install as ``pip install librosa requests``. Args: diff --git a/src/torchmetrics/functional/audio/pesq.py b/src/torchmetrics/functional/audio/pesq.py index 516014434bf..04cf1ebace2 100644 --- a/src/torchmetrics/functional/audio/pesq.py +++ b/src/torchmetrics/functional/audio/pesq.py @@ -40,7 +40,8 @@ def perceptual_evaluation_speech_quality( This metric is a wrapper for the `pesq package`_. Note that input will be moved to `cpu` to perform the metric calculation. - .. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install + .. hint:: + Usingsing this metrics requires you to have ``pesq`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pesq``. Note that ``pesq`` will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstall ``pesq``. diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index d098366df6b..26d27f00999 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -202,11 +202,12 @@ def speech_reverberation_modulation_energy_ratio( Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch, setting `fast=True` may slow down the speed for calculating this metric on GPU. - .. note:: using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. + .. hint:: + Usingsing this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio`` and ``pip install git+https://github.com/detly/gammatone``. - .. note:: + .. attention:: This implementation is experimental, and might not be consistent with the matlab implementation `SRMRToolbox`_, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have diff --git a/src/torchmetrics/functional/audio/stoi.py b/src/torchmetrics/functional/audio/stoi.py index 48e9e78510b..e029b721ef8 100644 --- a/src/torchmetrics/functional/audio/stoi.py +++ b/src/torchmetrics/functional/audio/stoi.py @@ -39,7 +39,8 @@ def short_time_objective_intelligibility( calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device. - .. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install + .. hint:: + Usingsing this metrics requires you to have ``pystoi`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pystoi`` Args: diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 845ed7162d7..0a0ffb457de 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -107,10 +107,12 @@ def dice( - ``'samples'``: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample). - .. note:: What is considered a sample in the multi-dimensional multi-class case + .. tip:: + What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, + .. hint:: + If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, the value for the class will be ``nan``. mdmc_average: diff --git a/src/torchmetrics/functional/image/gradients.py b/src/torchmetrics/functional/image/gradients.py index 03f751f0870..e87d4a75198 100644 --- a/src/torchmetrics/functional/image/gradients.py +++ b/src/torchmetrics/functional/image/gradients.py @@ -70,7 +70,8 @@ def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: [5., 5., 5., 5., 5.], [0., 0., 0., 0., 0.]]) - .. note:: The implementation follows the 1-step finite difference method as followed + .. note:: + The implementation follows the 1-step finite difference method as followed by the TF implementation. The values are organized such that the gradient of [I(x+1, y)-[I(x, y)]] are at the (x, y) location diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index adb80e2bc77..7bd93ba94e1 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -134,8 +134,8 @@ def peak_signal_noise_ratio( >>> peak_signal_noise_ratio(pred, target) tensor(2.5527) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ if dim is None and reduction != "elementwise_mean": diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 49e710e248a..0fe74cc1855 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -271,7 +271,8 @@ def clip_image_quality_assessment( available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt. - .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + .. hint:: + If using the default `clip_iqa` model, the package `piq` must be installed. Either install with `pip install piq` or `pip install torchmetrics[multimodal]`. Returns: diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 920eb6972e6..f70f37d534b 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -135,7 +135,8 @@ def clip_score( textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. - .. note:: Metric is not scriptable + .. caution:: + Metric is not scriptable Args: images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors diff --git a/src/torchmetrics/functional/regression/log_mse.py b/src/torchmetrics/functional/regression/log_mse.py index 96d2938a8ee..3ba27ab9e99 100644 --- a/src/torchmetrics/functional/regression/log_mse.py +++ b/src/torchmetrics/functional/regression/log_mse.py @@ -68,8 +68,8 @@ def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor: >>> mean_squared_log_error(x, y) tensor(0.0207) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ sum_squared_log_error, num_obs = _mean_squared_log_error_update(preds, target) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 8c2e5d3cf76..ac559ed0c68 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -211,9 +211,8 @@ class FrechetInceptionDistance(Metric): that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype` method of the metric. - .. note:: using this metrics requires you to have torch 1.9 or higher installed - - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 2b125542a17..20d53d10f2b 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -49,7 +49,8 @@ class InceptionScore(Metric): ``normalize`` is set to ``False`` images are expected to have dtype uint8 and take values in the ``[0, 255]`` range. All images will be resized to 299 x 299 which is the size of the original training data. - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 975edf800eb..018fc7a7511 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -97,7 +97,8 @@ class KernelInceptionDistance(Metric): effect and update method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible to the custom feature extractor. - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index f1adb73648a..1893fb734ba 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -47,12 +47,10 @@ class LearnedPerceptualImagePatchSimilarity(Metric): Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg). - .. note:: using this metrics requires you to have ``torchvision`` package installed. Either install as + .. hint:: + Using this metrics requires you to have ``torchvision`` package installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision``. - .. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation - if this is a issue. - As input to ``forward`` and ``update`` the metric accepts the following input - ``img1`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)`` diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 7d7237b7416..a1e2d2f4c0a 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -84,10 +84,12 @@ class MemorizationInformedFrechetInceptionDistance(Metric): flag ``real`` determines if the images should update the statistics of the real distribution or the fake distribution. - .. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install + .. hint:: + Using this metrics requires you to have ``scipy`` install. Either install as ``pip install torchmetrics[image]`` or ``pip install scipy`` - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 4c312404934..117dca6f8cb 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -51,7 +51,8 @@ class PerceptualPathLength(Metric): if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255]. - .. note:: using this metric with the default feature extractor requires that ``torchvision`` is installed. + .. hint:: + Using this metric with the default feature extractor requires that ``torchvision`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision`` As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index bac58b84e46..c9fb157d6ff 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -34,7 +34,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric): Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function. This metric is a modified version of PSNR that better supports evaluation of images with blocked artifacts, that oftens occur in compressed images. - .. note:: + .. attention:: Metric only supports grayscale images. If you have RGB images, please convert them to grayscale first. As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index dd912aaa28e..279537aac51 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -239,11 +239,11 @@ def add_state( - If the metric state is a ``list``, the synced value will be a ``list`` containing the combined elements from all processes. - .. note:: + .. important:: When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. - .. note:: + .. caution:: The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index 7a66de4ae3e..a28a360a0f1 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -112,7 +112,8 @@ class CLIPImageQualityAssessment(Metric): positive prompt and the second string must be a negative prompt. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + .. hint:: + If using the default `clip_iqa` model, the package `piq` must be installed. Either install with `pip install piq` or `pip install torchmetrics[image]`. Raises: diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index f385fbc145d..92ca7ad6b4f 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -54,7 +54,8 @@ class CLIPScore(Metric): textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. - .. note:: Metric is not scriptable + .. caution:: + Metric is not scriptable As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 10fc6d90bc9..253b7a05002 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -64,8 +64,8 @@ class KLDivergence(Metric): ValueError: If ``reduction`` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``. - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. Example: >>> from torch import tensor diff --git a/src/torchmetrics/regression/log_mse.py b/src/torchmetrics/regression/log_mse.py index 85e01dba314..da190016844 100644 --- a/src/torchmetrics/regression/log_mse.py +++ b/src/torchmetrics/regression/log_mse.py @@ -52,8 +52,8 @@ class MeanSquaredLogError(Metric): >>> mean_squared_log_error(preds, target) tensor(0.0397) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ diff --git a/src/torchmetrics/retrieval/base.py b/src/torchmetrics/retrieval/base.py index 5446d8668b1..f9a0a4f8cc4 100644 --- a/src/torchmetrics/retrieval/base.py +++ b/src/torchmetrics/retrieval/base.py @@ -50,10 +50,11 @@ class RetrievalMetric(Metric, ABC): - ``indexes`` (:class:`~torch.Tensor`): A long tensor of shape ``(N, ...)`` which indicate to which query a prediction belongs - .. note:: ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten - to single dimension once provided. + .. hint:: + The ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flattened + to single dimension once provided. - .. note:: + .. attention:: Predictions will be first grouped by ``indexes`` and then the real metric, defined by overriding the `_metric` method, will be computed as the mean of the scores over each query. diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 701e89b1433..9b4f4d38d0e 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -303,9 +303,10 @@ class RetrievalRecallAtFixedPrecision(RetrievalPrecisionRecallCurve): - ``indexes`` (:class:`~torch.Tensor`): A long tensor of shape ``(N, ...)`` which indicate to which query a prediction belongs - .. note:: All ``indexes``, ``preds`` and ``target`` must have the same dimension. + .. important:: + All ``indexes``, ``preds`` and ``target`` must have the same dimension. - .. note:: + .. attention:: Predictions will be first grouped by ``indexes`` and then `RetrievalRecallAtFixedPrecision` will be computed as the mean of the `RetrievalRecallAtFixedPrecision` over each query. diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index fa2f04db97d..04ddd87ad71 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -43,8 +43,8 @@ class MultitaskWrapper(WrapperMetric): postfix: A string to append after the keys of the output dict. If not provided, will default to an empty string. - .. note:: - The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test. + .. tip:: + The use prefix and postfix allows for easily creating task wrappers for training, validation and test. The arguments are only changing the output keys of the computed metrics and not the input keys. This means that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key