Skip to content

Commit

Permalink
Merge branch 'main' into lombok/getter
Browse files Browse the repository at this point in the history
  • Loading branch information
timtebeek authored Dec 13, 2024
2 parents 6766243 + 2c654e8 commit 74ddd7d
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright 2024 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openrewrite.java.migrate;

import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.ChangeType;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.search.FindMethods;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.TypeUtils;

public class IllegalArgumentExceptionToAlreadyConnectedException extends Recipe {

private static final String ILLEGAL_ARGUMENT_EXCEPTION = "java.lang.IllegalArgumentException";
private static final String ALREADY_CONNECTED_EXCEPTION = "java.nio.channels.AlreadyConnectedException";

@Override
public String getDisplayName() {
return "Replace `IllegalArgumentException` with `AlreadyConnectedException` in `DatagramChannel.send()` method";
}

@Override
public String getDescription() {
return "Replace `IllegalArgumentException` with `AlreadyConnectedException` for DatagramChannel.send() to ensure compatibility with Java 11+.";
}

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
String datagramChannelSendMethodPattern = "java.nio.channels.DatagramChannel send(java.nio.ByteBuffer, java.net.SocketAddress)";
return Preconditions.check(new UsesMethod<>(datagramChannelSendMethodPattern), new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.Try visitTry(J.Try tryStatement, ExecutionContext ctx) {
J.Try try_ = super.visitTry(tryStatement, ctx);
if (FindMethods.find(try_, datagramChannelSendMethodPattern).isEmpty()) {
return try_;
}
return try_.withCatches(ListUtils.map(try_.getCatches(), catch_ -> {
if (TypeUtils.isOfClassType(catch_.getParameter().getType(), ILLEGAL_ARGUMENT_EXCEPTION)) {
maybeAddImport(ALREADY_CONNECTED_EXCEPTION);
return (J.Try.Catch) new ChangeType(ILLEGAL_ARGUMENT_EXCEPTION, ALREADY_CONNECTED_EXCEPTION, true)
.getVisitor().visit(catch_, ctx);
}
return catch_;
}));
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public JodaTimeVisitor getVisitor(Accumulator acc) {
@Getter
public static class Accumulator {
private final Set<NamedVariable> unsafeVars = new HashSet<>();
private final Map<JavaType.Method, Boolean> safeMethodMap = new HashMap<>();
private final VarTable varTable = new VarTable();
}

Expand Down
168 changes: 139 additions & 29 deletions src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
Expand All @@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor {

private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();
private final Map<JavaType.Method, Set<NamedVariable>> methodReferencedVars = new HashMap<>();
private final Map<JavaType.Method, Set<UnresolvedVar>> methodUnresolvedReferencedVars = new HashMap<>();

public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
super(new LinkedList<>());
Expand All @@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
dfs(var, allReachable);
}
acc.getUnsafeVars().addAll(allReachable);

Set<JavaType.Method> unsafeMethods = new HashSet<>();
acc.getSafeMethodMap().forEach((method, isSafe) -> {
if (!isSafe) {
unsafeMethods.add(method);
return;
}
Set<NamedVariable> intersection = new HashSet<>(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
intersection.retainAll(acc.getUnsafeVars());
if (!intersection.isEmpty()) {
unsafeMethods.add(method);
}
});
for (JavaType.Method method : unsafeMethods) {
acc.getSafeMethodMap().put(method, false);
acc.getUnsafeVars().addAll(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
}
return cu;
}

@Override
public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) {
public J visitVariable(NamedVariable variable, ExecutionContext ctx) {
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return (NamedVariable) super.visitVariable(variable, ctx);
return super.visitVariable(variable, ctx);
}
// TODO: handle class variables
if (isClassVar(variable)) {
Expand Down Expand Up @@ -96,35 +116,112 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
}

@Override
public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
Expression var = assignment.getVariable();
// not joda expr or not local variable
if (!isJodaExpr(var) || !(var instanceof J.Identifier)) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
J.Identifier ident = (J.Identifier) var;
Optional<NamedVariable> mayBeVar = findVarInScope(ident.getSimpleName());
if (!mayBeVar.isPresent()) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
NamedVariable variable = mayBeVar.get();
Cursor varScope = findScope(variable);
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow());
processMarkersOnExpression(sinks, variable);
return assignment;
return super.visitAssignment(assignment, ctx);
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
acc.getVarTable().addVars(method);
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
if (var != null) { // var can only be null if method is not correctly type attributed
acc.getUnsafeVars().add(var);
}
});
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
Set<UnresolvedVar> unresolvedVars = methodUnresolvedReferencedVars.remove(method.getMethodType());
if (unresolvedVars != null) {
unresolvedVars.forEach(var -> {
NamedVariable namedVar = acc.getVarTable().getVarByName(var.getDeclaringType(), var.getVarName());
if (namedVar != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(namedVar);
}
});
}
return super.visitMethodDeclaration(method, ctx);
}

@Override
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
if (!isJodaExpr(method) || method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return super.visitMethodInvocation(method, ctx);
}
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
.visit(boundary.getValue(), ctx, boundary.getParentTreeCursor());

boolean isSafe = j != boundary.getValue();
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
J parent = boundary.getParentTreeCursor().getValue();
if (parent instanceof NamedVariable) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add((NamedVariable) parent);
}
if (parent instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) parent;
if (assignment.getVariable() instanceof J.Identifier) {
J.Identifier ident = (J.Identifier) assignment.getVariable();
findVarInScope(ident.getSimpleName())
.map(var -> methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var));
}
}
if (parent instanceof MethodCall) {
MethodCall parentMethod = (MethodCall) parent;
int argPos = parentMethod.getArguments().indexOf(boundary.getValue());
if (argPos == -1) {
return method;
}
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
NamedVariable var = acc.getVarTable().getVarByName(parentMethod.getMethodType(), paramName);
if (var != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var);
} else {
methodUnresolvedReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add(new UnresolvedVar(parentMethod.getMethodType(), paramName));
}
}
return method;
}

@Override
public J.Return visitReturn(J.Return _return, ExecutionContext ctx) {
if (_return.getExpression() == null) {
return _return;
}
Expression expr = _return.getExpression();
if (!isJodaExpr(expr)) {
return _return;
}
J methodOrLambda = getCursor().dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda).getValue();
if (methodOrLambda instanceof J.Lambda) {
return _return;
}
J.MethodDeclaration method = (J.MethodDeclaration) methodOrLambda;
Expression updatedExpr = (Expression) new JodaTimeVisitor(acc, true, scopes)
.visit(expr, ctx, getCursor().getParentTreeCursor());
boolean isSafe = !isJodaExpr(updatedExpr);

addReferencedVars(expr, method.getMethodType());
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
if (!isSafe) {
acc.getUnsafeVars().addAll(methodReferencedVars.get(method.getMethodType()));
}
return _return;
}

private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
Expand All @@ -146,7 +243,23 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
}
}

private boolean isJodaExpr(Expression expression) {
/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private static Cursor findBoundaryCursorForJodaExpr(Cursor cursor) {
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private static boolean isJodaExpr(Expression expression) {
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
}

Expand All @@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set<NamedVariable> visited) {
}
}

private void addReferencedVars(Expression expr, JavaType.Method method) {
Set<@Nullable NamedVariable> referencedVars = new HashSet<>();
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
referencedVars.remove(null);
methodReferencedVars.computeIfAbsent(method, k -> new HashSet<>()).addAll(referencedVars);
}

@RequiredArgsConstructor
private class AddSafeCheckMarker extends JavaIsoVisitor<ExecutionContext> {

Expand Down Expand Up @@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return mayBeMarker.get();
}

Cursor boundary = findBoundaryCursorForJodaExpr();
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
boolean isSafe = true;
// TODO: handle return statement
if (boundary.getParentTreeCursor().getValue() instanceof J.Return) {
isSafe = false;
// TODO: handle return statement in lambda
isSafe = boundary.dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda)
.getValue() instanceof J.MethodDeclaration;
}
Expression boundaryExpr = boundary.getValue();
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
Expand All @@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return new SafeCheckMarker(UUID.randomUUID(), isSafe, referencedVars);
}

/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private Cursor findBoundaryCursorForJodaExpr() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private Optional<Cursor> findArgumentExprCursor() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Expand Down Expand Up @@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy
return super.visitExpression(expression, hasJodaType);
}
}

@Value
private static class UnresolvedVar {
JavaType declaringType;
String varName;
}
}
Loading

0 comments on commit 74ddd7d

Please sign in to comment.