Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optim-wip: Improve docs for Sphinx #983

Open
wants to merge 30 commits into
base: optim-wip
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3173fd9
Sphinx doc improvements
ProGamerGov Jun 27, 2022
fdba4e2
Add function links to atlas docs
ProGamerGov Jun 29, 2022
76836b9
Improve ToRGB docs
ProGamerGov Jun 29, 2022
bb0984f
Add more doc refs
ProGamerGov Jun 29, 2022
e01478c
Fix lint error
ProGamerGov Jun 29, 2022
ef48587
Improve ToRGB docs
ProGamerGov Jul 1, 2022
3ab53ae
Improve docs for Sphinx (#550)
ProGamerGov Jul 2, 2022
57ea951
More doc improvements
ProGamerGov Jul 5, 2022
3c6c24d
Improve docs (#554)
ProGamerGov Jul 5, 2022
07c9e60
Add missing InceptionV1 InceptionModule docs
ProGamerGov Jul 6, 2022
953780e
Fix TransformationRobustness doc formatting & add missing RedirectedR…
ProGamerGov Jul 6, 2022
5333128
http -> https
ProGamerGov Jul 10, 2022
44af560
Improve ChannelReducer docs
ProGamerGov Jul 11, 2022
acebbd8
Fix tensor type hints
ProGamerGov Jul 15, 2022
61e18e4
Add missing return docs to get_model_layers
ProGamerGov Jul 16, 2022
6dbfc3d
Fix doc parameter type formatting
ProGamerGov Jul 16, 2022
82ca242
Fix duplicated circuits type hint
ProGamerGov Jul 16, 2022
95ed9f9
Remove unused type hints
ProGamerGov Jul 16, 2022
8e77eb7
Doc fix
ProGamerGov Jul 16, 2022
d6f0def
Add function aliases to docs
ProGamerGov Jul 17, 2022
910c38d
Improve docstring type formatting
ProGamerGov Jul 18, 2022
876d737
:class: -> :func:
ProGamerGov Jul 18, 2022
8cbca6d
Fix accidental indent
ProGamerGov Jul 19, 2022
485481d
Improve doc types for ActivationFetcher
ProGamerGov Jul 20, 2022
0fa87de
Add hyperlink ref to circuits argument
ProGamerGov Jul 21, 2022
01c59d2
Improve reducer docs
ProGamerGov Jul 22, 2022
d456347
Fix spelling
ProGamerGov Jul 27, 2022
e704243
Fix lint error
ProGamerGov Jul 27, 2022
330f009
Fix docstring types
ProGamerGov Jul 28, 2022
1f0420b
Update transforms.py
ProGamerGov Aug 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captum/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""optim submodule."""

from captum.optim import models
from captum.optim import models # noqa: F401
from captum.optim._core import loss, optimization # noqa: F401
from captum.optim._core.optimization import InputOptimization # noqa: F401
from captum.optim._param.image import images, transforms # noqa: F401
Expand Down
13 changes: 7 additions & 6 deletions captum/optim/_core/output_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,25 @@ def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None:
"""
Args:

model (nn.Module): The reference to PyTorch model instance.
targets (nn.Module or list of nn.Module): The target layers to
model (nn.Module): The reference to PyTorch model instance.
targets (nn.Module or list of nn.Module): The target layers to
collect activations from.
"""
super(ActivationFetcher, self).__init__()
super().__init__()
self.model = model
self.layers = ModuleOutputsHook(targets)

def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
"""
Args:

input_t (tensor or tuple of tensors, optional): The input to use
input_t (torch.Tensor or tuple of torch.Tensor, optional): The input to use
with the specified model.

Returns:
activations_dict: An dict containing the collected activations. The keys
for the returned dictionary are the target layers.
activations_dict (ModuleOutputMapping): A dict containing the collected
activations. The keys for the returned dictionary are the target
layers.
"""

try:
Expand Down
316 changes: 191 additions & 125 deletions captum/optim/_param/image/transforms.py

Large diffs are not rendered by default.

49 changes: 36 additions & 13 deletions captum/optim/_utils/circuits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional

import torch
import torch.nn as nn
Expand All @@ -11,7 +11,7 @@ def extract_expanded_weights(
model: nn.Module,
target1: nn.Module,
target2: nn.Module,
crop_shape: Optional[Union[Tuple[int, int], IntSeqOrIntType]] = None,
crop_shape: Optional[IntSeqOrIntType] = None,
model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224),
crop_func: Optional[Callable] = center_crop,
) -> torch.Tensor:
Expand All @@ -20,24 +20,47 @@ def extract_expanded_weights(
literally adjacent in a neural network, or where the weights aren’t directly
represented in a single weight tensor.

Example::

>>> # Load InceptionV1 model with nonlinear layers replaced by
>>> # their linear equivalents
>>> linear_model = opt.models.googlenet(
>>> pretrained=True, use_linear_modules_only=True
>>> ).eval()
>>> # Extract weight interactions between target layers
>>> W_3a_3b = opt.circuits.extract_expanded_weights(
>>> linear_model, linear_model.mixed3a, linear_model.mixed3b, 5
>>> )
>>> # Display results for channel 147 of mixed3a and channel 379 of
>>> # mixed3b, in human readable format
>>> W_3a_3b_hm = opt.weights_to_heatmap_2d(
>>> W_3a_3b[379, 147, ...] / W_3a_3b[379, ...].max()
>>> )
>>> opt.show(W_3a_3b_hm)

Voss, et al., "Visualizing Weights", Distill, 2021.
See: https://distill.pub/2020/circuits/visualizing-weights/

Args:
model (nn.Module): The reference to PyTorch model instance.
target1 (nn.module): The starting target layer. Must be below the layer
specified for target2.
target2 (nn.Module): The end target layer. Must be above the layer
specified for target1.
crop_shape (int or tuple of ints, optional): Specify the exact output size
to crop out.
model_input (tensor or tuple of tensors, optional): The input to use

model (nn.Module): The reference to PyTorch model instance.
target1 (nn.Module): The starting target layer. Must be below the layer
specified for ``target2``.
target2 (nn.Module): The end target layer. Must be above the layer
specified for ``target1``.
crop_shape (int, list of int, or tuple of int, optional): Specify the exact
output size to crop out. Set to ``None`` for no cropping.
Default: ``None``
model_input (torch.Tensor or tuple of torch.Tensor, optional): The input to use
with the specified model.
crop_func (Callable, optional): Specify a function to crop away the padding
Default: ``torch.zeros(1, 3, 224, 224)``
crop_func (Callable, optional): Specify a function to crop away the padding
from the output weights.
Default: :func:`.center_crop`

Returns:
*tensor*: A tensor containing the expanded weights in the form of:
(target2 output channels, target1 output channels, height, width)
tensor (torch.Tensor): A tensor containing the expanded weights in the form
of: (target2 output channels, target1 output channels, height, width)
"""
if isinstance(model_input, torch.Tensor):
model_input = model_input.to(next(model.parameters()).device)
Expand Down
140 changes: 77 additions & 63 deletions captum/optim/_utils/image/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ def normalize_grid(

Args:

xy_grid (torch.tensor): The xy coordinate grid tensor to normalize,
xy_grid (torch.Tensor): The xy coordinate grid tensor to normalize,
with a shape of: [n_points, n_axes].
min_percentile (float, optional): The minimum percentile to use when
normalizing the tensor. Value must be in the range [0, 1].
Default: 0.01
Default: ``0.01``
max_percentile (float, optional): The maximum percentile to use when
normalizing the tensor. Value must be in the range [0, 1].
Default: 0.99
Default: ``0.99``
relative_margin (float, optional): The relative margin to use when
normalizing the tensor.
Default: 0.1
Default: ``0.1``

Returns:
normalized_grid (torch.tensor): A normalized xy coordinate grid tensor.
normalized_grid (torch.Tensor): A normalized xy coordinate grid tensor.
"""

assert xy_grid.dim() == 2
Expand Down Expand Up @@ -56,8 +56,8 @@ def calc_grid_indices(
This function draws a 2D grid across the irregular grid of points, and then groups
point indices based on the grid cell they fall within. The grid cells are then
filled with 1D tensors that have anywhere from 0 to n_indices values in them. The
sets of grid indices can then be used with the compute_avg_cell_samples function
to create atlas grid cell direction vectors.
sets of grid indices can then be used with the :func:`compute_avg_cell_samples`
function to create atlas grid cell direction vectors.

Indices are stored for grid cells in an xy matrix, where the outer lists represent
x positions and the inner lists represent y positions. Each grid cell is filled
Expand All @@ -71,23 +71,31 @@ def calc_grid_indices(

Each cell in the above example would contain a list of indices inside a tensor for
that particular cell, like this:
indices = [
[tensor([0, 5]), tensor([1]), tensor([2, 3])],
[tensor([]), tensor([4]), tensor([])],
[tensor([6, 7, 8]), tensor([]), tensor([])],
]

::

indices = [
[tensor([0, 5]), tensor([1]), tensor([2, 3])],
[tensor([]), tensor([4]), tensor([])],
[tensor([6, 7, 8]), tensor([]), tensor([])],
]

Args:
xy_grid (torch.tensor): The xy coordinate grid activation samples, with a shape

xy_grid (torch.Tensor): The xy coordinate grid activation samples, with a shape
of: [n_points, 2].
grid_size (Tuple[int, int]): The grid_size of grid cells to use. The grid_size
variable should be in the format of: [width, height].
x_extent (Tuple[float, float], optional): The x axis range to use.
Default: (0.0, 1.0)
y_extent (Tuple[float, float], optional): The y axis range to use.
Default: (0.0, 1.0)
grid_size (tuple of int): The number of grid cells to use across the height
and width dimensions. The ``grid_size`` variable should be in the format
of: [width, height].
x_extent (tuple of float, optional): The x axis range to use, in the format
of: (min, max).
Default: ``(0.0, 1.0)``
y_extent (tuple of float, optional): The y axis range to use, in the format
of: (min, max).
Default: ``(0.0, 1.0)``

Returns:
indices (list of list of torch.Tensors): List of lists of grid indices
indices (list of list of torch.Tensor): List of lists of grid indices
stored inside tensors to use. Each 1D tensor of indices has a size of:
0 to n_indices.
"""
Expand Down Expand Up @@ -121,33 +129,35 @@ def compute_avg_cell_samples(
"""
Create direction vectors for sets of activation samples, attribution samples, and
grid indices. Grid cells without the minimum number of points as specified by
min_density will be ignored. The calc_grid_indices function can be used to produce
the values required for the grid_indices variable.
``min_density`` will be ignored. The :func:`calc_grid_indices` function can be used
to produce the values required for the ``grid_indices`` variable.

Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/

Args:

grid_indices (list of list of torch.tensor): List of lists of grid indices
grid_indices (list of list of torch.Tensor): List of lists of grid indices
stored inside tensors to use. Each 1D tensor of indices has a size of:
0 to n_indices.
raw_samples (torch.tensor): Raw unmodified activation or attribution samples,
raw_samples (torch.Tensor): Raw unmodified activation or attribution samples,
with a shape of: [n_samples, n_channels].
grid_size (Tuple[int, int]): The grid_size of grid cells to use. The grid_size
variable should be in the format of: [width, height].
grid_size (tuple of int): The number of grid cells to use across the height
and width dimensions. The ``grid_size`` variable should be in the format
of: [width, height].
min_density (int, optional): The minimum number of points for a cell to be
counted.
Default: 8
Default: ``8``

Returns:
cell_vecs (torch.tensor): A tensor containing all the direction vectors that
were created, stacked along the batch dimension with a shape of:
[n_vecs, n_channels].
cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
spatial positions of each direction vector, and the number of samples used
for the cell. The list for each cell is in the format of:
[x_coord, y_coord, number_of_samples_used].
cell_vecs_and_cell_coords: A 2 element tuple of: ``(cell_vecs, cell_coords)``.
- cell_vecs (torch.Tensor): A tensor containing all the direction vectors
that were created, stacked along the batch dimension with a shape of:
[n_vecs, n_channels].
- cell_coords (list of tuple of int): List of coordinates for grid
spatial positions of each direction vector, and the number of samples
used for the cell. The list for each cell is in the format of:
[x_coord, y_coord, number_of_samples_used].
"""
assert raw_samples.dim() == 2

Expand All @@ -174,39 +184,43 @@ def create_atlas_vectors(
) -> Tuple[torch.Tensor, List[Tuple[int, int, int]]]:
"""
Create direction vectors by splitting an irregular grid of activation samples into
cells. Grid cells without the minimum number of points as specified by min_density
will be ignored.
cells. Grid cells without the minimum number of points as specified by
``min_density`` will be ignored.

Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/

Args:

xy_grid (torch.tensor): The xy coordinate grid activation samples, with a shape
xy_grid (torch.Tensor): The xy coordinate grid activation samples, with a shape
of: [n_points, 2].
raw_activations (torch.tensor): Raw unmodified activation samples, with a shape
raw_activations (torch.Tensor): Raw unmodified activation samples, with a shape
of: [n_samples, n_channels].
grid_size (Tuple[int, int]): The size of grid cells to use. The grid_size
variable should be in the format of: [width, height].
grid_size (tuple of int): The number of grid cells to use across the height
and width dimensions. The ``grid_size`` variable should be in the format
of: [width, height].
min_density (int, optional): The minimum number of points for a cell to be
counted.
Default: 8
Default: ``8``
normalize (bool, optional): Whether or not to remove outliers from an xy
coordinate grid tensor, and rescale it to [0, 1].
Default: True
x_extent (Tuple[float, float], optional): The x axis range to use.
Default: (0.0, 1.0)
y_extent (Tuple[float, float], optional): The y axis range to use.
Default: (0.0, 1.0)
Default: ``True``
x_extent (tuple of float, optional): The x axis range to use, in the format
of: (min, max).
Default: ``(0.0, 1.0)``
y_extent (tuple of float, optional): The y axis range to use, in the format
of: (min, max).
Default: ``(0.0, 1.0)``

Returns:
grid_vecs (torch.tensor): A tensor containing all the direction vectors that
were created, stacked along the batch dimension, with a shape of:
[n_vecs, n_channels].
cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
spatial positions of each direction vector, and the number of samples used
for the cell. The list for each cell is in the format of:
[x_coord, y_coord, number_of_samples_used].
grid_vecs_and_cell_coords: A 2 element tuple of: ``(grid_vecs, cell_coords)``.
- grid_vecs (torch.Tensor): A tensor containing all the direction vectors
that were created, stacked along the batch dimension, with a shape
of: [n_vecs, n_channels].
- cell_coords (list of tuple of int): List of coordinates for grid
spatial positions of each direction vector, and the number of samples
used for the cell. The list for each cell is in the format of:
[x_coord, y_coord, number_of_samples_used].
"""

assert xy_grid.dim() == 2 and xy_grid.size(1) == 2
Expand Down Expand Up @@ -235,19 +249,19 @@ def create_atlas(

Args:

cells (list of torch.tensor or torch.tensor): A list or stack of NCHW image
cells (list of torch.Tensor or torch.Tensor): A list or stack of NCHW image
tensors made with atlas direction vectors.
coords (list of Tuple[int, int] or list of Tuple[int, int, int]): A list of
coordinates to use for the atlas image tensors. The first 2 values in each
coordinate list should be: [x, y, ...].
grid_size (Tuple[int, int]): The size of grid cells to use. The grid_size
variable should be in the format of: [width, height].
coords (list of tuple of int): A list of coordinates to use for the atlas image
tensors. The first 2 values in each coordinate list should be: [x, y, ...].
grid_size (tuple of int): The number of grid cells to use across the height
and width dimensions. The ``grid_size`` variable should be in the format
of: [width, height].
base_tensor (Callable, optional): What to use for the atlas base tensor. Basic
choices are: torch.ones or torch.zeros.
Default: torch.ones
choices are: :func:`torch.ones` or :func:`torch.zeros`.
Default: :func:`torch.ones`

Returns:
atlas_canvas (torch.tensor): The full activation atlas visualization, with a
atlas_canvas (torch.Tensor): The full activation atlas visualization, with a
shape of NCHW.
"""

Expand All @@ -262,7 +276,7 @@ def create_atlas(

# cell_b -> number of images
# cell_c -> image channel
# cell_h -> image hight
# cell_h -> image height
# cell_w -> image width
cell_b, cell_c, cell_h, cell_w = cells[0].shape
atlas_canvas = base_tensor(
Expand Down
Loading