Skip to content

Commit

Permalink
Use Assignment Expression (Walrus) In Conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
pixeebot committed Jan 23, 2024
1 parent 2c094e8 commit 582f7eb
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 49 deletions.
7 changes: 3 additions & 4 deletions src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class ValidatedCodmods(CsvListAction):

def validate_items(self, items):
potential_names = ids + names
unrecognized_codemods = [
name for name in items if name not in potential_names
]

if unrecognized_codemods:
if unrecognized_codemods := [
name for name in items if name not in potential_names
]:
args = {
"values": unrecognized_codemods,
"choices": ", ".join(map(repr, names)),
Expand Down
6 changes: 2 additions & 4 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,8 @@ def __init__(self, context: CodemodContext) -> None:
MetadataDependent.__init__(self)
MatcherDecoratableTransformer.__init__(self)
self.context = context
dependencies = self.get_inherited_dependencies()
if dependencies:
wrapper = self.context.wrapper
if wrapper is None:
if dependencies := self.get_inherited_dependencies():
if (wrapper := self.context.wrapper) is None:
raise ValueError(
f"Attempting to instantiate {self.__class__.__name__} outside of "
+ "an active transform. This means that metadata hasn't been "
Expand Down
12 changes: 4 additions & 8 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def find_base_name(self, node) -> Optional[str]:
return self._find_imported_name(node)

case cst.Attribute():
maybe_name = self.find_base_name(node.value)
if maybe_name:
if maybe_name := self.find_base_name(node.value):
return maybe_name + "." + node.attr.value

case cst.Call():
Expand Down Expand Up @@ -182,8 +181,7 @@ def find_used_names_in_module(self):
Find all the used names in the scope of a libcst Module.
"""
names = []
scope = self.find_global_scope()
if scope is None:
if (scope := self.find_global_scope()) is None:
return [] # pragma: no cover

nodes = [x.node for x in scope.assignments]
Expand Down Expand Up @@ -276,8 +274,7 @@ def is_builtin_function(self, node: cst.Call):
return False

def find_accesses(self, node) -> Collection[Access]:
scope = self.get_metadata(ScopeProvider, node, None)
if scope:
if scope := self.get_metadata(ScopeProvider, node, None):
return scope.accesses[node]
return {}

Expand Down Expand Up @@ -461,8 +458,7 @@ def resolve_expression(self, node: cst.BaseExpression) -> cst.BaseExpression:
maybe_expr = None
match node:
case cst.Name():
maybe_expr = self._resolve_name_transitive(node)
if maybe_expr:
if maybe_expr := self._resolve_name_transitive(node):
return maybe_expr
return node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def add_to_file(
def write(
self, dependencies: list[Dependency], dry_run: bool = False
) -> Optional[ChangeSet]:
new_dependencies = self.add(dependencies)
if new_dependencies:
if new_dependencies := self.add(dependencies):
return self.add_to_file(new_dependencies, dry_run)
return None

Expand Down
3 changes: 1 addition & 2 deletions src/codemodder/dependency_management/setup_py_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def visit_Module(self, _: cst.Module) -> bool:
return is_setup_py_file(self.filename)

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
true_name = self.find_base_name(original_node.func)
if true_name != "setuptools.setup":
if (true_name := self.find_base_name(original_node.func)) != "setuptools.setup":
return original_node

new_args = self.replace_arg(original_node)
Expand Down
6 changes: 2 additions & 4 deletions src/codemodder/dependency_management/setupcfg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@


def find_leading_whitespace(s):
match = re.match(r"(\s+)", s)
if match:
if match := re.match(r"(\s+)", s):
return match.group(1)
return "" # pragma: no cover

Expand Down Expand Up @@ -81,8 +80,7 @@ def build_new_lines(
"""
clean_lines = [s.strip() for s in original_lines]

newline_separated = len(defined_dependencies.split("\n")) > 1
if newline_separated:
if newline_separated := len(defined_dependencies.split("\n")) > 1:
last_dep_line = defined_dependencies.split("\n")[-1]
dep_sep = "\n"
else:
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/django_receiver_on_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def leave_FunctionDef(
# that that have different start/end numbers.
maybe_receiver_with_index = None
for i, decorator in enumerate(original_node.decorators):
true_name = self.find_base_name(decorator.decorator)
if true_name == "django.dispatch.receiver":
if (true_name := self.find_base_name(decorator.decorator)) == "django.dispatch.receiver":
maybe_receiver_with_index = (i, decorator)

if maybe_receiver_with_index:
Expand Down
9 changes: 3 additions & 6 deletions src/core_codemods/file_resource_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def _is_resource_call(self, value) -> Optional[cst.Call]:
return None

def _is_resource(self, call: cst.Call) -> bool:
maybe_assignment = self.find_single_assignment(call)
if maybe_assignment:
if maybe_assignment := self.find_single_assignment(call):
# is open call
if isinstance(maybe_assignment, BuiltinAssignment) and matchers.matches(
call.func, matchers.Name(value="open")
Expand Down Expand Up @@ -242,8 +241,7 @@ def _find_direct_name_assignment_targets(
name_targets = []
accesses = self.find_accesses(name)
for node in (access.node for access in accesses):
maybe_assigned = self.is_value_of_assignment(node)
if maybe_assigned:
if maybe_assigned := self.is_value_of_assignment(node):
targets = extract_targets_of_assignment(maybe_assigned)
name_targets.extend(targets)
return name_targets
Expand Down Expand Up @@ -277,8 +275,7 @@ def _sieve_targets(
def _find_transitive_assignment_targets(
self, expr
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
maybe_assigned = self.is_value_of_assignment(expr)
if maybe_assigned:
if maybe_assigned := self.is_value_of_assignment(expr):
named_targets, other_targets = self._sieve_targets(
extract_targets_of_assignment(maybe_assigned)
)
Expand Down
11 changes: 4 additions & 7 deletions src/core_codemods/flask_json_response_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ def _is_tuple_with_json_string_response(
case cst.Tuple():
elements = node.elements
first = elements[0].value
maybe_vuln = self._is_json_dumps_call(
if maybe_vuln := self._is_json_dumps_call(
first
) or self._is_make_response_with_json(first)
if maybe_vuln:
) or self._is_make_response_with_json(first):
return node
return None

Expand Down Expand Up @@ -153,8 +152,7 @@ def _is_json_dumps_call(self, node: cst.BaseExpression) -> Optional[cst.Call]:
expr = self.resolve_expression(node)
match expr:
case cst.Call():
true_name = self.find_base_name(expr)
if true_name == "json.dumps":
if (true_name := self.find_base_name(expr)) == "json.dumps":
return expr
return None

Expand All @@ -164,8 +162,7 @@ def _is_make_response_with_json(
expr = self.resolve_expression(node)
match expr:
case cst.Call(args=[cst.Arg(first_arg), *_]):
true_name = self.find_base_name(expr)
if true_name != "flask.make_response":
if (true_name := self.find_base_name(expr)) != "flask.make_response":
return None
match first_arg:
case cst.Tuple():
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def on_result_found(
maybe_name = self.get_aliased_prefix_name(
original_node, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def flask_app_is_assigned(self):
return bool(self.flask_app_name)

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
true_name = self.find_base_name(original_node.func)
if true_name == "flask.Flask":
if (true_name := self.find_base_name(original_node.func)) == "flask.Flask":
self._store_flask_app(original_node)

if self.flask_app_is_assigned and self._is_config_update_call(original_node):
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
return False

def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]:
assignment = self.find_single_assignment(node)
if assignment:
if assignment := self.find_single_assignment(node):
base_scope = assignment.scope
# TODO make this check in detect injection, to be more precise

Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/tempfile_mktemp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class TempfileMktemp(SimpleCodemod, NameResolutionMixin):

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_call_target(updated_node, maybe_name, "mkstemp")
3 changes: 1 addition & 2 deletions src/core_codemods/upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(
original_node.value, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2")

0 comments on commit 582f7eb

Please sign in to comment.