diff --git a/crystal_diffusion/sample_diffusion.py b/crystal_diffusion/sample_diffusion.py index 81f09e8c..ea598c63 100644 --- a/crystal_diffusion/sample_diffusion.py +++ b/crystal_diffusion/sample_diffusion.py @@ -47,7 +47,7 @@ def main(args: Optional[Any] = None): "--output", required=True, help="path to outputs - will store files here" ) parser.add_argument( - "--device", default="gpu", help="Device to use. Defaults to cuda." + "--device", default="cuda", help="Device to use. Defaults to cuda." ) args = parser.parse_args(args) if os.path.exists(args.output):