From 2a7da711b60a003df06f2f0a0e50cc83ef5e900d Mon Sep 17 00:00:00 2001 From: "pixeebot[bot]" <104101892+pixeebot[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 21:27:12 +0000 Subject: [PATCH] Use Assignment Expression (Walrus) In Conditional --- sigoptlite/builders.py | 6 ++---- sigoptlite/driver.py | 3 +-- sigoptlite/models.py | 3 +-- sigoptlite/sources.py | 3 +-- sigoptlite/validators.py | 9 +++------ 5 files changed, 8 insertions(+), 16 deletions(-) diff --git a/sigoptlite/builders.py b/sigoptlite/builders.py index 1c5c14a..3668e39 100644 --- a/sigoptlite/builders.py +++ b/sigoptlite/builders.py @@ -254,8 +254,7 @@ def validate_input_dict(cls, input_dict): def create_object(cls, **input_dict): cls.set_object(input_dict, "assignments", LocalAssignments) cls.set_list_of_objects(input_dict, field="values", local_class=MetricEvaluationBuilder) - values = input_dict.get("values") - if values: + if values := input_dict.get("values"): input_dict["metric_evaluations"] = {me.name: me for me in values} input_dict.pop("values", None) cls.set_object(input_dict, "task", LocalTaskBuilder) @@ -273,8 +272,7 @@ class MetricEvaluationBuilder(BuilderBase): def validate_input_dict(cls, input_dict): assert isinstance(input_dict["name"], str) assert isinstance(input_dict["value"], (int, float)) - value_stddev = input_dict.get("value_stddev", None) - if value_stddev is not None: + if (value_stddev := input_dict.get("value_stddev", None)) is not None: assert isinstance(input_dict["value_stddev"], (int, float)) assert input_dict["value_stddev"] >= 0 assert set(input_dict.keys()) == {"name", "value", "value_stddev"} diff --git a/sigoptlite/driver.py b/sigoptlite/driver.py index f144998..406ffcd 100644 --- a/sigoptlite/driver.py +++ b/sigoptlite/driver.py @@ -112,7 +112,6 @@ def path_to_route(self, path, method): def request(self, method, path, data, headers): route = self.path_to_route(path, method) - handler = self.routes.get(route, {}).get(method) - if handler is None: + if (handler := self.routes.get(route, {}).get(method)) is None: raise Exception(f"{PRODUCT_NAME} only supports the following routes: {' '.join(self.routes.keys())}") return handler(data) diff --git a/sigoptlite/models.py b/sigoptlite/models.py index 8cd09ba..14bd8a8 100644 --- a/sigoptlite/models.py +++ b/sigoptlite/models.py @@ -254,8 +254,7 @@ def get_optimized_measurements_for_maximization(self, experiment): return [self.get_value_for_maximization(metric) for metric in experiment.optimized_metrics] def get_value_for_maximization(self, metric): - value = self.get_metric_evaluation_by_name(metric.name).value - if value is None: + if (value := self.get_metric_evaluation_by_name(metric.name).value) is None: raise Exception(f"Metric `{metric.name}` is not in observation data.") if metric.is_minimized: return -value diff --git a/sigoptlite/sources.py b/sigoptlite/sources.py index f4b8829..baeba58 100644 --- a/sigoptlite/sources.py +++ b/sigoptlite/sources.py @@ -291,8 +291,7 @@ def make_assignments_from_point(experiment, point): def get_point_from_assignments(experiment, assignments): point = numpy.empty(experiment.dimension) for i, parameter in enumerate(experiment.parameters): - parameter_value = assignments.get(parameter.name, None) - if parameter_value is None: + if (parameter_value := assignments.get(parameter.name, None)) is None: parameter_value = replacement_value_if_missing(parameter) if parameter.has_log_transformation: parameter_value = numpy.log10(parameter_value) diff --git a/sigoptlite/validators.py b/sigoptlite/validators.py index 6a010bd..4b3b871 100644 --- a/sigoptlite/validators.py +++ b/sigoptlite/validators.py @@ -204,8 +204,7 @@ def validate_experiment(experiment, cls_name): if not experiment.parallel_bandwidth == 1: raise ValueError(f"{cls_name} must have parallel_bandwidth == 1") - observation_budget = experiment.observation_budget - if observation_budget is None: + if (observation_budget := experiment.observation_budget) is None: if experiment.num_solutions > 1: raise ValueError(f"observation_budget is required for a {cls_name} with multiple solutions") if experiment.requires_pareto_frontier_optimization: @@ -250,8 +249,7 @@ def validate_experiment(experiment, cls_name): validate_conditionals_for_experiment(experiment) # Check feature viability of multitask - tasks = experiment.tasks - if tasks: + if tasks := experiment.tasks: if experiment.requires_pareto_frontier_optimization: raise ValueError(f"{cls_name} cannot have both tasks and multiple optimized metrics") if experiment.has_constraint_metrics: @@ -441,8 +439,7 @@ def validate_constraints_for_experiment(experiment): term_types = [] for term in terms: - coeff = term.weight - if coeff == 0: + if (coeff := term.weight) == 0: continue name = term.name if name in integer_params_names: