-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Share Inception v3 network between FID, KID, Inception Score, and MiFID #2099
Comments
Hi! thanks for your contribution!, great first issue! |
Hi @gavrieladler, thanks for raising this issue. The current solution that I came up with is using a caching system and creating a specialized
Here is some work in progress code class NetworkCache(Module):
def __init__(self, network, max_size=100):
super().__init__()
self.max_size = max_size
self.network = lru_cache(maxsize=self.max_size)(network)
def forward(self, *args, **kwargs):
return self.network(*args, **kwargs)
class FeatureShare(MetricCollection):
def __init__(self,
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
*additional_metrics: Metric,
network_names: Union[str, Sequence[str]],
):
super().__init__(metrics=metrics, *additional_metrics)
if isinstance(network_names, str):
network_names = [network_names] * len(self)
else:
if len(network_names) != len(self):
raise ValueError('The number of network names should be equal to the number of metrics.')
shared_net = getattr(getattr(self, list(self.keys())[0]), network_names[0])
cached_net = NetworkCache(shared_net)
for (_, metric), network_name in zip(self.items(), network_names):
setattr(metric, network_name, cached_net) So very basic analysis: from torchmetrics.image import FrechetInceptionDistance, InceptionScore, KernelInceptionDistance
from torchmetrics import MetricCollection
fs = FeatureShare([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()], network_names=['inception', 'inception', 'inception'])
mc = MetricCollection([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()])
import time
start = time.time()
for _ in range(10):
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8)
mc.update(x, real=True)
end = time.time()
print(end - start)
start = time.time()
for _ in range(10):
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8)
fs.update(x, real=True)
end = time.time()
print(end - start) give me 5.57 and 1.92 respectively for What do you think about this? |
Hi @SkafteNicki, Thanks for the detailed reply! I think this looks good and is absolutely usable as presented. One suggestion, and I could go either way: I do think there's an argument to be made that because
may be unnecessary for the user to specify? Torchmetrics only supports a finite number of metrics and is able to "know" that the selected features should share a network. Only requiring the user to specify
without having to have detailed knowledge of the inner workings of each metrics is a cleaner API. Especially since The code could automatically, when you create any of Thoughts? Like I said, the solution presented is definitely usable as is. |
🚀 Feature
A single copy of Inception v3 loaded on a GPU and shared between all the metrics which use it.
Motivation
FID, KID, Inception Score, and MiFID all use the same Inception v3 network, which uses 304MB of GPU memory when loaded. If you want to calculate all 4 scores while training, that's 1.2GB GPU memory per GPU. It would be ideal if the metrics could share a single copy of the network to minimize GPU memory usage.
Pitch
Torchmetrics should know if one metric which uses Inception V3 has been loaded and share the copy between the metrics, so when a second metric is created, a second copy of the network is not created.
Alternatives
__init__
of another metric. So for example, today you could do:However I am hesitant to rely on this as it depends on the internal details of metrics which are subject to change release to release. I would not know from release notes if this was broken or not.
update
/reset
/compuse
would in turn pass that down to each individual metric. This could get a bit messy as the__init__
APIs do not exactly line up.Additional context
It would be doubly awesome if the solution allowed you to share activations as well as the network itself, as currently for each copy I'm calculating the same activations on each copy of the network.
Thanks so much!
The text was updated successfully, but these errors were encountered: