Skip to content

Commit

Permalink
Improved documentation and added unit tests for mlayer module (#19)
Browse files Browse the repository at this point in the history
* Improved documentation and added unit tests for mlayer module

* fix gitignore
  • Loading branch information
sacadena authored Dec 19, 2023
1 parent ace8724 commit 7f8610c
Show file tree
Hide file tree
Showing 6 changed files with 674 additions and 519 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,7 @@ venv.bak/
dmypy.json

# Pyre type checker
.pyre/
.pyre/

# Notebooks
notebooks/
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,79 @@
[![License](https://img.shields.io/github/license/sacadena/ptrnets)](https://img.shields.io/github/license/sacadena/ptrnets)

Collection of pretrained networks in pytorch readily available for transfer learning tasks like neural system identification.

## Installation

```bash
pip install ptrnets
```

## Usage
Find a list of all available models like this:

```python
from ptrnets import AVAILABLE_MODELS

print(AVAILABLE_MODELS)
```

Import a model like this:

```python
from ptrnets import simclr_resnet50x2

model = simclr_resnet50x2(pretrained=True)
```
You can access intermediate representations in two ways:

### Probing the model
You can conveniently access intermediate representations of a forward pass using the `ptrnets.utils.mlayer.probe_model` function Example:
```python
import torch
from ptrnets import resnet50
from ptrnets.utils.mlayer import probe_model

model = resnet50(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "layer2.1"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"

model_probe = probe_model(model, layer_name)

x = torch.rand(1, 3, 224, 224)
output = model_probe(x)
```

**Note**: if the input is not large enough to do a full forward pass through the network, you might need to use a `try-except` block to catch the `RuntimeError`.

### Clipping the model

`ptrnets.utils.mlayer.clip_model` creates a copy of the model up to a specific layer. Because the model is smaller, a forward pass can run faster.
However, the output is only guaranteed to be the same as the original model's if the architecture is fully sequential up until that layer.

Example:
```python
import torch
from ptrnets import vgg16
from ptrnets.utils.mlayer import clip_model, probe_model

model = vgg16(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "features.18"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"

model_clipped = clip_model(model, layer_name) # Creates new model up to the layer

x = torch.rand(1, 3, 224, 224)
output = model_clipped(x)

assert torch.allclose(output, probe_model(model, layer_name)(x)), "Output of clipped model is not the same as the original model"
```

## Contributing
Pull requests are welcome.
Please see instructions [here](https://github.com/sacadena/ptrnets/blob/master/CONTRIBUTING.rst).




Loading

0 comments on commit 7f8610c

Please sign in to comment.