Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new argument num_classes to eval scripts #21

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ You can generate videos like the one on the blog post with `video_generation.py`
https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4

Extract frames from input video and generate attention video:

```
python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
--input_path input/video.mp4 \
Expand All @@ -246,7 +247,6 @@ python video_generation.py --input_path output/attention \
--video_format avi
```


## Evaluation: k-NN classification on ImageNet
To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
```
Expand Down
3 changes: 2 additions & 1 deletion eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __getitem__(self, idx):
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
parser.add_argument("--num_classes", default=1000, type=int, help="Num classes")
args = parser.parse_args()

utils.init_distributed_mode(args)
Expand All @@ -237,6 +238,6 @@ def __getitem__(self, idx):
print("Features are ready!\nStart the k-NN classification.")
for k in args.nb_knn:
top1, top5 = knn_classifier(train_features, train_labels,
test_features, test_labels, k, args.temperature)
test_features, test_labels, k, args.temperature, args.num_classes)
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
dist.barrier()
4 changes: 4 additions & 0 deletions video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import glob
import sys
Expand Down Expand Up @@ -258,8 +259,10 @@ def __load_model(self):
)
state_dict = state_dict[self.args.checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}

msg = model.load_state_dict(state_dict, strict=False)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
Expand All @@ -271,6 +274,7 @@ def __load_model(self):
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url = None

if self.args.arch == "vit_small" and self.args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif self.args.arch == "vit_small" and self.args.patch_size == 8:
Expand Down