diff --git a/src/codemodder/cli.py b/src/codemodder/cli.py index c434734be..e8a97e17f 100644 --- a/src/codemodder/cli.py +++ b/src/codemodder/cli.py @@ -89,11 +89,10 @@ class ValidatedCodmods(CsvListAction): def validate_items(self, items): potential_names = ids + names - unrecognized_codemods = [ - name for name in items if name not in potential_names - ] - if unrecognized_codemods: + if unrecognized_codemods := [ + name for name in items if name not in potential_names + ]: args = { "values": unrecognized_codemods, "choices": ", ".join(map(repr, names)), diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 2bbcbd270..8c94a3b7c 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -119,10 +119,8 @@ def __init__(self, context: CodemodContext) -> None: MetadataDependent.__init__(self) MatcherDecoratableTransformer.__init__(self) self.context = context - dependencies = self.get_inherited_dependencies() - if dependencies: - wrapper = self.context.wrapper - if wrapper is None: + if dependencies := self.get_inherited_dependencies(): + if (wrapper := self.context.wrapper) is None: raise ValueError( f"Attempting to instantiate {self.__class__.__name__} outside of " + "an active transform. This means that metadata hasn't been " diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 8f5cee5f3..28225e64b 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -59,8 +59,7 @@ def find_base_name(self, node) -> Optional[str]: return self._find_imported_name(node) case cst.Attribute(): - maybe_name = self.find_base_name(node.value) - if maybe_name: + if maybe_name := self.find_base_name(node.value): return maybe_name + "." + node.attr.value case cst.Call(): @@ -182,8 +181,7 @@ def find_used_names_in_module(self): Find all the used names in the scope of a libcst Module. """ names = [] - scope = self.find_global_scope() - if scope is None: + if (scope := self.find_global_scope()) is None: return [] # pragma: no cover nodes = [x.node for x in scope.assignments] @@ -276,8 +274,7 @@ def is_builtin_function(self, node: cst.Call): return False def find_accesses(self, node) -> Collection[Access]: - scope = self.get_metadata(ScopeProvider, node, None) - if scope: + if scope := self.get_metadata(ScopeProvider, node, None): return scope.accesses[node] return {} @@ -461,8 +458,7 @@ def resolve_expression(self, node: cst.BaseExpression) -> cst.BaseExpression: maybe_expr = None match node: case cst.Name(): - maybe_expr = self._resolve_name_transitive(node) - if maybe_expr: + if maybe_expr := self._resolve_name_transitive(node): return maybe_expr return node diff --git a/src/codemodder/dependency_management/base_dependency_writer.py b/src/codemodder/dependency_management/base_dependency_writer.py index 76ea60a28..060835d86 100644 --- a/src/codemodder/dependency_management/base_dependency_writer.py +++ b/src/codemodder/dependency_management/base_dependency_writer.py @@ -25,8 +25,7 @@ def add_to_file( def write( self, dependencies: list[Dependency], dry_run: bool = False ) -> Optional[ChangeSet]: - new_dependencies = self.add(dependencies) - if new_dependencies: + if new_dependencies := self.add(dependencies): return self.add_to_file(new_dependencies, dry_run) return None diff --git a/src/codemodder/dependency_management/setup_py_writer.py b/src/codemodder/dependency_management/setup_py_writer.py index fcba52880..87d4ada37 100644 --- a/src/codemodder/dependency_management/setup_py_writer.py +++ b/src/codemodder/dependency_management/setup_py_writer.py @@ -83,8 +83,7 @@ def visit_Module(self, _: cst.Module) -> bool: return is_setup_py_file(self.filename) def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - true_name = self.find_base_name(original_node.func) - if true_name != "setuptools.setup": + if (true_name := self.find_base_name(original_node.func)) != "setuptools.setup": return original_node new_args = self.replace_arg(original_node) diff --git a/src/codemodder/dependency_management/setupcfg_writer.py b/src/codemodder/dependency_management/setupcfg_writer.py index 0728d4c51..0a594258d 100644 --- a/src/codemodder/dependency_management/setupcfg_writer.py +++ b/src/codemodder/dependency_management/setupcfg_writer.py @@ -12,8 +12,7 @@ def find_leading_whitespace(s): - match = re.match(r"(\s+)", s) - if match: + if match := re.match(r"(\s+)", s): return match.group(1) return "" # pragma: no cover @@ -81,8 +80,7 @@ def build_new_lines( """ clean_lines = [s.strip() for s in original_lines] - newline_separated = len(defined_dependencies.split("\n")) > 1 - if newline_separated: + if newline_separated := len(defined_dependencies.split("\n")) > 1: last_dep_line = defined_dependencies.split("\n")[-1] dep_sep = "\n" else: diff --git a/src/core_codemods/django_receiver_on_top.py b/src/core_codemods/django_receiver_on_top.py index c153fb58f..92079227c 100644 --- a/src/core_codemods/django_receiver_on_top.py +++ b/src/core_codemods/django_receiver_on_top.py @@ -29,8 +29,7 @@ def leave_FunctionDef( # that that have different start/end numbers. maybe_receiver_with_index = None for i, decorator in enumerate(original_node.decorators): - true_name = self.find_base_name(decorator.decorator) - if true_name == "django.dispatch.receiver": + if (true_name := self.find_base_name(decorator.decorator)) == "django.dispatch.receiver": maybe_receiver_with_index = (i, decorator) if maybe_receiver_with_index: diff --git a/src/core_codemods/file_resource_leak.py b/src/core_codemods/file_resource_leak.py index ee4bb4dae..966f125cd 100644 --- a/src/core_codemods/file_resource_leak.py +++ b/src/core_codemods/file_resource_leak.py @@ -136,8 +136,7 @@ def _is_resource_call(self, value) -> Optional[cst.Call]: return None def _is_resource(self, call: cst.Call) -> bool: - maybe_assignment = self.find_single_assignment(call) - if maybe_assignment: + if maybe_assignment := self.find_single_assignment(call): # is open call if isinstance(maybe_assignment, BuiltinAssignment) and matchers.matches( call.func, matchers.Name(value="open") @@ -242,8 +241,7 @@ def _find_direct_name_assignment_targets( name_targets = [] accesses = self.find_accesses(name) for node in (access.node for access in accesses): - maybe_assigned = self.is_value_of_assignment(node) - if maybe_assigned: + if maybe_assigned := self.is_value_of_assignment(node): targets = extract_targets_of_assignment(maybe_assigned) name_targets.extend(targets) return name_targets @@ -277,8 +275,7 @@ def _sieve_targets( def _find_transitive_assignment_targets( self, expr ) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]: - maybe_assigned = self.is_value_of_assignment(expr) - if maybe_assigned: + if maybe_assigned := self.is_value_of_assignment(expr): named_targets, other_targets = self._sieve_targets( extract_targets_of_assignment(maybe_assigned) ) diff --git a/src/core_codemods/flask_json_response_type.py b/src/core_codemods/flask_json_response_type.py index 273ab7c1d..6d4ce2b2d 100644 --- a/src/core_codemods/flask_json_response_type.py +++ b/src/core_codemods/flask_json_response_type.py @@ -113,10 +113,9 @@ def _is_tuple_with_json_string_response( case cst.Tuple(): elements = node.elements first = elements[0].value - maybe_vuln = self._is_json_dumps_call( + if maybe_vuln := self._is_json_dumps_call( first - ) or self._is_make_response_with_json(first) - if maybe_vuln: + ) or self._is_make_response_with_json(first): return node return None @@ -153,8 +152,7 @@ def _is_json_dumps_call(self, node: cst.BaseExpression) -> Optional[cst.Call]: expr = self.resolve_expression(node) match expr: case cst.Call(): - true_name = self.find_base_name(expr) - if true_name == "json.dumps": + if (true_name := self.find_base_name(expr)) == "json.dumps": return expr return None @@ -164,8 +162,7 @@ def _is_make_response_with_json( expr = self.resolve_expression(node) match expr: case cst.Call(args=[cst.Arg(first_arg), *_]): - true_name = self.find_base_name(expr) - if true_name != "flask.make_response": + if (true_name := self.find_base_name(expr)) != "flask.make_response": return None match first_arg: case cst.Tuple(): diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 2dac634a6..dbdd5601c 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -79,8 +79,7 @@ def on_result_found( maybe_name = self.get_aliased_prefix_name( original_node, self._module_name ) - maybe_name = maybe_name or self._module_name - if maybe_name == self._module_name: + if (maybe_name := maybe_name or self._module_name) == self._module_name: self.add_needed_import(self._module_name) new_args = [ *updated_node.args[:1], diff --git a/src/core_codemods/secure_flask_session_config.py b/src/core_codemods/secure_flask_session_config.py index 95c2ae5ec..725c234e3 100644 --- a/src/core_codemods/secure_flask_session_config.py +++ b/src/core_codemods/secure_flask_session_config.py @@ -129,8 +129,7 @@ def flask_app_is_assigned(self): return bool(self.flask_app_name) def leave_Call(self, original_node: cst.Call, updated_node: cst.Call): - true_name = self.find_base_name(original_node.func) - if true_name == "flask.Flask": + if (true_name := self.find_base_name(original_node.func)) == "flask.Flask": self._store_flask_app(original_node) if self.flask_app_is_assigned and self._is_config_update_call(original_node): diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index df115eb14..c89aef089 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -311,8 +311,7 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: return False def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]: - assignment = self.find_single_assignment(node) - if assignment: + if assignment := self.find_single_assignment(node): base_scope = assignment.scope # TODO make this check in detect injection, to be more precise diff --git a/src/core_codemods/tempfile_mktemp.py b/src/core_codemods/tempfile_mktemp.py index c5a00fdb5..7d43adb1d 100644 --- a/src/core_codemods/tempfile_mktemp.py +++ b/src/core_codemods/tempfile_mktemp.py @@ -32,8 +32,7 @@ class TempfileMktemp(SimpleCodemod, NameResolutionMixin): def on_result_found(self, original_node, updated_node): maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) - maybe_name = maybe_name or self._module_name - if maybe_name == self._module_name: + if (maybe_name := maybe_name or self._module_name) == self._module_name: self.add_needed_import(self._module_name) self.remove_unused_import(original_node) return self.update_call_target(updated_node, maybe_name, "mkstemp") diff --git a/src/core_codemods/upgrade_sslcontext_minimum_version.py b/src/core_codemods/upgrade_sslcontext_minimum_version.py index cddf1dcdd..1aad403cc 100644 --- a/src/core_codemods/upgrade_sslcontext_minimum_version.py +++ b/src/core_codemods/upgrade_sslcontext_minimum_version.py @@ -50,8 +50,7 @@ def on_result_found(self, original_node, updated_node): maybe_name = self.get_aliased_prefix_name( original_node.value, self._module_name ) - maybe_name = maybe_name or self._module_name - if maybe_name == self._module_name: + if (maybe_name := maybe_name or self._module_name) == self._module_name: self.add_needed_import(self._module_name) self.remove_unused_import(original_node) return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2")