Skip to content

Commit

Permalink
Joda-Time to Java time: Add support for Method Return Statement Migra…
Browse files Browse the repository at this point in the history
…tion (#626)

* Joda-Time to Java time: Add support for Method Return Expression Migration

* Add few tests

* remove foo test

* formatting
  • Loading branch information
amishra-u authored Dec 13, 2024
1 parent ccb2010 commit 2c654e8
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public JodaTimeVisitor getVisitor(Accumulator acc) {
@Getter
public static class Accumulator {
private final Set<NamedVariable> unsafeVars = new HashSet<>();
private final Map<JavaType.Method, Boolean> safeMethodMap = new HashMap<>();
private final VarTable varTable = new VarTable();
}

Expand Down
168 changes: 139 additions & 29 deletions src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
Expand All @@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor {

private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();
private final Map<JavaType.Method, Set<NamedVariable>> methodReferencedVars = new HashMap<>();
private final Map<JavaType.Method, Set<UnresolvedVar>> methodUnresolvedReferencedVars = new HashMap<>();

public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
super(new LinkedList<>());
Expand All @@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
dfs(var, allReachable);
}
acc.getUnsafeVars().addAll(allReachable);

Set<JavaType.Method> unsafeMethods = new HashSet<>();
acc.getSafeMethodMap().forEach((method, isSafe) -> {
if (!isSafe) {
unsafeMethods.add(method);
return;
}
Set<NamedVariable> intersection = new HashSet<>(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
intersection.retainAll(acc.getUnsafeVars());
if (!intersection.isEmpty()) {
unsafeMethods.add(method);
}
});
for (JavaType.Method method : unsafeMethods) {
acc.getSafeMethodMap().put(method, false);
acc.getUnsafeVars().addAll(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
}
return cu;
}

@Override
public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) {
public J visitVariable(NamedVariable variable, ExecutionContext ctx) {
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return (NamedVariable) super.visitVariable(variable, ctx);
return super.visitVariable(variable, ctx);
}
// TODO: handle class variables
if (isClassVar(variable)) {
Expand Down Expand Up @@ -96,35 +116,112 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
}

@Override
public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
Expression var = assignment.getVariable();
// not joda expr or not local variable
if (!isJodaExpr(var) || !(var instanceof J.Identifier)) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
J.Identifier ident = (J.Identifier) var;
Optional<NamedVariable> mayBeVar = findVarInScope(ident.getSimpleName());
if (!mayBeVar.isPresent()) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
NamedVariable variable = mayBeVar.get();
Cursor varScope = findScope(variable);
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow());
processMarkersOnExpression(sinks, variable);
return assignment;
return super.visitAssignment(assignment, ctx);
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
acc.getVarTable().addVars(method);
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
if (var != null) { // var can only be null if method is not correctly type attributed
acc.getUnsafeVars().add(var);
}
});
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
Set<UnresolvedVar> unresolvedVars = methodUnresolvedReferencedVars.remove(method.getMethodType());
if (unresolvedVars != null) {
unresolvedVars.forEach(var -> {
NamedVariable namedVar = acc.getVarTable().getVarByName(var.getDeclaringType(), var.getVarName());
if (namedVar != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(namedVar);
}
});
}
return super.visitMethodDeclaration(method, ctx);
}

@Override
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
if (!isJodaExpr(method) || method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return super.visitMethodInvocation(method, ctx);
}
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
.visit(boundary.getValue(), ctx, boundary.getParentTreeCursor());

boolean isSafe = j != boundary.getValue();
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
J parent = boundary.getParentTreeCursor().getValue();
if (parent instanceof NamedVariable) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add((NamedVariable) parent);
}
if (parent instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) parent;
if (assignment.getVariable() instanceof J.Identifier) {
J.Identifier ident = (J.Identifier) assignment.getVariable();
findVarInScope(ident.getSimpleName())
.map(var -> methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var));
}
}
if (parent instanceof MethodCall) {
MethodCall parentMethod = (MethodCall) parent;
int argPos = parentMethod.getArguments().indexOf(boundary.getValue());
if (argPos == -1) {
return method;
}
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
NamedVariable var = acc.getVarTable().getVarByName(parentMethod.getMethodType(), paramName);
if (var != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var);
} else {
methodUnresolvedReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add(new UnresolvedVar(parentMethod.getMethodType(), paramName));
}
}
return method;
}

@Override
public J.Return visitReturn(J.Return _return, ExecutionContext ctx) {
if (_return.getExpression() == null) {
return _return;
}
Expression expr = _return.getExpression();
if (!isJodaExpr(expr)) {
return _return;
}
J methodOrLambda = getCursor().dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda).getValue();
if (methodOrLambda instanceof J.Lambda) {
return _return;
}
J.MethodDeclaration method = (J.MethodDeclaration) methodOrLambda;
Expression updatedExpr = (Expression) new JodaTimeVisitor(acc, true, scopes)
.visit(expr, ctx, getCursor().getParentTreeCursor());
boolean isSafe = !isJodaExpr(updatedExpr);

addReferencedVars(expr, method.getMethodType());
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
if (!isSafe) {
acc.getUnsafeVars().addAll(methodReferencedVars.get(method.getMethodType()));
}
return _return;
}

private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
Expand All @@ -146,7 +243,23 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
}
}

private boolean isJodaExpr(Expression expression) {
/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private static Cursor findBoundaryCursorForJodaExpr(Cursor cursor) {
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private static boolean isJodaExpr(Expression expression) {
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
}

Expand All @@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set<NamedVariable> visited) {
}
}

private void addReferencedVars(Expression expr, JavaType.Method method) {
Set<@Nullable NamedVariable> referencedVars = new HashSet<>();
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
referencedVars.remove(null);
methodReferencedVars.computeIfAbsent(method, k -> new HashSet<>()).addAll(referencedVars);
}

@RequiredArgsConstructor
private class AddSafeCheckMarker extends JavaIsoVisitor<ExecutionContext> {

Expand Down Expand Up @@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return mayBeMarker.get();
}

Cursor boundary = findBoundaryCursorForJodaExpr();
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
boolean isSafe = true;
// TODO: handle return statement
if (boundary.getParentTreeCursor().getValue() instanceof J.Return) {
isSafe = false;
// TODO: handle return statement in lambda
isSafe = boundary.dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda)
.getValue() instanceof J.MethodDeclaration;
}
Expression boundaryExpr = boundary.getValue();
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
Expand All @@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return new SafeCheckMarker(UUID.randomUUID(), isSafe, referencedVars);
}

/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private Cursor findBoundaryCursorForJodaExpr() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private Optional<Cursor> findArgumentExprCursor() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Expand Down Expand Up @@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy
return super.visitExpression(expression, hasJodaType);
}
}

@Value
private static class UnresolvedVar {
JavaType declaringType;
String varName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
return super.visitCompilationUnit(cu, ctx);
}

@Override
public @NonNull J visitMethodDeclaration(@NonNull J.MethodDeclaration method, @NonNull ExecutionContext ctx) {
J.MethodDeclaration m = (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
if (m.getReturnTypeExpression() == null || !m.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return m;
}
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(m.getMethodType(), false)) {
return m;
}

JavaType.Class returnType = TimeClassMap.getJavaTimeType(((JavaType.Class) m.getType()).getFullyQualifiedName());
J.Identifier returnExpr = TypeTree.build(returnType.getClassName()).withType(returnType).withPrefix(Space.format(" "));
return m.withReturnTypeExpression(returnExpr)
.withMethodType(m.getMethodType().withReturnType(returnType));
}

@Override
public @NonNull J visitVariableDeclarations(@NonNull J.VariableDeclarations multiVariable, @NonNull ExecutionContext ctx) {
if (multiVariable.getTypeExpression() == null || !multiVariable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
Expand Down Expand Up @@ -147,6 +163,13 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
@Override
public @NonNull J visitMethodInvocation(@NonNull J.MethodInvocation method, @NonNull ExecutionContext ctx) {
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);

// internal method with Joda class as return type
if (!method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN) &&
method.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return migrateNonJodaMethod(method, m);
}

if (hasJodaType(m.getArguments()) || isJodaVarRef(m.getSelect())) {
return method;
}
Expand Down Expand Up @@ -179,7 +202,9 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)

JavaType.FullyQualified jodaType = ((JavaType.Class) ident.getType());
JavaType.FullyQualified fqType = TimeClassMap.getJavaTimeType(jodaType.getFullyQualifiedName());

if (fqType == null) {
return ident;
}
return ident.withType(fqType)
.withFieldType(ident.getFieldType().withType(fqType));
}
Expand Down Expand Up @@ -218,6 +243,19 @@ private J migrateMethodCall(MethodCall original, MethodCall updated) {
return original;
}

private J.MethodInvocation migrateNonJodaMethod(J.MethodInvocation original, J.MethodInvocation updated) {
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(updated.getMethodType(), false)) {
return original;
}
JavaType.Class returnType = (JavaType.Class) updated.getMethodType().getReturnType();
JavaType.Class updatedReturnType = TimeClassMap.getJavaTimeType(returnType.getFullyQualifiedName());
if (updatedReturnType == null) {
return original; // unhandled case
}
return updated.withMethodType(updated.getMethodType().withReturnType(updatedReturnType))
.withName(updated.getName().withType(updatedReturnType));
}

private boolean hasJodaType(List<Expression> exprs) {
for (Expression expr : exprs) {
JavaType exprType = expr.getType();
Expand Down
Loading

0 comments on commit 2c654e8

Please sign in to comment.