-
Notifications
You must be signed in to change notification settings - Fork 528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Jit profile #150
base: master
Are you sure you want to change the base?
Adding Jit profile #150
Conversation
HaoKang-Timmy
commented
Oct 10, 2021
- Adding class JitProfile, which can count macs and params
- fix some bugs in counter
count_Transformer for nn.Transformer
The answer is 2.76945964418343e+25 44151808.0.
no embeding ** 2
For one word, about 25G flops and 44M parameters
@@ -12,11 +12,11 @@ def count_parameters(m, x, y): | |||
total_params = 0 | |||
for p in m.parameters(): | |||
total_params += torch.DoubleTensor([p.numel()]) | |||
m.total_params[0] = total_params | |||
m.total_params[0] = counter_parameters(m.parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why call count_parameters() here? total_params
already provides the number of parameters of the model.
results = dict() | ||
graph = trace(model, args) | ||
for node in graph.nodes: | ||
for operators, func in handlers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an O(N) operation and will become slow when the number of handlers increases. I suggest to rework it to dictionary.
|