From 7edf0e479600279c4485562d880a947be7f01ea7 Mon Sep 17 00:00:00 2001 From: Lukas Garbas Date: Sat, 30 Nov 2024 05:38:29 +0100 Subject: [PATCH] Fix default device setting --- transformer_ranker/ranker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_ranker/ranker.py b/transformer_ranker/ranker.py index 57d1362..37503a6 100644 --- a/transformer_ranker/ranker.py +++ b/transformer_ranker/ranker.py @@ -72,6 +72,8 @@ def run( """ self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator) + device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + # Load all transformers into hf cache self._preload_transformers(models, device)