Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Assignment Expression (Walrus) In Conditional #220

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class ValidatedCodmods(CsvListAction):

def validate_items(self, items):
potential_names = ids + names
unrecognized_codemods = [
name for name in items if name not in potential_names
]

if unrecognized_codemods:
if unrecognized_codemods := [
name for name in items if name not in potential_names
]:
args = {
"values": unrecognized_codemods,
"choices": ", ".join(map(repr, names)),
Expand Down
6 changes: 2 additions & 4 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,8 @@ def __init__(self, context: CodemodContext) -> None:
MetadataDependent.__init__(self)
MatcherDecoratableTransformer.__init__(self)
self.context = context
dependencies = self.get_inherited_dependencies()
if dependencies:
wrapper = self.context.wrapper
if wrapper is None:
if dependencies := self.get_inherited_dependencies():
if (wrapper := self.context.wrapper) is None:
raise ValueError(
f"Attempting to instantiate {self.__class__.__name__} outside of "
+ "an active transform. This means that metadata hasn't been "
Expand Down
12 changes: 4 additions & 8 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def find_base_name(self, node) -> Optional[str]:
return self._find_imported_name(node)

case cst.Attribute():
maybe_name = self.find_base_name(node.value)
if maybe_name:
if maybe_name := self.find_base_name(node.value):
return maybe_name + "." + node.attr.value

case cst.Call():
Expand Down Expand Up @@ -182,8 +181,7 @@ def find_used_names_in_module(self):
Find all the used names in the scope of a libcst Module.
"""
names = []
scope = self.find_global_scope()
if scope is None:
if (scope := self.find_global_scope()) is None:
return [] # pragma: no cover

nodes = [x.node for x in scope.assignments]
Expand Down Expand Up @@ -276,8 +274,7 @@ def is_builtin_function(self, node: cst.Call):
return False

def find_accesses(self, node) -> Collection[Access]:
scope = self.get_metadata(ScopeProvider, node, None)
if scope:
if scope := self.get_metadata(ScopeProvider, node, None):
return scope.accesses[node]
return {}

Expand Down Expand Up @@ -465,8 +462,7 @@ def resolve_expression(self, node: cst.BaseExpression) -> cst.BaseExpression:
maybe_expr = None
match node:
case cst.Name():
maybe_expr = self._resolve_name_transitive(node)
if maybe_expr:
if maybe_expr := self._resolve_name_transitive(node):
return maybe_expr
return node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def add_to_file(
def write(
self, dependencies: list[Dependency], dry_run: bool = False
) -> Optional[ChangeSet]:
new_dependencies = self.add(dependencies)
if new_dependencies:
if new_dependencies := self.add(dependencies):
return self.add_to_file(new_dependencies, dry_run)
return None

Expand Down
3 changes: 1 addition & 2 deletions src/codemodder/dependency_management/setup_py_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def visit_Module(self, _: cst.Module) -> bool:
return is_setup_py_file(self.filename)

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
true_name = self.find_base_name(original_node.func)
if true_name != "setuptools.setup":
if self.find_base_name(original_node.func) != "setuptools.setup":
return original_node

new_args = self.replace_arg(original_node)
Expand Down
6 changes: 2 additions & 4 deletions src/codemodder/dependency_management/setupcfg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@


def find_leading_whitespace(s):
match = re.match(r"(\s+)", s)
if match:
if match := re.match(r"(\s+)", s):
return match.group(1)
return "" # pragma: no cover

Expand Down Expand Up @@ -81,8 +80,7 @@ def build_new_lines(
"""
clean_lines = [s.strip() for s in original_lines]

newline_separated = len(defined_dependencies.split("\n")) > 1
if newline_separated:
if newline_separated := len(defined_dependencies.split("\n")) > 1:
last_dep_line = defined_dependencies.split("\n")[-1]
dep_sep = "\n"
else:
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/django_receiver_on_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def leave_FunctionDef(
# that that have different start/end numbers.
maybe_receiver_with_index = None
for i, decorator in enumerate(original_node.decorators):
true_name = self.find_base_name(decorator.decorator)
if true_name == "django.dispatch.receiver":
if self.find_base_name(decorator.decorator) == "django.dispatch.receiver":
maybe_receiver_with_index = (i, decorator)

if maybe_receiver_with_index:
Expand Down
9 changes: 3 additions & 6 deletions src/core_codemods/file_resource_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def _is_resource_call(self, value) -> Optional[cst.Call]:
return None

def _is_resource(self, call: cst.Call) -> bool:
maybe_assignment = self.find_single_assignment(call)
if maybe_assignment:
if maybe_assignment := self.find_single_assignment(call):
# is open call
if isinstance(maybe_assignment, BuiltinAssignment) and matchers.matches(
call.func, matchers.Name(value="open")
Expand Down Expand Up @@ -242,8 +241,7 @@ def _find_direct_name_assignment_targets(
name_targets = []
accesses = self.find_accesses(name)
for node in (access.node for access in accesses):
maybe_assigned = self.is_value_of_assignment(node)
if maybe_assigned:
if maybe_assigned := self.is_value_of_assignment(node):
targets = extract_targets_of_assignment(maybe_assigned)
name_targets.extend(targets)
return name_targets
Expand Down Expand Up @@ -277,8 +275,7 @@ def _sieve_targets(
def _find_transitive_assignment_targets(
self, expr
) -> tuple[list[cst.Name], list[cst.BaseAssignTargetExpression]]:
maybe_assigned = self.is_value_of_assignment(expr)
if maybe_assigned:
if maybe_assigned := self.is_value_of_assignment(expr):
named_targets, other_targets = self._sieve_targets(
extract_targets_of_assignment(maybe_assigned)
)
Expand Down
11 changes: 4 additions & 7 deletions src/core_codemods/flask_json_response_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ def _is_tuple_with_json_string_response(
case cst.Tuple():
elements = node.elements
first = elements[0].value
maybe_vuln = self._is_json_dumps_call(
if self._is_json_dumps_call(first) or self._is_make_response_with_json(
first
) or self._is_make_response_with_json(first)
if maybe_vuln:
):
return node
return None

Expand Down Expand Up @@ -153,8 +152,7 @@ def _is_json_dumps_call(self, node: cst.BaseExpression) -> Optional[cst.Call]:
expr = self.resolve_expression(node)
match expr:
case cst.Call():
true_name = self.find_base_name(expr)
if true_name == "json.dumps":
if self.find_base_name(expr) == "json.dumps":
return expr
return None

Expand All @@ -164,8 +162,7 @@ def _is_make_response_with_json(
expr = self.resolve_expression(node)
match expr:
case cst.Call(args=[cst.Arg(first_arg), *_]):
true_name = self.find_base_name(expr)
if true_name != "flask.make_response":
if self.find_base_name(expr) != "flask.make_response":
return None
match first_arg:
case cst.Tuple():
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def on_result_found(
maybe_name = self.get_aliased_prefix_name(
original_node, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
new_args = [
*updated_node.args[:1],
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def flask_app_is_assigned(self):
return bool(self.flask_app_name)

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
true_name = self.find_base_name(original_node.func)
if true_name == "flask.Flask":
if self.find_base_name(original_node.func) == "flask.Flask":
self._store_flask_app(original_node)

if self.flask_app_is_assigned and self._is_config_update_call(original_node):
Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
return False

def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]:
assignment = self.find_single_assignment(node)
if assignment:
if assignment := self.find_single_assignment(node):
base_scope = assignment.scope
# TODO make this check in detect injection, to be more precise

Expand Down
3 changes: 1 addition & 2 deletions src/core_codemods/tempfile_mktemp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class TempfileMktemp(SimpleCodemod, NameResolutionMixin):

def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_call_target(updated_node, maybe_name, "mkstemp")
3 changes: 1 addition & 2 deletions src/core_codemods/upgrade_sslcontext_minimum_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def on_result_found(self, original_node, updated_node):
maybe_name = self.get_aliased_prefix_name(
original_node.value, self._module_name
)
maybe_name = maybe_name or self._module_name
if maybe_name == self._module_name:
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2")
Loading