diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index 9177d5c62c..d2c61cc70f 100644 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -1,6 +1,6 @@ """optim submodule.""" -from captum.optim import models +from captum.optim import models # noqa: F401 from captum.optim._core import loss, optimization # noqa: F401 from captum.optim._core.optimization import InputOptimization # noqa: F401 from captum.optim._param.image import images, transforms # noqa: F401 diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 4903155e74..d7bd8affd4 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -101,11 +101,11 @@ def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None: """ Args: - model (nn.Module): The reference to PyTorch model instance. - targets (nn.Module or list of nn.Module): The target layers to + model (nn.Module): The reference to PyTorch model instance. + targets (nn.Module or list of nn.Module): The target layers to collect activations from. """ - super(ActivationFetcher, self).__init__() + super().__init__() self.model = model self.layers = ModuleOutputsHook(targets) @@ -113,12 +113,13 @@ def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: """ Args: - input_t (tensor or tuple of tensors, optional): The input to use + input_t (torch.Tensor or tuple of torch.Tensor, optional): The input to use with the specified model. Returns: - activations_dict: An dict containing the collected activations. The keys - for the returned dictionary are the target layers. + activations_dict (ModuleOutputMapping): A dict containing the collected + activations. The keys for the returned dictionary are the target + layers. """ try: diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 4ec8762637..e76500050d 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -20,9 +20,9 @@ def __init__(self, background: Optional[torch.Tensor] = None) -> None: """ Args: - background (tensor, optional): An NCHW image tensor to be used as the + background (torch.Tensor, optional): An NCHW image tensor to be used as the Alpha channel's background. - Default: None + Default: ``None`` """ super().__init__() self.background = background @@ -36,7 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): RGBA image tensor to blend into an RGB image tensor. Returns: - **blended** (torch.Tensor): RGB image tensor. + blended (torch.Tensor): RGB image tensor. """ assert x.dim() == 4 assert x.size(1) == 4 @@ -60,7 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): RGBA image tensor. Returns: - **rgb** (torch.Tensor): RGB image tensor without the alpha channel. + rgb (torch.Tensor): RGB image tensor without the alpha channel. """ assert x.dim() == 4 assert x.size(1) == 4 @@ -71,13 +71,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ToRGB(nn.Module): """Transforms arbitrary channels to RGB. We use this to ensure our image parametrization itself can be decorrelated. So this goes between - the image parametrization and the normalization/sigmoid step. + the image parametrization and the normalization / sigmoid step, like in + :class:`captum.optim.images.NaturalImage`. + We offer two precalculated transforms: Karhunen-Loève (KLT) and I1I2I3. KLT corresponds to the empirically measured channel correlations on imagenet. I1I2I3 corresponds to an approximation for natural images from Ohta et al.[0] + + While the default transform matrices should work for the vast majority of use + cases, you can also use your own 3x3 transform matrix. If you wish to calculate + your own KLT transform matrix on a custom dataset, then please see + :func:`captum.optim.dataset.dataset_klt_matrix` for an example of how to do so. + [0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation," Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980 https://www.sciencedirect.com/science/article/pii/0146664X80900477 + + Example:: + + >>> to_rgb = opt.transforms.ToRGB() + >>> x = torch.randn(1, 3, 224, 224) + >>> decorrelated_colors = to_rgb(x, inverse=True) + >>> recorrelated_colors = to_rgb(decorrelated_colors) + + .. note:: The ``ToRGB`` transform is included by default inside + :class:`.NaturalImage`. """ @staticmethod @@ -86,7 +104,7 @@ def klt_transform() -> torch.Tensor: Karhunen-Loève transform (KLT) measured on ImageNet Returns: - **transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on + transform (torch.Tensor): A Karhunen-Loève transform (KLT) measured on the ImageNet dataset. """ # Handle older versions of PyTorch @@ -105,7 +123,7 @@ def klt_transform() -> torch.Tensor: def i1i2i3_transform() -> torch.Tensor: """ Returns: - **transform** (torch.Tensor): An approximation of natural colors transform + transform (torch.Tensor): An approximation of natural colors transform (i1i2i3). """ i1i2i3_matrix = [ @@ -119,9 +137,9 @@ def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None: """ Args: - transform (str or tensor): Either a string for one of the precalculated - transform matrices, or a 3x3 matrix for the 3 RGB channels of input - tensors. + transform (str or torch.Tensor): Either a string for one of the + precalculated transform matrices, or a 3x3 matrix for the 3 RGB + channels of input tensors. """ super().__init__() assert isinstance(transform, str) or torch.is_tensor(transform) @@ -143,12 +161,12 @@ def _forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: """ Args: - x (torch.tensor): A CHW or NCHW RGB or RGBA image tensor. - inverse (bool, optional): Whether to recorrelate or decorrelate colors. - Default: False. + x (torch.Tensor): A CHW or NCHW RGB or RGBA image tensor. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default: ``False`` Returns: - chw (torch.tensor): A tensor with it's colors recorrelated or + chw (torch.Tensor): A tensor with it's colors recorrelated or decorrelated. """ @@ -197,12 +215,12 @@ def _forward_without_named_dims( Args: - x (torch.tensor): A CHW pr NCHW RGB or RGBA image tensor. - inverse (bool, optional): Whether to recorrelate or decorrelate colors. - Default: False. + x (torch.Tensor): A CHW pr NCHW RGB or RGBA image tensor. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default: ``False`` Returns: - chw (torch.tensor): A tensor with it's colors recorrelated or + chw (torch.Tensor): A tensor with it's colors recorrelated or decorrelated. """ @@ -244,12 +262,12 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: Args: - x (torch.tensor): A CHW or NCHW RGB or RGBA image tensor. - inverse (bool, optional): Whether to recorrelate or decorrelate colors. - Default: False. + x (torch.Tensor): A CHW or NCHW RGB or RGBA image tensor. + inverse (bool, optional): Whether to recorrelate or decorrelate colors. + Default: ``False`` Returns: - chw (torch.tensor): A tensor with it's colors recorrelated or + chw (torch.Tensor): A tensor with it's colors recorrelated or decorrelated. """ if torch.jit.is_scripting(): @@ -263,6 +281,8 @@ class CenterCrop(torch.nn.Module): """ Center crop a specified amount from a tensor. If input are smaller than the specified crop size, padding will be applied. + + See :func:`.center_crop` for the functional version of this transform. """ __constants__ = [ @@ -291,18 +311,20 @@ def __init__( pixels_from_edges (bool, optional): Whether to treat crop size values as the number of pixels from the tensor's edge, or an exact shape in the center. - Default: False + Default: ``False`` offset_left (bool, optional): If the cropped away sides are not equal in size, offset center by +1 to the left and/or top. - This parameter is only valid when `pixels_from_edges` is False. - Default: False - padding_mode (optional, str): One of "constant", "reflect", "replicate" - or "circular". This parameter is only used if the crop size is larger - than the image size. - Default: "constant" - padding_value (float, optional): fill value for "constant" padding. This - parameter is only used if the crop size is larger than the image size. - Default: 0.0 + This parameter is only valid when ``pixels_from_edges`` is + ``False``. + Default: ``False`` + padding_mode (str, optional): One of: ``"constant"``, ``"reflect"``, + ``"replicate"``, or ``"circular"``. This parameter is only used if the + crop size is larger than the image size. + Default: ``"constant"`` + padding_value (float, optional): fill value for ``"constant"`` padding. + This parameter is only used if the crop size is larger than the image + size. + Default: ``0.0`` """ super().__init__() if not hasattr(size, "__iter__"): @@ -333,7 +355,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input (torch.Tensor): Input to center crop. Returns: - **tensor** (torch.Tensor): A center cropped *tensor*. + tensor (torch.Tensor): A center cropped NCHW tensor. """ return center_crop( @@ -358,28 +380,32 @@ def center_crop( Center crop a specified amount from a tensor. If input are smaller than the specified crop size, padding will be applied. + This function is the functional version of: :class:`.CenterCrop`. + Args: - input (tensor): A CHW or NCHW image tensor to center crop. + input (torch.Tensor): A CHW or NCHW image tensor to center crop. size (int, sequence, int): Number of pixels to center crop away. pixels_from_edges (bool, optional): Whether to treat crop size values as the number of pixels from the tensor's edge, or an exact shape in the center. - Default: False + Default: ``False`` offset_left (bool, optional): If the cropped away sides are not equal in size, offset center by +1 to the left and/or top. - This parameter is only valid when `pixels_from_edges` is False. - Default: False - padding_mode (optional, str): One of "constant", "reflect", "replicate" or - "circular". This parameter is only used if the crop size is larger than - the image size. - Default: "constant" - padding_value (float, optional): fill value for "constant" padding. This - parameter is only used if the crop size is larger than the image size. - Default: 0.0 + This parameter is only valid when ``pixels_from_edges`` is + ``False``. + Default: ``False`` + padding_mode (str, optional): One of: ``"constant"``, ``"reflect"``, + ``"replicate"``, or ``"circular"``. This parameter is only used if the crop + size is larger than the image size. + Default: ``"constant"`` + padding_value (float, optional): fill value for ``"constant"`` padding. + This parameter is only used if the crop size is larger than the image + size. + Default: ``0.0`` Returns: - **tensor**: A center cropped *tensor*. + tensor (torch.Tensor): A center cropped NCHW tensor. """ assert input.dim() == 3 or input.dim() == 4 @@ -433,7 +459,8 @@ def center_crop( class RandomScale(nn.Module): """ - Apply random rescaling on a NCHW tensor using the F.interpolate function. + Apply random rescaling on a NCHW tensor using the + :func:`torch.nn.functional.interpolate` function. """ __constants__ = [ @@ -458,21 +485,26 @@ def __init__( Args: scale (float, sequence, or torch.distribution): Sequence of rescaling - values to randomly select from, or a torch.distributions instance. + values to randomly select from, or a :mod:`torch.distributions` + instance. mode (str, optional): Interpolation mode to use. See documentation of - F.interpolate for more details. One of; "bilinear", "nearest", "area", - or "bicubic". - Default: "bilinear" + :func:`torch.nn.functional.interpolate` for more details. One of; + ``"bilinear"``, ``"nearest"``, ``"nearest-exact"``, ``"area"``, or + ``"bicubic"``. + Default: ``"bilinear"`` align_corners (bool, optional): Whether or not to align corners. See - documentation of F.interpolate for more details. - Default: False + documentation of :func:`torch.nn.functional.interpolate` for more + details. + Default: ``False`` recompute_scale_factor (bool, optional): Whether or not to recompute the - scale factor See documentation of F.interpolate for more details. - Default: False + scale factor See documentation of + :func:`torch.nn.functional.interpolate` for more details. + Default: ``False`` antialias (bool, optional): Whether or not use to anti-aliasing. This - feature is currently only available for "bilinear" and "bicubic" - modes. See documentation of F.interpolate for more details. - Default: False + feature is currently only available for ``"bilinear"`` and + ``"bicubic"`` modes. See documentation of + :func:`torch.nn.functional.interpolate` for more details. + Default: ``False`` """ super().__init__() assert mode not in ["linear", "trilinear"] @@ -508,7 +540,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: scale (float): The amount to scale the NCHW image by. Returns: - **x** (torch.Tensor): A scaled NCHW image tensor. + x (torch.Tensor): A scaled NCHW image tensor. """ if self._has_antialias: x = F.interpolate( @@ -538,7 +570,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): NCHW image tensor to randomly scale. Returns: - **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. + x (torch.Tensor): A randomly scaled NCHW image tensor. """ assert x.dim() == 4 if self._is_distribution: @@ -562,11 +594,11 @@ class RandomScaleAffine(nn.Module): """ Apply random rescaling on a NCHW tensor. - This random scaling transform utilizes F.affine_grid & F.grid_sample, and as a - result has two key differences to the default RandomScale transforms This - transform either shrinks an image while adding a background, or center crops image - and then resizes it to a larger size. This means that the output image shape is the - same shape as the input image. + This random scaling transform utilizes :func:`torch.nn.functional.affine_grid` + & :func:`torch.nn.functional.grid_sample`, and as a result has two key differences + to the default RandomScale transforms This transform either shrinks an image while + adding a background, or center crops image and then resizes it to a larger size. + This means that the output image shape is the same shape as the input image. In constrast to RandomScaleAffine, the default RandomScale transform simply resizes the input image using F.interpolate. @@ -591,18 +623,21 @@ def __init__( Args: scale (float, sequence, or torch.distribution): Sequence of rescaling - values to randomly select from, or a torch.distributions instance. + values to randomly select from, or a :mod:`torch.distributions` + instance. mode (str, optional): Interpolation mode to use. See documentation of - F.grid_sample for more details. One of; "bilinear", "nearest", or - "bicubic". - Default: "bilinear" + :func:`torch.nn.functional.grid_sample` for more details. One of; + ``"bilinear"``, ``"nearest"``, or ``"bicubic"``. + Default: ``"bilinear"`` padding_mode (str, optional): Padding mode for values that fall outside of - the grid. See documentation of F.grid_sample for more details. One of; - "zeros", "border", or "reflection". - Default: "zeros" + the grid. See documentation of :func:`torch.nn.functional.grid_sample` + for more details. One of; ``"zeros"``, ``"border"``, or + ``"reflection"``. + Default: ``"zeros"`` align_corners (bool, optional): Whether or not to align corners. See - documentation of F.affine_grid & F.grid_sample for more details. - Default: False + documentation of :func:`torch.nn.functional.affine_grid` & + :func:`torch.nn.functional.grid_sample` for more details. + Default: ``False`` """ super().__init__() if isinstance(scale, torch.distributions.distribution.Distribution): @@ -637,7 +672,7 @@ def _get_scale_mat( m (float): The scale value to use. Returns: - **scale_mat** (torch.Tensor): A scale matrix. + scale_mat (torch.Tensor): A scale matrix. """ scale_mat = torch.tensor( [[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype @@ -654,7 +689,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor: scale (float): The amount to scale the NCHW image by. Returns: - **x** (torch.Tensor): A scaled NCHW image tensor. + x (torch.Tensor): A scaled NCHW image tensor. """ scale_matrix = self._get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat( x.shape[0], 1, 1 @@ -678,7 +713,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): NCHW image tensor to randomly scale. Returns: - **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. + x (torch.Tensor): A randomly scaled NCHW image tensor. """ assert x.dim() == 4 if self._is_distribution: @@ -736,7 +771,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input (torch.Tensor): Input to randomly translate. Returns: - **tensor** (torch.Tensor): A randomly translated *tensor*. + tensor (torch.Tensor): A randomly translated NCHW tensor. """ insets = torch.randint( high=self.pad_range, @@ -750,8 +785,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class RandomRotation(nn.Module): """ - Apply random rotation transforms on a NCHW tensor, using a sequence of degrees or - torch.distributions instance. + Apply random rotation transforms on a NCHW tensor. """ __constants__ = [ @@ -772,19 +806,22 @@ def __init__( """ Args: - degrees (float, sequence, or torch.distribution): Tuple of degrees values - to randomly select from, or a torch.distributions instance. + degrees (float, sequence, or torch.distribution): Tuple or list of degrees + values to randomly select from, or a :mod:`torch.distributions` + instance. mode (str, optional): Interpolation mode to use. See documentation of - F.grid_sample for more details. One of; "bilinear", "nearest", or - "bicubic". - Default: "bilinear" + :func:`torch.nn.functional.grid_sample` for more details. One of; + ``"bilinear"``, ``"nearest"``, or ``"bicubic"``. + Default: ``"bilinear"`` padding_mode (str, optional): Padding mode for values that fall outside of - the grid. See documentation of F.grid_sample for more details. One of; - "zeros", "border", or "reflection". - Default: "zeros" + the grid. See documentation of :func:`torch.nn.functional.grid_sample` + for more details. One of; ``"zeros"``, ``"border"``, or + ``"reflection"``. + Default: ``"zeros"`` align_corners (bool, optional): Whether or not to align corners. See - documentation of F.affine_grid & F.grid_sample for more details. - Default: False + documentation of :func:`torch.nn.functional.affine_grid` & + :func:`torch.nn.functional.grid_sample` for more details. + Default: ``False`` """ super().__init__() if isinstance(degrees, torch.distributions.distribution.Distribution): @@ -820,7 +857,7 @@ def _get_rot_mat( theta (float): The rotation value in degrees. Returns: - **rot_mat** (torch.Tensor): A rotation matrix. + rot_mat (torch.Tensor): A rotation matrix. """ theta = theta * math.pi / 180.0 rot_mat = torch.tensor( @@ -843,7 +880,7 @@ def _rotate_tensor(self, x: torch.Tensor, theta: float) -> torch.Tensor: theta (float): The amount to rotate the NCHW image, in degrees. Returns: - **x** (torch.Tensor): A rotated NCHW image tensor. + x (torch.Tensor): A rotated NCHW image tensor. """ rot_matrix = self._get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat( x.shape[0], 1, 1 @@ -867,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): NCHW image tensor to randomly rotate. Returns: - **x** (torch.Tensor): A randomly rotated NCHW image *tensor*. + x (torch.Tensor): A randomly rotated NCHW image tensor. """ assert x.dim() == 4 if self._is_distribution: @@ -899,7 +936,7 @@ def __init__(self, multiplier: float = 1.0) -> None: """ Args: - multiplier (float, optional): A float value used to scale the input. + multiplier (float, optional): A float value used to scale the input. """ super().__init__() self.multiplier = multiplier @@ -913,7 +950,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): Input to scale values of. Returns: - **tensor** (torch.Tensor): tensor with it's values scaled. + tensor (torch.Tensor): tensor with it's values scaled. """ return x * self.multiplier @@ -932,7 +969,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): RGB image tensor to convert to BGR. Returns: - **BGR tensor** (torch.Tensor): A BGR tensor. + BGR tensor (torch.Tensor): A BGR tensor. """ assert x.dim() == 4 assert x.size(1) == 3 @@ -975,7 +1012,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GaussianSmoothing(nn.Module): """ Apply gaussian smoothing on a - 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + 1d, 2d or 3d tensor. Filtering is performed separately for each channel in the input using a depthwise convolution. """ @@ -1070,7 +1107,7 @@ def forward( x (torch.Tensor): Input to apply symmetric padding on. Returns: - **tensor** (torch.Tensor): Padded tensor. + tensor (torch.Tensor): Padded tensor. """ ctx.padding = padding x_device = x.device @@ -1093,7 +1130,7 @@ def backward( grad_output (torch.Tensor): Input to remove symmetric padding from. Returns: - **grad_input** (torch.Tensor): Unpadded tensor. + grad_input (torch.Tensor): Unpadded tensor. """ grad_input = grad_output.clone() B, C, H, W = grad_input.size() @@ -1117,7 +1154,8 @@ def __init__(self, warp: bool = False) -> None: Args: warp (bool, optional): Whether or not to make the resulting RGB colors more - distict from each other. Default is set to False. + distict from each other. + Default: ``False`` """ super().__init__() self.warp = warp @@ -1131,7 +1169,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): Input to reduce channel dimensions on. Returns: - **3 channel RGB tensor** (torch.Tensor): RGB image tensor. + x (torch.Tensor): A 3 channel RGB image tensor. """ assert x.dim() == 4 return nchannels_to_rgb(x, self.warp) @@ -1181,6 +1219,16 @@ def _center_crop(self, x: torch.Tensor) -> torch.Tensor: ] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Randomly crop an NCHW image tensor. + + Args: + + x (torch.Tensor): The NCHW image tensor to randomly crop. + + Returns + x (torch.Tensor): The randomly cropped NCHW image tensor. + """ assert x.dim() == 4 hs = int(math.ceil((x.shape[2] - self.crop_size[0]) / 2.0)) ws = int(math.ceil((x.shape[3] - self.crop_size[1]) / 2.0)) @@ -1206,9 +1254,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self._center_crop(x) +# Define TransformationRobustness defaults externally for easier Sphinx docs formatting +_TR_TRANSLATE: List[int] = [4] * 10 +_TR_SCALE: List[float] = [0.995**n for n in range(-5, 80)] + [ + 0.998**n for n in 2 * list(range(20, 40)) +] +_TR_DEGREES: List[int] = ( + list(range(-20, 20)) + list(range(-10, 10)) + list(range(-5, 5)) + 5 * [0] +) + + class TransformationRobustness(nn.Module): """ - This transform combines the standard transforms together for ease of use. + This transform combines the standard transforms (:class:`.RandomSpatialJitter`, + :class:`.RandomScale` & :class:`.RandomRotation`) together for ease of + use. Multiple jitter transforms can be used to create roughly gaussian distribution of jitter. @@ -1222,15 +1282,9 @@ class TransformationRobustness(nn.Module): def __init__( self, padding_transform: Optional[nn.Module] = nn.ConstantPad2d(2, value=0.5), - translate: Optional[Union[int, List[int]]] = [4] * 10, - scale: Optional[NumSeqOrTensorOrProbDistType] = [ - 0.995**n for n in range(-5, 80) - ] - + [0.998**n for n in 2 * list(range(20, 40))], - degrees: Optional[NumSeqOrTensorOrProbDistType] = list(range(-20, 20)) - + list(range(-10, 10)) - + list(range(-5, 5)) - + 5 * [0], + translate: Optional[Union[int, List[int]]] = _TR_TRANSLATE, + scale: Optional[NumSeqOrTensorOrProbDistType] = _TR_SCALE, + degrees: Optional[NumSeqOrTensorOrProbDistType] = _TR_DEGREES, final_translate: Optional[int] = 2, crop_or_pad_output: bool = False, ) -> None: @@ -1238,26 +1292,30 @@ def __init__( Args: padding_transform (nn.Module, optional): A padding module instance. No - padding will be applied before transforms if set to None. - Default: nn.ConstantPad2d(2, value=0.5) - translate (int or list of int, optional): The max horizontal and vertical - translation to use for each jitter transform. - Default: [4] * 10 + padding will be applied before transforms if set to ``None``. + Default: ``nn.ConstantPad2d(2, value=0.5)`` + translate (int or List[int], optional): The max horizontal and vertical + translation to use for each :class:`.RandomSpatialJitter` transform. + Default: ``[4] * 10`` scale (float, sequence, or torch.distribution, optional): Sequence of - rescaling values to randomly select from, or a torch.distributions - instance. If set to None, no rescaling transform will be used. - Default: A set of optimal values. + rescaling values to randomly select from, or a + :mod:`torch.distributions` instance. If set to ``None``, no + :class:`.RandomScale` transform will be used. + Default: ``[0.995**n for n in range(-5, 80)] + [0.998**n for n in 2 * + list(range(20, 40))]`` degrees (float, sequence, or torch.distribution, optional): Sequence of - degrees to randomly select from, or a torch.distributions - instance. If set to None, no rotation transform will be used. - Default: A set of optimal values. + degrees to randomly select from, or a :mod:`torch.distributions` + instance. If set to ``None``, no :class:`.RandomRotation` transform + will be used. + Default: ``list(range(-20, 20)) + list(range(-10, 10)) + + list(range(-5, 5)) + 5 * [0]`` final_translate (int, optional): The max horizontal and vertical - translation to use for the final jitter transform on fractional - pixels. - Default: 2 + translation to use for the final :class:`.RandomSpatialJitter` + transform on fractional pixels. + Default: ``2`` crop_or_pad_output (bool, optional): Whether or not to crop or pad the transformed output so that it is the same shape as the input. - Default: False + Default: ``False`` """ super().__init__() self.padding_transform = padding_transform @@ -1280,6 +1338,14 @@ def __init__( self.crop_or_pad_output = crop_or_pad_output def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An NCHW tensor. + + Returns: + x (torch.Tensor): A transformed NCHW tensor. + """ assert x.dim() == 4 crop_size = x.shape[2:] diff --git a/captum/optim/_utils/circuits.py b/captum/optim/_utils/circuits.py index 9c84d16247..d82d049fca 100644 --- a/captum/optim/_utils/circuits.py +++ b/captum/optim/_utils/circuits.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional import torch import torch.nn as nn @@ -11,7 +11,7 @@ def extract_expanded_weights( model: nn.Module, target1: nn.Module, target2: nn.Module, - crop_shape: Optional[Union[Tuple[int, int], IntSeqOrIntType]] = None, + crop_shape: Optional[IntSeqOrIntType] = None, model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), crop_func: Optional[Callable] = center_crop, ) -> torch.Tensor: @@ -20,24 +20,47 @@ def extract_expanded_weights( literally adjacent in a neural network, or where the weights aren’t directly represented in a single weight tensor. + Example:: + + >>> # Load InceptionV1 model with nonlinear layers replaced by + >>> # their linear equivalents + >>> linear_model = opt.models.googlenet( + >>> pretrained=True, use_linear_modules_only=True + >>> ).eval() + >>> # Extract weight interactions between target layers + >>> W_3a_3b = opt.circuits.extract_expanded_weights( + >>> linear_model, linear_model.mixed3a, linear_model.mixed3b, 5 + >>> ) + >>> # Display results for channel 147 of mixed3a and channel 379 of + >>> # mixed3b, in human readable format + >>> W_3a_3b_hm = opt.weights_to_heatmap_2d( + >>> W_3a_3b[379, 147, ...] / W_3a_3b[379, ...].max() + >>> ) + >>> opt.show(W_3a_3b_hm) + Voss, et al., "Visualizing Weights", Distill, 2021. See: https://distill.pub/2020/circuits/visualizing-weights/ Args: - model (nn.Module): The reference to PyTorch model instance. - target1 (nn.module): The starting target layer. Must be below the layer - specified for target2. - target2 (nn.Module): The end target layer. Must be above the layer - specified for target1. - crop_shape (int or tuple of ints, optional): Specify the exact output size - to crop out. - model_input (tensor or tuple of tensors, optional): The input to use + + model (nn.Module): The reference to PyTorch model instance. + target1 (nn.Module): The starting target layer. Must be below the layer + specified for ``target2``. + target2 (nn.Module): The end target layer. Must be above the layer + specified for ``target1``. + crop_shape (int, list of int, or tuple of int, optional): Specify the exact + output size to crop out. Set to ``None`` for no cropping. + Default: ``None`` + model_input (torch.Tensor or tuple of torch.Tensor, optional): The input to use with the specified model. - crop_func (Callable, optional): Specify a function to crop away the padding + Default: ``torch.zeros(1, 3, 224, 224)`` + crop_func (Callable, optional): Specify a function to crop away the padding from the output weights. + Default: :func:`.center_crop` + Returns: - *tensor*: A tensor containing the expanded weights in the form of: - (target2 output channels, target1 output channels, height, width) + tensor (torch.Tensor): A tensor containing the expanded weights in the form + of: (target2 output channels, target1 output channels, height, width) """ if isinstance(model_input, torch.Tensor): model_input = model_input.to(next(model.parameters()).device) diff --git a/captum/optim/_utils/image/atlas.py b/captum/optim/_utils/image/atlas.py index 5954a3a471..3e616fd55c 100644 --- a/captum/optim/_utils/image/atlas.py +++ b/captum/optim/_utils/image/atlas.py @@ -14,20 +14,20 @@ def normalize_grid( Args: - xy_grid (torch.tensor): The xy coordinate grid tensor to normalize, + xy_grid (torch.Tensor): The xy coordinate grid tensor to normalize, with a shape of: [n_points, n_axes]. min_percentile (float, optional): The minimum percentile to use when normalizing the tensor. Value must be in the range [0, 1]. - Default: 0.01 + Default: ``0.01`` max_percentile (float, optional): The maximum percentile to use when normalizing the tensor. Value must be in the range [0, 1]. - Default: 0.99 + Default: ``0.99`` relative_margin (float, optional): The relative margin to use when normalizing the tensor. - Default: 0.1 + Default: ``0.1`` Returns: - normalized_grid (torch.tensor): A normalized xy coordinate grid tensor. + normalized_grid (torch.Tensor): A normalized xy coordinate grid tensor. """ assert xy_grid.dim() == 2 @@ -56,8 +56,8 @@ def calc_grid_indices( This function draws a 2D grid across the irregular grid of points, and then groups point indices based on the grid cell they fall within. The grid cells are then filled with 1D tensors that have anywhere from 0 to n_indices values in them. The - sets of grid indices can then be used with the compute_avg_cell_samples function - to create atlas grid cell direction vectors. + sets of grid indices can then be used with the :func:`compute_avg_cell_samples` + function to create atlas grid cell direction vectors. Indices are stored for grid cells in an xy matrix, where the outer lists represent x positions and the inner lists represent y positions. Each grid cell is filled @@ -71,23 +71,31 @@ def calc_grid_indices( Each cell in the above example would contain a list of indices inside a tensor for that particular cell, like this: - indices = [ - [tensor([0, 5]), tensor([1]), tensor([2, 3])], - [tensor([]), tensor([4]), tensor([])], - [tensor([6, 7, 8]), tensor([]), tensor([])], - ] + + :: + + indices = [ + [tensor([0, 5]), tensor([1]), tensor([2, 3])], + [tensor([]), tensor([4]), tensor([])], + [tensor([6, 7, 8]), tensor([]), tensor([])], + ] Args: - xy_grid (torch.tensor): The xy coordinate grid activation samples, with a shape + + xy_grid (torch.Tensor): The xy coordinate grid activation samples, with a shape of: [n_points, 2]. - grid_size (Tuple[int, int]): The grid_size of grid cells to use. The grid_size - variable should be in the format of: [width, height]. - x_extent (Tuple[float, float], optional): The x axis range to use. - Default: (0.0, 1.0) - y_extent (Tuple[float, float], optional): The y axis range to use. - Default: (0.0, 1.0) + grid_size (tuple of int): The number of grid cells to use across the height + and width dimensions. The ``grid_size`` variable should be in the format + of: [width, height]. + x_extent (tuple of float, optional): The x axis range to use, in the format + of: (min, max). + Default: ``(0.0, 1.0)`` + y_extent (tuple of float, optional): The y axis range to use, in the format + of: (min, max). + Default: ``(0.0, 1.0)`` + Returns: - indices (list of list of torch.Tensors): List of lists of grid indices + indices (list of list of torch.Tensor): List of lists of grid indices stored inside tensors to use. Each 1D tensor of indices has a size of: 0 to n_indices. """ @@ -121,33 +129,35 @@ def compute_avg_cell_samples( """ Create direction vectors for sets of activation samples, attribution samples, and grid indices. Grid cells without the minimum number of points as specified by - min_density will be ignored. The calc_grid_indices function can be used to produce - the values required for the grid_indices variable. + ``min_density`` will be ignored. The :func:`calc_grid_indices` function can be used + to produce the values required for the ``grid_indices`` variable. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/ Args: - grid_indices (list of list of torch.tensor): List of lists of grid indices + grid_indices (list of list of torch.Tensor): List of lists of grid indices stored inside tensors to use. Each 1D tensor of indices has a size of: 0 to n_indices. - raw_samples (torch.tensor): Raw unmodified activation or attribution samples, + raw_samples (torch.Tensor): Raw unmodified activation or attribution samples, with a shape of: [n_samples, n_channels]. - grid_size (Tuple[int, int]): The grid_size of grid cells to use. The grid_size - variable should be in the format of: [width, height]. + grid_size (tuple of int): The number of grid cells to use across the height + and width dimensions. The ``grid_size`` variable should be in the format + of: [width, height]. min_density (int, optional): The minimum number of points for a cell to be counted. - Default: 8 + Default: ``8`` Returns: - cell_vecs (torch.tensor): A tensor containing all the direction vectors that - were created, stacked along the batch dimension with a shape of: - [n_vecs, n_channels]. - cell_coords (list of Tuple[int, int, int]): List of coordinates for grid - spatial positions of each direction vector, and the number of samples used - for the cell. The list for each cell is in the format of: - [x_coord, y_coord, number_of_samples_used]. + cell_vecs_and_cell_coords: A 2 element tuple of: ``(cell_vecs, cell_coords)``. + - cell_vecs (torch.Tensor): A tensor containing all the direction vectors + that were created, stacked along the batch dimension with a shape of: + [n_vecs, n_channels]. + - cell_coords (list of tuple of int): List of coordinates for grid + spatial positions of each direction vector, and the number of samples + used for the cell. The list for each cell is in the format of: + [x_coord, y_coord, number_of_samples_used]. """ assert raw_samples.dim() == 2 @@ -174,39 +184,43 @@ def create_atlas_vectors( ) -> Tuple[torch.Tensor, List[Tuple[int, int, int]]]: """ Create direction vectors by splitting an irregular grid of activation samples into - cells. Grid cells without the minimum number of points as specified by min_density - will be ignored. + cells. Grid cells without the minimum number of points as specified by + ``min_density`` will be ignored. Carter, et al., "Activation Atlas", Distill, 2019. https://distill.pub/2019/activation-atlas/ Args: - xy_grid (torch.tensor): The xy coordinate grid activation samples, with a shape + xy_grid (torch.Tensor): The xy coordinate grid activation samples, with a shape of: [n_points, 2]. - raw_activations (torch.tensor): Raw unmodified activation samples, with a shape + raw_activations (torch.Tensor): Raw unmodified activation samples, with a shape of: [n_samples, n_channels]. - grid_size (Tuple[int, int]): The size of grid cells to use. The grid_size - variable should be in the format of: [width, height]. + grid_size (tuple of int): The number of grid cells to use across the height + and width dimensions. The ``grid_size`` variable should be in the format + of: [width, height]. min_density (int, optional): The minimum number of points for a cell to be counted. - Default: 8 + Default: ``8`` normalize (bool, optional): Whether or not to remove outliers from an xy coordinate grid tensor, and rescale it to [0, 1]. - Default: True - x_extent (Tuple[float, float], optional): The x axis range to use. - Default: (0.0, 1.0) - y_extent (Tuple[float, float], optional): The y axis range to use. - Default: (0.0, 1.0) + Default: ``True`` + x_extent (tuple of float, optional): The x axis range to use, in the format + of: (min, max). + Default: ``(0.0, 1.0)`` + y_extent (tuple of float, optional): The y axis range to use, in the format + of: (min, max). + Default: ``(0.0, 1.0)`` Returns: - grid_vecs (torch.tensor): A tensor containing all the direction vectors that - were created, stacked along the batch dimension, with a shape of: - [n_vecs, n_channels]. - cell_coords (list of Tuple[int, int, int]): List of coordinates for grid - spatial positions of each direction vector, and the number of samples used - for the cell. The list for each cell is in the format of: - [x_coord, y_coord, number_of_samples_used]. + grid_vecs_and_cell_coords: A 2 element tuple of: ``(grid_vecs, cell_coords)``. + - grid_vecs (torch.Tensor): A tensor containing all the direction vectors + that were created, stacked along the batch dimension, with a shape + of: [n_vecs, n_channels]. + - cell_coords (list of tuple of int): List of coordinates for grid + spatial positions of each direction vector, and the number of samples + used for the cell. The list for each cell is in the format of: + [x_coord, y_coord, number_of_samples_used]. """ assert xy_grid.dim() == 2 and xy_grid.size(1) == 2 @@ -235,19 +249,19 @@ def create_atlas( Args: - cells (list of torch.tensor or torch.tensor): A list or stack of NCHW image + cells (list of torch.Tensor or torch.Tensor): A list or stack of NCHW image tensors made with atlas direction vectors. - coords (list of Tuple[int, int] or list of Tuple[int, int, int]): A list of - coordinates to use for the atlas image tensors. The first 2 values in each - coordinate list should be: [x, y, ...]. - grid_size (Tuple[int, int]): The size of grid cells to use. The grid_size - variable should be in the format of: [width, height]. + coords (list of tuple of int): A list of coordinates to use for the atlas image + tensors. The first 2 values in each coordinate list should be: [x, y, ...]. + grid_size (tuple of int): The number of grid cells to use across the height + and width dimensions. The ``grid_size`` variable should be in the format + of: [width, height]. base_tensor (Callable, optional): What to use for the atlas base tensor. Basic - choices are: torch.ones or torch.zeros. - Default: torch.ones + choices are: :func:`torch.ones` or :func:`torch.zeros`. + Default: :func:`torch.ones` Returns: - atlas_canvas (torch.tensor): The full activation atlas visualization, with a + atlas_canvas (torch.Tensor): The full activation atlas visualization, with a shape of NCHW. """ @@ -262,7 +276,7 @@ def create_atlas( # cell_b -> number of images # cell_c -> image channel - # cell_h -> image hight + # cell_h -> image height # cell_w -> image width cell_b, cell_c, cell_h, cell_w = cells[0].shape atlas_canvas = base_tensor( diff --git a/captum/optim/_utils/image/common.py b/captum/optim/_utils/image/common.py index f1cdc5f477..40a7f075b5 100644 --- a/captum/optim/_utils/image/common.py +++ b/captum/optim/_utils/image/common.py @@ -27,13 +27,13 @@ def make_grid_image( tiles (torch.Tensor or list of torch.Tensor): A stack of NCHW image tensors or a list of NCHW image tensors to create a grid from. - nrow (int, optional): The number of rows to use for the grid image. - Default: 4 + images_per_row (int, optional): The number of rows to use for the grid image. + Default: ``4`` padding (int, optional): The amount of padding between images in the grid images. - padding: 2 + padding: ``2`` pad_value (float, optional): The value to use for the padding. - Default: 0.0 + Default: ``0.0`` Returns: grid_img (torch.Tensor): The full NCHW grid image. @@ -79,22 +79,27 @@ def show( """ Show CHW & NCHW tensors as an image. + Alias: ``captum.optim.images.show`` + Args: x (torch.Tensor): The tensor you want to display as an image. - figsize (Tuple[int, int], optional): height & width to use - for displaying the image figure. - scale (float): Value to multiply the input tensor by so that + figsize (tuple of int, optional): The height & width to use for displaying the + ``ImageTensor`` figure, in the format of: (height, width). + Default: ``None`` + scale (float, optional): Value to multiply the input tensor by so that it's value range is [0-255] for display. + Default: ``255.0`` images_per_row (int, optional): The number of images per row to use for the - grid image. Default is set to None for no grid image creation. - Default: None + grid image. Default is set to ``None`` for no grid image creation. + Default: ``None`` padding (int, optional): The amount of padding between images in the grid - images. This parameter only has an effect if nrow is not None. - Default: 2 + images. This parameter only has an effect if ``images_per_row`` is not + ``None``. + Default: ``2`` pad_value (float, optional): The value to use for the padding. This parameter - only has an effect if nrow is not None. - Default: 0.0 + only has an effect if ``images_per_row`` is not ``None``. + Default: ``0.0`` """ if x.dim() not in [3, 4]: @@ -127,24 +132,28 @@ def save_tensor_as_image( """ Save RGB & RGBA image tensors with a shape of CHW or NCHW as images. + Alias: ``captum.optim.images.save_tensor_as_image`` + Args: x (torch.Tensor): The tensor you want to save as an image. filename (str): The filename to use when saving the image. scale (float, optional): Value to multiply the input tensor by so that it's value range is [0-255] for saving. + Default: ``255.0`` mode (str, optional): A PIL / Pillow supported colorspace. Default is set to None for automatic RGB / RGBA detection and usage. - Default: None + Default: ``None`` images_per_row (int, optional): The number of images per row to use for the grid image. Default is set to None for no grid image creation. - Default: None + Default: ``None`` padding (int, optional): The amount of padding between images in the grid - images. This parameter only has an effect if `nrow` is not None. - Default: 2 + images. This parameter only has an effect if ``images_per_row`` is not + ``None``. + Default: ``2`` pad_value (float, optional): The value to use for the padding. This parameter - only has an effect if `nrow` is not None. - Default: 0.0 + only has an effect if ``images_per_row`` is not ``None``. + Default: ``0.0`` """ if x.dim() not in [3, 4]: @@ -170,14 +179,14 @@ def get_neuron_pos( """ Args: - H (int) The height - W (int) The width + H (int): The h position to use. + W (int): The w position to use. x (int, optional): Optionally specify and exact x location of the neuron. If - set to None, then the center x location will be used. - Default: None + set to ``None``, then the center x location will be used. + Default: ``None`` y (int, optional): Optionally specify and exact y location of the neuron. If - set to None, then the center y location will be used. - Default: None + set to ``None``, then the center y location will be used. + Default: ``None`` Return: Tuple[_x, _y] (Tuple[int, int]): The x and y dimensions of the neuron. @@ -208,17 +217,22 @@ def _dot_cossim( a specified dimension. Args: + x (torch.Tensor): The tensor that you wish to compute the cosine similarity for in relation to tensor y. y (torch.Tensor): The tensor that you wish to compute the cosine similarity for in relation to tensor x. cossim_pow (float, optional): The desired cosine similarity power to use. + Default: ``0.0`` dim (int, optional): The target dimension for computing cosine similarity. + Default: ``1`` eps (float, optional): If cossim_pow is greater than zero, the desired epsilon value to use for cosine similarity calculations. + Default: ``1e-8`` + Returns: tensor (torch.Tensor): Dot cosine similarity between x and y, along the - specified dim. + specified dim. """ dot = torch.sum(x * y, dim) @@ -241,13 +255,16 @@ def hue_to_rgb( ) -> torch.Tensor: """ Create an RGB unit vector based on a hue of the input angle. + Args: + angle (float): The hue angle to create an RGB color for. device (torch.device, optional): The device to create the angle color tensor on. - Default: torch.device("cpu") + Default: ``torch.device("cpu")`` warp (bool, optional): Whether or not to make colors more distinguishable. - Default: True + Default: ``True`` + Returns: color_vec (torch.Tensor): A color vector. """ @@ -288,11 +305,12 @@ def nchannels_to_rgb( Args: - x (torch.Tensor): NCHW image tensor to transform into RGB image. - warp (bool, optional): Whether or not to make colors more distinguishable. - Default: True + x (torch.Tensor): NCHW image tensor to transform into RGB image. + warp (bool, optional): Whether or not to make colors more distinguishable. + Default: ``True`` eps (float, optional): An optional epsilon value. - Default: 1e-4 + Default: ``1e-4`` + Returns: tensor (torch.Tensor): An NCHW RGB image tensor. """ @@ -326,13 +344,15 @@ def weights_to_heatmap_2d( no excitation or inhibition. Args: - weight (torch.Tensor): A 2d tensor to create the heatmap from. - colors (list of str): A list of 5 strings containing hex triplet + + weight (torch.Tensor): A 2d tensor to create the heatmap from. + colors (list of str, optional): A list of 5 strings containing hex triplet (six digit), three-byte hexadecimal color values to use for coloring the heatmap. + Default: ``["0571b0", "92c5de", "f7f7f7", "f4a582", "ca0020"]`` Returns: - color_tensor (torch.Tensor): A weight heatmap. + color_tensor (torch.Tensor): A weight heatmap. """ assert weight.dim() == 2 diff --git a/captum/optim/_utils/reducer.py b/captum/optim/_utils/reducer.py index 2696d003d6..85f15f7bf3 100644 --- a/captum/optim/_utils/reducer.py +++ b/captum/optim/_utils/reducer.py @@ -16,20 +16,47 @@ class ChannelReducer: """ + The ChannelReducer class is a wrapper for PyTorch and NumPy based dimensionality + reduction algorithms, like those from ``sklearn.decomposition`` (ex: NMF, PCA), + ``sklearn.manifold`` (ex: TSNE), UMAP, and other libraries. This class handles + things like reshaping, algorithm search by name (for scikit-learn only), and + PyTorch tensor conversions to and from NumPy arrays. + + Example:: + + >>> reducer = opt.reducer.ChannelReducer(2, "NMF") + >>> x = torch.randn(1, 8, 128, 128).abs() + >>> output = reducer.fit_transform(x) + >>> print(output.shape) + torch.Size([1, 2, 128, 128]) + + >>> # reduction_alg attributes are easily accessible + >>> print(reducer.components.shape) + torch.Size([2, 8]) + Dimensionality reduction for the channel dimension of an input tensor. Olah, et al., "The Building Blocks of Interpretability", Distill, 2018. See here for more information: https://distill.pub/2018/building-blocks/ + Some of the possible algorithm choices: + + * https://scikit-learn.org/stable/modules/classes.html#module-sklearn.decomposition + * https://scikit-learn.org/stable/modules/classes.html#module-sklearn.manifold + * https://umap-learn.readthedocs.io/en/latest/ + Args: - n_components (int, optional): The number of channels to reduce the target + + n_components (int, optional): The number of channels to reduce the target dimension to. - reduction_alg (str or callable, optional): The desired dimensionality - reduction algorithm to use. The default reduction_alg is set to NMF from - sklearn, which requires users to put inputs on CPU before passing them to - fit_transform. - **kwargs (optional): Arbitrary keyword arguments used by the specified - reduction_alg. + reduction_alg (str or Callable, optional): The desired dimensionality + reduction algorithm to use. The default ``reduction_alg`` is set to NMF + from sklearn, which requires users to put inputs on CPU before passing them + to :func:`ChannelReducer.fit_transform`. Name strings are only supported + for ``sklearn.decomposition`` & ``sklearn.manifold`` class names. + Default: ``NMF`` + **kwargs (Any, optional): Arbitrary keyword arguments used by the specified + ``reduction_alg``. """ def __init__( @@ -47,14 +74,42 @@ def __init__( self._reducer = reduction_alg(n_components=n_components, **kwargs) def _get_reduction_algo_instance(self, name: str) -> Union[None, Callable]: + """ + Search through a library for a ``reduction_alg`` matching the provided str + name. + + Args: + + name (str): The name of the reduction_alg to search for. + + Returns: + reduction_alg (Callable or None): The ``reduction_alg`` if it was found, + otherwise None. + """ if hasattr(sklearn.decomposition, name): obj = sklearn.decomposition.__getattribute__(name) if issubclass(obj, BaseEstimator): return obj + elif hasattr(sklearn.manifold, name): + obj = sklearn.manifold.__getattribute__(name) + if issubclass(obj, BaseEstimator): + return obj return None @classmethod def _apply_flat(cls, func: Callable, x: torch.Tensor) -> torch.Tensor: + """ + Flatten inputs, run them through the reduction_alg, and then reshape them back + to their original size using the resized dimension. + + Args: + + func (Callable): The ``reduction_alg`` transform function being used. + x (torch.Tensor): The tensor being transformed and reduced. + + Returns: + x (torch.Tensor): A transformed tensor. + """ orig_shape = x.shape try: return func(x.reshape([-1, x.shape[-1]])).reshape( @@ -70,14 +125,21 @@ def fit_transform( self, x: torch.Tensor, swap_2nd_and_last_dims: bool = True ) -> torch.Tensor: """ - Perform dimensionality reduction on an input tensor. + Perform dimensionality reduction on an input tensor using the specified + ``reduction_alg``'s ``.fit_transform`` function. + Args: - tensor (tensor): A tensor to perform dimensionality reduction on. - swap_2nd_and_last_dims (bool, optional): If true, input channels are + + x (torch.Tensor): A tensor to perform dimensionality reduction on. + swap_2nd_and_last_dims (bool, optional): If ``True``, input channels are expected to be in the second dimension unless the input tensor has a - shape of CHW. Default is set to True. + shape of CHW. When reducing the channel dimension, this parameter + should be set to ``True`` unless you are already using the channels + last format. + Default: ``True``. + Returns: - *tensor*: A tensor with one of it's dimensions reduced. + x (torch.Tensor): A tensor with one of it's dimensions reduced. """ if x.dim() == 3 and swap_2nd_and_last_dims: @@ -127,14 +189,20 @@ def __dir__(self) -> List: def posneg(x: torch.Tensor, dim: int = 0) -> torch.Tensor: """ - Hack that makes a matrix positive by concatination in order to simulate one-sided + Hack that makes a matrix positive by concatenation in order to simulate one-sided NMF with regular NMF. + Voss, et al., "Visualizing Weights", Distill, 2021. + See: https://distill.pub/2020/circuits/visualizing-weights/ + Args: - x (tensor): A tensor to make positive. - dim (int, optional): The dimension to concatinate the two tensor halves at. + + x (torch.Tensor): A tensor to make positive. + dim (int, optional): The dimension to concatenate the two tensor halves at. + Default: ``0`` + Returns: - tensor (torch.tensor): A positive tensor for one-sided dimensionality + tensor (torch.Tensor): A positive tensor for one-sided dimensionality reduction. """ diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index e65e281217..49a1154fe7 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -68,6 +68,14 @@ class RedirectedReluLayer(nn.Module): @torch.jit.ignore def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): A tensor to pass through RedirectedReLU. + + Returns: + x (torch.Tensor): The output of RedirectedReLU. + """ return RedirectedReLU.apply(input) @@ -207,7 +215,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (torch.tensor): The input tensor to apply 2D convolution to. + x (torch.Tensor): The input tensor to apply 2D convolution to. Returns x (torch.Tensor): The input tensor after the 2D convolution was applied. @@ -275,6 +283,7 @@ class SkipLayer(torch.nn.Module): https://pytorch.org/docs/stable/generated/torch.nn.Identity.html Args: + args (Any): Any argument. Arguments will be safely ignored. kwargs (Any) Any keyword argument. Arguments will be safely ignored. """ @@ -287,9 +296,11 @@ def forward( ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. args (Any): Any argument. Arguments will be safely ignored. kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or tensors. @@ -306,7 +317,9 @@ def skip_layers( with layers that do nothing. This is useful for removing the nonlinear ReLU layers when creating expanded weights. + Args: + model (nn.Module): A PyTorch model instance. layers (nn.Module or list of nn.Module): The layer class type to replace in the model. diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index d24c87d42d..e0660d6f93 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -15,38 +15,45 @@ def googlenet( **kwargs: Any, ) -> "InceptionV1": r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from - `"Going Deeper with Convolutions" `_. + `"Going Deeper with Convolutions" `_. + + Example:: + + >>> model = opt.models.googlenet(pretrained=True) + >>> output = model(torch.zeros(1, 3, 224, 224)) Args: - pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. - Default: False - progress (bool, optional): If True, displays a progress bar of the download to - stderr - Default: True + pretrained (bool, optional): If ``True``, returns a model pre-trained on + ImageNet. + Default: ``False`` + progress (bool, optional): If ``True``, displays a progress bar of the download + to stderr. + Default: ``True`` model_path (str, optional): Optional path for InceptionV1 model file. - Default: None - replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained - model with Redirected ReLU in place of ReLU layers. - Default: *True* when pretrained is True otherwise *False* - use_linear_modules_only (bool, optional): If True, return pretrained + Default: ``None`` + replace_relus_with_redirectedrelu (bool, optional): If ``True``, return + pretrained model with :class:`.RedirectedReLU` in place of ReLU layers. + Default: *``True``* when pretrained is True otherwise *``False``* + use_linear_modules_only (bool, optional): If ``True``, return pretrained model with all nonlinear layers replaced with linear equivalents. - Default: False - aux_logits (bool, optional): If True, adds two auxiliary branches that can + Default: ``False`` + aux_logits (bool, optional): If ``True``, adds two auxiliary branches that can improve training. - Default: False + Default: ``False`` out_features (int, optional): Number of output features in the model used for training. - Default: 1008 - transform_input (bool, optional): If True, preprocesses the input according to - the method with which it was trained on ImageNet. - Default: False - bgr_transform (bool, optional): If True and transform_input is True, perform an - RGB to BGR transform in the internal preprocessing. - Default: False + Default: ``1008`` + transform_input (bool, optional): If ``True``, preprocesses the input according + to the method with which it was trained on ImageNet. + Default: ``False`` + bgr_transform (bool, optional): If ``True`` and ``transform_input`` is + ``True``, perform an RGB to BGR transform in the internal + preprocessing. + Default: ``False`` Returns: - **InceptionV1** (InceptionV1): An Inception5h model. + model (InceptionV1): An Inception5h model instance. """ if pretrained: @@ -93,24 +100,25 @@ def __init__( """ Args: - replace_relus_with_redirectedrelu (bool, optional): If True, return - pretrained model with Redirected ReLU in place of ReLU layers. - Default: False - use_linear_modules_only (bool, optional): If True, return pretrained + replace_relus_with_redirectedrelu (bool, optional): If ``True``, return + pretrained model with :class:`.RedirectedReLU` in place of ReLU layers. + Default: ``False`` + use_linear_modules_only (bool, optional): If ``True``, return pretrained model with all nonlinear layers replaced with linear equivalents. - Default: False + Default: ``False`` aux_logits (bool, optional): If True, adds two auxiliary branches that can improve training. - Default: False + Default: ``False`` out_features (int, optional): Number of output features in the model used for training. - Default: 1008 - transform_input (bool, optional): If True, preprocesses the input according - to the method with which it was trained on ImageNet. - Default: False - bgr_transform (bool, optional): If True and transform_input is True, - perform an RGB to BGR transform in the internal preprocessing. - Default: False + Default: ``1008`` + transform_input (bool, optional): If ``True``, preprocesses the input + according to the method with which it was trained on ImageNet. + Default: ``False`` + bgr_transform (bool, optional): If ``True`` and ``transform_input`` is + ``True``, perform an RGB to BGR transform in the internal + preprocessing. + Default: ``False`` """ super().__init__() self.aux_logits = aux_logits @@ -283,20 +291,26 @@ def __init__( """ Args: - in_channels (int, optional): The number of input channels to use for the - inception module. - c1x1 (int, optional): - c3x3reduce (int, optional): - c3x3 (int, optional): - c5x5reduce (int, optional): - c5x5 (int, optional): - pool_proj (int, optional): + in_channels (int): The number of input channels to use for the first + layers of the inception module branches. + c1x1 (int): The number of output channels to use for the first layer in + the c1x1 branch. + c3x3reduce (int): The number of output channels to use for the first layer + in the c3x3 branch. + c3x3 (int): The number of output channels to use for the second layer in + the c3x3 branch. + c5x5reduce (int): The number of output channels to use for the first layer + in the c5x5 branch. + c5x5 (int): The number of output channels to use for the second layer in + the c5x5 branch. + pool_proj (int): The number of output channels to use for the second layer + in the pool branch. activ (type of nn.Module, optional): The nn.Module class type to use for activation layers. - Default: nn.ReLU + Default: :class:`torch.nn.ReLU` p_layer (type of nn.Module, optional): The nn.Module class type to use for pooling layers. - Default: nn.MaxPool2d + Default: :class:`torch.nn.MaxPool2d` """ super().__init__() self.conv_1x1 = nn.Conv2d( @@ -390,13 +404,13 @@ def __init__( in_channels (int, optional): The number of input channels to use for the auxiliary branch. - Default: 508 + Default: ``508`` out_features (int, optional): The number of output features to use for the auxiliary branch. - Default: 1008 + Default: ``1008`` activ (type of nn.Module, optional): The nn.Module class type to use for activation layers. - Default: nn.ReLU + Default: :class:`nn.ReLU` """ super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))