diff --git a/src/lmql/models/lmtp/utils.py b/src/lmql/models/lmtp/utils.py index 9a30a4ca..7aa9597b 100644 --- a/src/lmql/models/lmtp/utils.py +++ b/src/lmql/models/lmtp/utils.py @@ -12,6 +12,9 @@ def rename_model_args(model_args): # parse cuda if cuda: - model_args["device_map"] = "auto" + if "device_map" in model_args: + print("Warning: device_map is set, but cuda is True. Ignoring 'cuda' which would set device_map to 'auto'.") + else: + model_args["device_map"] = "auto" return model_args