From eb4b756f9f9772c1d60fa4263d6b36c4e08cb38c Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Fri, 14 Jul 2023 17:36:09 +0200 Subject: [PATCH] fix --cuda with --device_map --- src/lmql/models/lmtp/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lmql/models/lmtp/utils.py b/src/lmql/models/lmtp/utils.py index 9a30a4ca..7aa9597b 100644 --- a/src/lmql/models/lmtp/utils.py +++ b/src/lmql/models/lmtp/utils.py @@ -12,6 +12,9 @@ def rename_model_args(model_args): # parse cuda if cuda: - model_args["device_map"] = "auto" + if "device_map" in model_args: + print("Warning: device_map is set, but cuda is True. Ignoring 'cuda' which would set device_map to 'auto'.") + else: + model_args["device_map"] = "auto" return model_args