Skip to content

Commit

Permalink
Improved parameter counting and model summary utilities with detailed…
Browse files Browse the repository at this point in the history
… breakdown
  • Loading branch information
shishir13 committed Dec 1, 2024
1 parent 301486b commit a077cd0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
5 changes: 5 additions & 0 deletions MNIST_99.4/check_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from models.model import MNISTModel
from utils import print_model_summary

model = MNISTModel()
print_model_summary(model)
21 changes: 16 additions & 5 deletions MNIST_99.4/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ def count_parameters(model):
"""
Counts the total number of trainable parameters in the model.
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = 0
for name, parameter in model.named_parameters():
if parameter.requires_grad:
param_count = parameter.numel()
print(f"{name}: {param_count:,} parameters")
total_params += param_count
print(f"\nTotal trainable parameters: {total_params:,}")
return total_params

def has_batch_norm(model):
"""
Expand Down Expand Up @@ -38,12 +45,16 @@ def print_model_summary(model):
Prints a summary of the model including parameter count
and layer types used.
"""
print("Model Summary:")
print("-" * 40)
param_count = count_parameters(model)
bn = has_batch_norm(model)
dropout = has_dropout(model)
fc = has_fully_connected(model)

print(f"Total Parameters: {param_count}")
print(f"Batch Normalization Used: {'Yes' if bn else 'No'}")
print(f"Dropout Used: {'Yes' if dropout else 'No'}")
print(f"Fully Connected Layers Used: {'Yes' if fc else 'No'}")
print(f"\nArchitecture Requirements:")
print(f"- Total Parameters: {param_count:,} {'[PASS]' if param_count <= 20000 else '[FAIL]'}")
print(f"- Batch Normalization: {'Yes [PASS]' if bn else 'No [FAIL]'}")
print(f"- Dropout: {'Yes [PASS]' if dropout else 'No [FAIL]'}")
print(f"- Fully Connected Layers: {'Yes [PASS]' if fc else 'No [FAIL]'}")
print("-" * 40)

0 comments on commit a077cd0

Please sign in to comment.