Skip to content

Commit

Permalink
fix: fixing some type hints for optional params (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
krneta authored Nov 3, 2023
1 parent c953d2f commit dc2ac56
Show file tree
Hide file tree
Showing 24 changed files with 332 additions and 309 deletions.
8 changes: 4 additions & 4 deletions src/braket/annealing/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class Problem:
def __init__(
self,
problem_type: ProblemType,
linear: Dict[int, float] = None,
quadratic: Dict[Tuple[int, int], float] = None,
linear: Dict[int, float] | None = None,
quadratic: Dict[Tuple[int, int], float] | None = None,
):
"""
Args:
problem_type (ProblemType): The type of annealing problem
linear (Dict[int, float]): The linear terms of this problem,
linear (Dict[int, float] | None): The linear terms of this problem,
as a map of variable to coefficient
quadratic (Dict[Tuple[int, int], float]): The quadratic terms of this problem,
quadratic (Dict[Tuple[int, int], float] | None): The quadratic terms of this problem,
as a map of variables to coefficient
Examples:
Expand Down
88 changes: 45 additions & 43 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,22 @@ def create(
cls,
device: str,
source_module: str,
entry_point: str = None,
image_uri: str = None,
job_name: str = None,
code_location: str = None,
role_arn: str = None,
entry_point: str | None = None,
image_uri: str | None = None,
job_name: str | None = None,
code_location: str | None = None,
role_arn: str | None = None,
wait_until_complete: bool = False,
hyperparameters: dict[str, Any] = None,
input_data: str | dict | S3DataSourceConfig = None,
instance_config: InstanceConfig = None,
distribution: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
aws_session: AwsSession = None,
tags: dict[str, str] = None,
hyperparameters: dict[str, Any] | None = None,
input_data: str | dict | S3DataSourceConfig | None = None,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Expand All @@ -96,77 +96,79 @@ def create(
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
tar.gz file. Otherwise, source_module may be a file or directory.
entry_point (str): A str that specifies the entry point of the hybrid job, relative to
the source module. The entry point must be in the format
entry_point (str | None): A str that specifies the entry point of the hybrid job,
relative to the source module. The entry point must be in the format
`importable.module` or `importable.module:callable`. For example,
`source_module.submodule:start_here` indicates the `start_here` function
contained in `source_module.submodule`. If source_module is an S3 URI,
entry point must be given. Default: source_module's name
image_uri (str): A str that specifies the ECR image to use for executing the hybrid job.
`image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs
for the containers supported by Braket. Default = `<Braket base image_uri>`.
image_uri (str | None): A str that specifies the ECR image to use for executing the
hybrid job. `image_uris.retrieve_image()` function may be used for retrieving the
ECR image URIs for the containers supported by Braket.
Default = `<Braket base image_uri>`.
job_name (str): A str that specifies the name with which the hybrid job is created.
Allowed pattern for hybrid job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`
job_name (str | None): A str that specifies the name with which the hybrid job is
created. Allowed pattern for hybrid job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`
Default: f'{image_uri_type}-{timestamp}'.
code_location (str): The S3 prefix URI where custom code will be uploaded.
code_location (str | None): The S3 prefix URI where custom code will be uploaded.
Default: f's3://{default_bucket_name}/jobs/{job_name}/script'.
role_arn (str): A str providing the IAM role ARN used to execute the
role_arn (str | None): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.
wait_until_complete (bool): `True` if we should wait until the hybrid job completes.
This would tail the hybrid job logs as it waits. Otherwise `False`.
Default: `False`.
hyperparameters (dict[str, Any]): Hyperparameters accessible to the hybrid job.
hyperparameters (dict[str, Any] | None): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a dict[str, str] to the hybrid job.
For convenience, this accepts other types for keys and values, but `str()`
is called to convert them before being passed on. Default: None.
input_data (str | dict | S3DataSourceConfig): Information about the training
input_data (str | dict | S3DataSourceConfig | None): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
path, S3 URI, or S3DataSourceConfig is provided, it will be given a default
channel name "input".
Default: {}.
instance_config (InstanceConfig): Configuration of the instance(s) for running the
classical code for the hybrid job. Default:
instance_config (InstanceConfig | None): Configuration of the instance(s) for running
the classical code for the hybrid job. Default:
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str): A str that specifies how the hybrid job should be distributed.
If set to "data_parallel", the hyperparameters for the hybrid job will be set
to use data parallelism features for PyTorch or TensorFlow. Default: None.
distribution (str | None): A str that specifies how the hybrid job should be
distributed. If set to "data_parallel", the hyperparameters for the hybrid job will
be set to use data parallelism features for PyTorch or TensorFlow. Default: None.
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
and the maximum number of quantum tasks that a hybrid job can run before being
forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
output_data_config (OutputDataConfig): Specifies the location for the output of the
hybrid job.
output_data_config (OutputDataConfig | None): Specifies the location for the output of
the hybrid job.
Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None).
copy_checkpoints_from_job (str): A str that specifies the hybrid job ARN whose
copy_checkpoints_from_job (str | None): A str that specifies the hybrid job ARN whose
checkpoint you want to use in the current hybrid job. Specifying this value will
copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current hybrid job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the hybrid job execution. Default: None
checkpoint_config (CheckpointConfig): Configuration that specifies the location where
checkpoint data is stored.
checkpoint_config (CheckpointConfig | None): Configuration that specifies the location
where checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
aws_session (AwsSession): AwsSession for connecting to AWS Services.
aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this hybrid job.
tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this
hybrid job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as quantum task statuses
Expand Down Expand Up @@ -210,11 +212,11 @@ def create(

return job

def __init__(self, arn: str, aws_session: AwsSession = None):
def __init__(self, arn: str, aws_session: AwsSession | None = None):
"""
Args:
arn (str): The ARN of the hybrid job.
aws_session (AwsSession): The `AwsSession` for connecting to AWS services.
aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the hybrid job.
"""
Expand Down Expand Up @@ -486,7 +488,7 @@ def _read_and_deserialize_results(temp_dir: str, job_name: str) -> dict[str, Any

def download_result(
self,
extract_to: str = None,
extract_to: str | None = None,
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
) -> None:
Expand All @@ -495,7 +497,7 @@ def download_result(
the results are extracted to the current directory.
Args:
extract_to (str): The directory to which the results are extracted. The results
extract_to (str | None): The directory to which the results are extracted. The results
are extracted to a folder titled with the hybrid job name within this directory.
Default= `Current working directory`.
poll_timeout_seconds (float): The polling timeout, in seconds, for `download_result()`.
Expand Down
22 changes: 11 additions & 11 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def create(
],
s3_destination_folder: AwsSession.S3DestinationFolder,
shots: int,
device_parameters: dict[str, Any] = None,
device_parameters: dict[str, Any] | None = None,
disable_qubit_rewiring: bool = False,
tags: dict[str, str] = None,
inputs: dict[str, float] = None,
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] = None,
tags: dict[str, str] | None = None,
inputs: dict[str, float] | None = None,
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None,
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -129,23 +129,23 @@ def create(
`shots=0` is only available on simulators and means that the simulator
will compute the exact results based on the quantum task specification.
device_parameters (dict[str, Any]): Additional parameters to send to the device.
device_parameters (dict[str, Any] | None): Additional parameters to send to the device.
disable_qubit_rewiring (bool): Whether to run the circuit with the exact qubits chosen,
without any rewiring downstream, if this is supported by the device.
Only applies to digital, gate-based circuits (as opposed to annealing problems).
If ``True``, no qubit rewiring is allowed; if ``False``, qubit rewiring is allowed.
Default: False
tags (dict[str, str]): Tags, which are Key-Value pairs to add to this quantum task.
An example would be:
tags (dict[str, str] | None): Tags, which are Key-Value pairs to add to this quantum
task. An example would be:
`{"state": "washington"}`
inputs (dict[str, float]): Inputs to be passed along with the
inputs (dict[str, float] | None): Inputs to be passed along with the
IR. If the IR supports inputs, the inputs will be updated with this value.
Default: {}.
gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]]):
gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None):
A `Dict` for user defined gate calibration. The calibration is defined for
for a particular `Gate` on a particular `QubitSet` and is represented by
a `PulseSequence`.
Expand Down Expand Up @@ -203,15 +203,15 @@ def create(
def __init__(
self,
arn: str,
aws_session: AwsSession = None,
aws_session: AwsSession | None = None,
poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
):
"""
Args:
arn (str): The ARN of the quantum task.
aws_session (AwsSession): The `AwsSession` for connecting to AWS services.
aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the quantum task.
poll_timeout_seconds (float): The polling timeout for `result()`. Default: 5 days.
Expand Down
4 changes: 2 additions & 2 deletions src/braket/aws/aws_quantum_task_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
max_workers: int = MAX_CONNECTIONS_DEFAULT,
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs: Union[dict[str, float], list[dict[str, float]]] = None,
inputs: Union[dict[str, float], list[dict[str, float]]] | None = None,
*aws_quantum_task_args,
**aws_quantum_task_kwargs,
):
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
in seconds. Default: 5 days.
poll_interval_seconds (float): The polling interval for results in seconds.
Default: 1 second.
inputs (Union[dict[str, float], list[dict[str, float]]]): Inputs to be passed
inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed
along with the IR. If the IR supports inputs, the inputs will be updated
with this value. Default: {}.
"""
Expand Down
20 changes: 10 additions & 10 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ class AwsSession(object):

def __init__(
self,
boto_session: boto3.Session = None,
braket_client: client = None,
config: Config = None,
default_bucket: str = None,
boto_session: boto3.Session | None = None,
braket_client: client | None = None,
config: Config | None = None,
default_bucket: str | None = None,
):
"""
Args:
boto_session (Session): A boto3 session object.
braket_client (client): A boto3 Braket client.
config (Config): A botocore Config object.
default_bucket (str): The name of the default bucket of the AWS Session.
boto_session (Session | None): A boto3 session object.
braket_client (client | None): A boto3 Braket client.
config (Config | None): A botocore Config object.
default_bucket (str | None): The name of the default bucket of the AWS Session.
"""
if (
boto_session
Expand Down Expand Up @@ -716,7 +716,7 @@ def describe_log_streams(
self,
log_group: str,
log_stream_prefix: str,
limit: int = None,
limit: Optional[int] = None,
next_token: Optional[str] = None,
) -> dict[str, Any]:
"""
Expand All @@ -725,7 +725,7 @@ def describe_log_streams(
Args:
log_group (str): Name of the log group.
log_stream_prefix (str): Prefix for log streams to include.
limit (int): Limit for number of log streams returned.
limit (Optional[int]): Limit for number of log streams returned.
default is 50.
next_token (Optional[str]): The token for the next set of items to return.
Would have been received in a previous call.
Expand Down
Loading

0 comments on commit dc2ac56

Please sign in to comment.