diff --git a/amlslurm/__init__.py b/amlslurm/__init__.py index 4cdbc9d..68cbcef 100644 --- a/amlslurm/__init__.py +++ b/amlslurm/__init__.py @@ -95,14 +95,17 @@ def sbatch(vargs=None): ) from azure.ai.ml import command + from azure.ai.ml.entities import Environment import argparse parser = argparse.ArgumentParser(description='sbatch: submit jobs to Azure Machine Learning') parser.prog = "sbatch" parser.add_argument('-a', '--array', default="None", type=str, help='index for array jobs') + parser.add_argument('--container', default="None", type=str, help='container environment for the job to run in') + parser.add_argument('-e', '--environment', default="None", type=str, help='Azure Machine Learning environment, should be enclosed in quotes, may use @latest') + parser.add_argument('-N', '--nodes', default=1, type=int, help='amount of nodes to use for the job') parser.add_argument('-p', '--partition', type=str, required=True, help='set compute partition where the job should be run. Use to view available partitions') - parser.add_argument('-N', '--nodes', default=1, type=int, help='amount of nodes to use for the job') parser.add_argument('-w', '--wrap', type=str, help='command line to be executed, should be enclosed with quotes') parser.add_argument('script', nargs='?', default="None", type=str, help='script to be executed') args = parser.parse_args(vargs) @@ -115,11 +118,24 @@ def sbatch(vargs=None): print("Conflict: provide either script to execute as argument or commandline to execute through --wrap option") exit(-1) + if (args.container != "None"): + env_docker_image = Environment( + image=args.container, + name="sbatch-container-image", + description="Environment created from a Docker image.", + ) + ml_client.environments.create_or_update(env_docker_image) + args.environment = "sbatch-container-image@latest" + + if (args.environment == "None"): + args.environment = "ubuntu2004-mofed@latest" + + print(args.environment) if (args.script != "None"): command_job = command( code=pwd + "/" + args.script, command=args.script, - environment="docker-test1:4", + environment=args.environment, instance_count=args.nodes, compute=args.partition, ) @@ -127,7 +143,7 @@ def sbatch(vargs=None): if (args.wrap is not None): command_job = command( command=args.wrap, - environment="docker-test1:4", + environment=args.environment, instance_count=args.nodes, compute=args.partition, )