diff --git a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeRecipe.java b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeRecipe.java index 2e21cc00f..9c0e9e759 100644 --- a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeRecipe.java +++ b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeRecipe.java @@ -54,6 +54,7 @@ public JodaTimeVisitor getVisitor(Accumulator acc) { @Getter public static class Accumulator { private final Set unsafeVars = new HashSet<>(); + private final Map safeMethodMap = new HashMap<>(); private final VarTable varTable = new VarTable(); } diff --git a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java index 8255f0edd..74df986dd 100644 --- a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java +++ b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java @@ -19,6 +19,7 @@ import lombok.Getter; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; @@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor { private final Map> varDependencies = new HashMap<>(); private final Map> unsafeVarsByType = new HashMap<>(); + private final Map> methodReferencedVars = new HashMap<>(); + private final Map> methodUnresolvedReferencedVars = new HashMap<>(); public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) { super(new LinkedList<>()); @@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { dfs(var, allReachable); } acc.getUnsafeVars().addAll(allReachable); + + Set unsafeMethods = new HashSet<>(); + acc.getSafeMethodMap().forEach((method, isSafe) -> { + if (!isSafe) { + unsafeMethods.add(method); + return; + } + Set intersection = new HashSet<>(methodReferencedVars.getOrDefault(method, Collections.emptySet())); + intersection.retainAll(acc.getUnsafeVars()); + if (!intersection.isEmpty()) { + unsafeMethods.add(method); + } + }); + for (JavaType.Method method : unsafeMethods) { + acc.getSafeMethodMap().put(method, false); + acc.getUnsafeVars().addAll(methodReferencedVars.getOrDefault(method, Collections.emptySet())); + } return cu; } @Override - public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) { + public J visitVariable(NamedVariable variable, ExecutionContext ctx) { if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) { - return (NamedVariable) super.visitVariable(variable, ctx); + return super.visitVariable(variable, ctx); } // TODO: handle class variables if (isClassVar(variable)) { @@ -96,27 +116,27 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) } @Override - public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ctx) { + public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) { Expression var = assignment.getVariable(); // not joda expr or not local variable if (!isJodaExpr(var) || !(var instanceof J.Identifier)) { - return assignment; + return super.visitAssignment(assignment, ctx); } J.Identifier ident = (J.Identifier) var; Optional mayBeVar = findVarInScope(ident.getSimpleName()); if (!mayBeVar.isPresent()) { - return assignment; + return super.visitAssignment(assignment, ctx); } NamedVariable variable = mayBeVar.get(); Cursor varScope = findScope(variable); List sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment())); new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow()); processMarkersOnExpression(sinks, variable); - return assignment; + return super.visitAssignment(assignment, ctx); } @Override - public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { acc.getVarTable().addVars(method); unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> { NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName); @@ -124,7 +144,84 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex acc.getUnsafeVars().add(var); } }); - return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx); + Set unresolvedVars = methodUnresolvedReferencedVars.remove(method.getMethodType()); + if (unresolvedVars != null) { + unresolvedVars.forEach(var -> { + NamedVariable namedVar = acc.getVarTable().getVarByName(var.getDeclaringType(), var.getVarName()); + if (namedVar != null) { + methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(namedVar); + } + }); + } + return super.visitMethodDeclaration(method, ctx); + } + + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + if (!isJodaExpr(method) || method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN)) { + return super.visitMethodInvocation(method, ctx); + } + Cursor boundary = findBoundaryCursorForJodaExpr(getCursor()); + J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes) + .visit(boundary.getValue(), ctx, boundary.getParentTreeCursor()); + + boolean isSafe = j != boundary.getValue(); + acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe); + J parent = boundary.getParentTreeCursor().getValue(); + if (parent instanceof NamedVariable) { + methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()) + .add((NamedVariable) parent); + } + if (parent instanceof J.Assignment) { + J.Assignment assignment = (J.Assignment) parent; + if (assignment.getVariable() instanceof J.Identifier) { + J.Identifier ident = (J.Identifier) assignment.getVariable(); + findVarInScope(ident.getSimpleName()) + .map(var -> methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var)); + } + } + if (parent instanceof MethodCall) { + MethodCall parentMethod = (MethodCall) parent; + int argPos = parentMethod.getArguments().indexOf(boundary.getValue()); + if (argPos == -1) { + return method; + } + String paramName = parentMethod.getMethodType().getParameterNames().get(argPos); + NamedVariable var = acc.getVarTable().getVarByName(parentMethod.getMethodType(), paramName); + if (var != null) { + methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var); + } else { + methodUnresolvedReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()) + .add(new UnresolvedVar(parentMethod.getMethodType(), paramName)); + } + } + return method; + } + + @Override + public J.Return visitReturn(J.Return _return, ExecutionContext ctx) { + if (_return.getExpression() == null) { + return _return; + } + Expression expr = _return.getExpression(); + if (!isJodaExpr(expr)) { + return _return; + } + J methodOrLambda = getCursor().dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda).getValue(); + if (methodOrLambda instanceof J.Lambda) { + return _return; + } + J.MethodDeclaration method = (J.MethodDeclaration) methodOrLambda; + Expression updatedExpr = (Expression) new JodaTimeVisitor(acc, true, scopes) + .visit(expr, ctx, getCursor().getParentTreeCursor()); + boolean isSafe = !isJodaExpr(updatedExpr); + + addReferencedVars(expr, method.getMethodType()); + acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe); + if (!isSafe) { + acc.getUnsafeVars().addAll(methodReferencedVars.get(method.getMethodType())); + } + return _return; } private void processMarkersOnExpression(List expressions, NamedVariable var) { @@ -146,7 +243,23 @@ private void processMarkersOnExpression(List expressions, NamedVaria } } - private boolean isJodaExpr(Expression expression) { + /** + * Traverses the cursor to find the first non-Joda expression in the path. + * If no non-Joda expression is found, it returns the cursor pointing + * to the last Joda expression whose parent is not an Expression. + */ + private static Cursor findBoundaryCursorForJodaExpr(Cursor cursor) { + while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) { + Cursor parent = cursor.getParentTreeCursor(); + if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) { + return cursor; + } + cursor = parent; + } + return cursor; + } + + private static boolean isJodaExpr(Expression expression) { return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN); } @@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set visited) { } } + private void addReferencedVars(Expression expr, JavaType.Method method) { + Set<@Nullable NamedVariable> referencedVars = new HashSet<>(); + new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor()); + referencedVars.remove(null); + methodReferencedVars.computeIfAbsent(method, k -> new HashSet<>()).addAll(referencedVars); + } + @RequiredArgsConstructor private class AddSafeCheckMarker extends JavaIsoVisitor { @@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) { return mayBeMarker.get(); } - Cursor boundary = findBoundaryCursorForJodaExpr(); + Cursor boundary = findBoundaryCursorForJodaExpr(getCursor()); boolean isSafe = true; - // TODO: handle return statement if (boundary.getParentTreeCursor().getValue() instanceof J.Return) { - isSafe = false; + // TODO: handle return statement in lambda + isSafe = boundary.dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda) + .getValue() instanceof J.MethodDeclaration; } Expression boundaryExpr = boundary.getValue(); J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes) @@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) { return new SafeCheckMarker(UUID.randomUUID(), isSafe, referencedVars); } - /** - * Traverses the cursor to find the first non-Joda expression in the path. - * If no non-Joda expression is found, it returns the cursor pointing - * to the last Joda expression whose parent is not an Expression. - */ - private Cursor findBoundaryCursorForJodaExpr() { - Cursor cursor = getCursor(); - while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) { - Cursor parent = cursor.getParentTreeCursor(); - if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) { - return cursor; - } - cursor = parent; - } - return cursor; - } - private Optional findArgumentExprCursor() { Cursor cursor = getCursor(); while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) { @@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy return super.visitExpression(expression, hasJodaType); } } + + @Value + private static class UnresolvedVar { + JavaType declaringType; + String varName; + } } diff --git a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java index 32cc1e5bd..d0b56ccc1 100644 --- a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java +++ b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java @@ -84,6 +84,22 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx) return super.visitCompilationUnit(cu, ctx); } + @Override + public @NonNull J visitMethodDeclaration(@NonNull J.MethodDeclaration method, @NonNull ExecutionContext ctx) { + J.MethodDeclaration m = (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx); + if (m.getReturnTypeExpression() == null || !m.getType().isAssignableFrom(JODA_CLASS_PATTERN)) { + return m; + } + if (safeMigration && !acc.getSafeMethodMap().getOrDefault(m.getMethodType(), false)) { + return m; + } + + JavaType.Class returnType = TimeClassMap.getJavaTimeType(((JavaType.Class) m.getType()).getFullyQualifiedName()); + J.Identifier returnExpr = TypeTree.build(returnType.getClassName()).withType(returnType).withPrefix(Space.format(" ")); + return m.withReturnTypeExpression(returnExpr) + .withMethodType(m.getMethodType().withReturnType(returnType)); + } + @Override public @NonNull J visitVariableDeclarations(@NonNull J.VariableDeclarations multiVariable, @NonNull ExecutionContext ctx) { if (multiVariable.getTypeExpression() == null || !multiVariable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) { @@ -147,6 +163,13 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx) @Override public @NonNull J visitMethodInvocation(@NonNull J.MethodInvocation method, @NonNull ExecutionContext ctx) { J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); + + // internal method with Joda class as return type + if (!method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN) && + method.getType().isAssignableFrom(JODA_CLASS_PATTERN)) { + return migrateNonJodaMethod(method, m); + } + if (hasJodaType(m.getArguments()) || isJodaVarRef(m.getSelect())) { return method; } @@ -179,7 +202,9 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx) JavaType.FullyQualified jodaType = ((JavaType.Class) ident.getType()); JavaType.FullyQualified fqType = TimeClassMap.getJavaTimeType(jodaType.getFullyQualifiedName()); - + if (fqType == null) { + return ident; + } return ident.withType(fqType) .withFieldType(ident.getFieldType().withType(fqType)); } @@ -218,6 +243,19 @@ private J migrateMethodCall(MethodCall original, MethodCall updated) { return original; } + private J.MethodInvocation migrateNonJodaMethod(J.MethodInvocation original, J.MethodInvocation updated) { + if (safeMigration && !acc.getSafeMethodMap().getOrDefault(updated.getMethodType(), false)) { + return original; + } + JavaType.Class returnType = (JavaType.Class) updated.getMethodType().getReturnType(); + JavaType.Class updatedReturnType = TimeClassMap.getJavaTimeType(returnType.getFullyQualifiedName()); + if (updatedReturnType == null) { + return original; // unhandled case + } + return updated.withMethodType(updated.getMethodType().withReturnType(updatedReturnType)) + .withName(updated.getName().withType(updatedReturnType)); + } + private boolean hasJodaType(List exprs) { for (Expression expr : exprs) { JavaType exprType = expr.getType(); diff --git a/src/main/java/org/openrewrite/java/migrate/lombok/LombokUtils.java b/src/main/java/org/openrewrite/java/migrate/lombok/LombokUtils.java new file mode 100644 index 000000000..20c1ae081 --- /dev/null +++ b/src/main/java/org/openrewrite/java/migrate/lombok/LombokUtils.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lombok; + +import lombok.AccessLevel; +import org.jspecify.annotations.Nullable; +import org.openrewrite.internal.StringUtils; +import org.openrewrite.java.tree.Expression; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; + +import static lombok.AccessLevel.*; +import static org.openrewrite.java.tree.J.Modifier.Type.*; + +class LombokUtils { + + static boolean isGetter(J.MethodDeclaration method) { + if (method.getMethodType() == null) { + return false; + } + // Check signature: no parameters + if (!(method.getParameters().get(0) instanceof J.Empty) || method.getReturnTypeExpression() == null) { + return false; + } + // Check body: just a return statement + if (method.getBody() == null || + method.getBody().getStatements().size() != 1 || + !(method.getBody().getStatements().get(0) instanceof J.Return)) { + return false; + } + // Check field is declared on method type + JavaType.FullyQualified declaringType = method.getMethodType().getDeclaringType(); + Expression returnExpression = ((J.Return) method.getBody().getStatements().get(0)).getExpression(); + if (returnExpression instanceof J.Identifier) { + J.Identifier identifier = (J.Identifier) returnExpression; + if (identifier.getFieldType() != null && declaringType == identifier.getFieldType().getOwner()) { + // Check return: type and matching field name + return hasMatchingTypeAndName(method, identifier.getType(), identifier.getSimpleName()); + } + } else if (returnExpression instanceof J.FieldAccess) { + J.FieldAccess fieldAccess = (J.FieldAccess) returnExpression; + Expression target = fieldAccess.getTarget(); + if (target instanceof J.Identifier && ((J.Identifier) target).getFieldType() != null && + declaringType == ((J.Identifier) target).getFieldType().getOwner()) { + // Check return: type and matching field name + return hasMatchingTypeAndName(method, fieldAccess.getType(), fieldAccess.getSimpleName()); + } + } + return false; + } + + private static boolean hasMatchingTypeAndName(J.MethodDeclaration method, @Nullable JavaType type, String simpleName) { + if (method.getType() == type) { + String deriveGetterMethodName = deriveGetterMethodName(type, simpleName); + return method.getSimpleName().equals(deriveGetterMethodName); + } + return false; + } + + private static String deriveGetterMethodName(@Nullable JavaType type, String fieldName) { + if (type == JavaType.Primitive.Boolean) { + boolean alreadyStartsWithIs = fieldName.length() >= 3 && + fieldName.substring(0, 3).matches("is[A-Z]"); + if (alreadyStartsWithIs) { + return fieldName; + } else { + return "is" + StringUtils.capitalize(fieldName); + } + } + return "get" + StringUtils.capitalize(fieldName); + } + + static AccessLevel getAccessLevel(J.MethodDeclaration modifiers) { + if (modifiers.hasModifier(Public)) { + return PUBLIC; + } else if (modifiers.hasModifier(Protected)) { + return PROTECTED; + } else if (modifiers.hasModifier(Private)) { + return PRIVATE; + } + return PACKAGE; + } + +} diff --git a/src/main/java/org/openrewrite/java/migrate/lombok/UseLombokGetter.java b/src/main/java/org/openrewrite/java/migrate/lombok/UseLombokGetter.java new file mode 100644 index 000000000..b8de2d23f --- /dev/null +++ b/src/main/java/org/openrewrite/java/migrate/lombok/UseLombokGetter.java @@ -0,0 +1,109 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lombok; + +import lombok.AccessLevel; +import lombok.EqualsAndHashCode; +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaParser; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.tree.Expression; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static java.util.Comparator.comparing; +import static lombok.AccessLevel.PUBLIC; + +@Value +@EqualsAndHashCode(callSuper = false) +public class UseLombokGetter extends Recipe { + + @Override + public String getDisplayName() { + return "Convert getter methods to annotations"; + } + + @Override + public String getDescription() { + //language=markdown + return "Convert trivial getter methods to `@Getter` annotations on their respective fields."; + } + + @Override + public Set getTags() { + return Collections.singleton("lombok"); + } + + @Override + public TreeVisitor getVisitor() { + return new JavaIsoVisitor() { + @Override + public J.@Nullable MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + if (LombokUtils.isGetter(method)) { + Expression returnExpression = ((J.Return) method.getBody().getStatements().get(0)).getExpression(); + if (returnExpression instanceof J.Identifier && + ((J.Identifier) returnExpression).getFieldType() != null) { + doAfterVisit(new FieldAnnotator( + ((J.Identifier) returnExpression).getFieldType(), + LombokUtils.getAccessLevel(method))); + return null; + } else if (returnExpression instanceof J.FieldAccess && + ((J.FieldAccess) returnExpression).getName().getFieldType() != null) { + doAfterVisit(new FieldAnnotator( + ((J.FieldAccess) returnExpression).getName().getFieldType(), + LombokUtils.getAccessLevel(method))); + return null; + } + } + return method; + } + }; + } + + + @Value + @EqualsAndHashCode(callSuper = false) + static class FieldAnnotator extends JavaIsoVisitor { + + JavaType field; + AccessLevel accessLevel; + + @Override + public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) { + for (J.VariableDeclarations.NamedVariable variable : multiVariable.getVariables()) { + if (variable.getName().getFieldType() == field) { + maybeAddImport("lombok.Getter"); + maybeAddImport("lombok.AccessLevel"); + String suffix = accessLevel == PUBLIC ? "" : String.format("(AccessLevel.%s)", accessLevel.name()); + return JavaTemplate.builder("@Getter" + suffix) + .imports("lombok.Getter", "lombok.AccessLevel") + .javaParser(JavaParser.fromJavaVersion().classpath("lombok")) + .build().apply(getCursor(), multiVariable.getCoordinates().addAnnotation(comparing(J.Annotation::getSimpleName))); + } + } + return multiVariable; + } + } +} diff --git a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java index 9a30c3c02..2265fab7e 100644 --- a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java +++ b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java @@ -198,13 +198,13 @@ private void print(Date date) { } @Test - void localVarUsedReferencedInReturnStatement() { // not supported yet + void localVarUsedReferencedInReturnStatement() { // language=java rewriteRun( java( """ - import org.joda.time.DateTime; - import org.joda.time.DateTimeZone; + import org.joda.time.DateTime; + import org.joda.time.DateTimeZone; class A { public DateTime foo(String city) { @@ -218,6 +218,24 @@ public DateTime foo(String city) { return dt.plus(2); } } + """, + """ + import java.time.Duration; + import java.time.ZoneId; + import java.time.ZonedDateTime; + + class A { + public ZonedDateTime foo(String city) { + ZoneId dtz; + if ("london".equals(city)) { + dtz = ZoneId.of("Europe/London"); + } else { + dtz = ZoneId.of("America/New_York"); + } + ZonedDateTime dt = ZonedDateTime.now(dtz); + return dt.plus(Duration.ofMillis(2)); + } + } """ ) ); @@ -262,6 +280,105 @@ public void bar(ZonedDateTime dt) { ); } + @Test + void migrateMethodWithSafeReturnExpression() { + //language=java + rewriteRun( + java( + """ + import org.joda.time.DateTime; + import org.joda.time.Interval; + + class A { + public DateTime foo(DateTime dt) { + Interval interval = new Interval(dt, dt.plusDays(1)); + return interval.getEnd(); + } + + private static class Bar { + public void bar(DateTime dt) { + DateTime d = foo(dt); + System.out.println(d.getMillis()); + } + } + } + """, + """ + import org.threeten.extra.Interval; + + import java.time.ZoneId; + import java.time.ZonedDateTime; + + class A { + public ZonedDateTime foo(ZonedDateTime dt) { + Interval interval = Interval.of(dt.toInstant(), dt.plusDays(1).toInstant()); + return interval.getEnd().atZone(ZoneId.systemDefault()); + } + + private static class Bar { + public void bar(ZonedDateTime dt) { + ZonedDateTime d = foo(dt); + System.out.println(d.toInstant().toEpochMilli()); + } + } + } + """ + ) + ); + } + + @Test + void migrateMethodWithSafeReturnExpressionAndUnsafeParam() { + //language=java + rewriteRun( + java( + """ + import org.joda.time.DateTime; + import org.joda.time.Interval; + + class A { + public DateTime foo(DateTime dt) { + DateTime d = dt.toDateMidnight(); + DateTime d2 = DateTime.now(); + Interval interval = new Interval(d2, d2.plusDays(1)); + return interval.getEnd(); + } + + private static class Bar { + public void bar() { + DateTime d = foo(new DateTime()); + System.out.println(d.getMillis()); + } + } + } + """, + """ + import org.joda.time.DateTime; + import org.threeten.extra.Interval; + + import java.time.ZoneId; + import java.time.ZonedDateTime; + + class A { + public ZonedDateTime foo(DateTime dt) { + DateTime d = dt.toDateMidnight(); + ZonedDateTime d2 = ZonedDateTime.now(); + Interval interval = Interval.of(d2.toInstant(), d2.plusDays(1).toInstant()); + return interval.getEnd().atZone(ZoneId.systemDefault()); + } + + private static class Bar { + public void bar() { + ZonedDateTime d = foo(new DateTime()); + System.out.println(d.toInstant().toEpochMilli()); + } + } + } + """ + ) + ); + } + @Test void doNotMigrateUnsafeMethodParam() { //language=java diff --git a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java index c402db853..15c484d85 100644 --- a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java +++ b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java @@ -21,6 +21,8 @@ import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -142,7 +144,7 @@ public void foo(String city) { } @Test - void localVarUsedReferencedInReturnStatement() { // not supported yet + void localVarUsedReferencedInReturnStatement() { JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); // language=java rewriteRun( @@ -167,11 +169,7 @@ public DateTime foo(String city) { """ ) ); - // The local variable dt used in return statement. - assertEquals(2, scanner.getAcc().getUnsafeVars().size()); - for (J.VariableDeclarations.NamedVariable var : scanner.getAcc().getUnsafeVars()) { - assertTrue(var.getSimpleName().equals("dtz") || var.getSimpleName().equals("dt")); - } + assertThat(scanner.getAcc().getUnsafeVars()).isEmpty(); } @Test @@ -213,22 +211,22 @@ void detectUnsafeVarsInInitializer() { spec -> spec.recipe(toRecipe(() -> scanner)), java( """ - import org.joda.time.Interval; + import org.joda.time.Period; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.List; class A { - public Interval interval() { - return new Interval(10, 20); + public Period period() { + return new Period(); } public void foo() { List list = Stream.of(1, 2, 3).peek(i -> { - Interval i1 = interval(); - Interval i2 = new Interval(i, 100); - if (i1 != null && !i1.contains(i2)) { - System.out.println("i1 does not contain i2"); + Period p1 = period(); + Period p2 = new Period(i, 100); + if (p1 != null && p1.plus(p2).getDays() > 10) { + System.out.println("Hello world!"); } }).toList(); } @@ -238,7 +236,7 @@ public void foo() { ); assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) .hasSize(2) - .containsExactlyInAnyOrder("i1", "i2"); + .containsExactlyInAnyOrder("p1", "p2"); } @Test @@ -249,28 +247,28 @@ void detectUnsafeVarsInChainedLambdaExpressions() { spec -> spec.recipe(toRecipe(() -> scanner)), java( """ - import org.joda.time.Interval; + import org.joda.time.Period; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.List; class A { - public Interval interval() { - return new Interval(10, 20); + public Period period() { + return new Period(); } public void foo() { List list = Stream.of(1, 2, 3).peek(i -> { - Interval i1 = interval(); - Interval i2 = new Interval(i, 100); - if (i1 != null && !i1.contains(i2)) { - System.out.println("i1 does not contain i2"); + Period p1 = period(); + Period p2 = new Period(i, 100); + if (p1 != null && p1.plus(p2).getDays() > 10) { + System.out.println("Hello world!"); } }).peek(i -> { - Interval i3 = interval(); - Interval i4 = new Interval(i, 100); - if (i3 != null && !i3.contains(i4)) { - System.out.println("i3 does not contain i4"); + Period p3 = period(); + Period p4 = new Period(i, 100); + if (p3 != null && p3.plus(p4).getDays() > 10) { + System.out.println("Hello world!"); } }).toList(); } @@ -280,6 +278,150 @@ public void foo() { ); assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) .hasSize(4) - .containsExactlyInAnyOrder("i1", "i2", "i3", "i4"); + .containsExactlyInAnyOrder("p1", "p2", "p3", "p4"); + } + + @Test + void hasSafeMethods() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.DateTime; + + class A { + private DateTime dateTime() { + DateTime dt = new DateTime(); + return dt; + } + public void print() { + System.out.println(dateTime()); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getSafeMethodMap()).hasSize(1); + assertThat(scanner.getAcc().getSafeMethodMap().entrySet().stream().filter(Map.Entry::getValue).map(e -> e.getKey().toString())) + .containsExactlyInAnyOrder("A{name=dateTime,return=org.joda.time.DateTime,parameters=[]}"); + assertThat(scanner.getAcc().getUnsafeVars()).isEmpty(); + } + + @Test + void methodInvocationBeforeDeclaration() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.DateTime; + + class A { + public void print() { + System.out.println(dateTime()); + } + private DateTime dateTime() { + DateTime dt = new DateTime(); + return dt; + } + } + """ + ) + ); + assertThat(scanner.getAcc().getSafeMethodMap()).hasSize(1); + assertThat(scanner.getAcc().getSafeMethodMap().entrySet().stream().filter(Map.Entry::getValue).map(e -> e.getKey().toString())) + .containsExactlyInAnyOrder("A{name=dateTime,return=org.joda.time.DateTime,parameters=[]}"); + assertThat(scanner.getAcc().getUnsafeVars()).isEmpty(); + } + + @Test + void safeMethodWithUnsafeParam() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.DateTime; + import org.joda.time.Period; + + class A { + private DateTime dateTime(Period period) { + DateTime dt = new DateTime(); + System.out.println(period); + return dt; + } + public void print() { + System.out.println(dateTime(new Period())); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getSafeMethodMap()).hasSize(1); + assertThat(scanner.getAcc().getSafeMethodMap().entrySet().stream().filter(Map.Entry::getValue).map(e -> e.getKey().toString())) + .containsExactlyInAnyOrder("A{name=dateTime,return=org.joda.time.DateTime,parameters=[org.joda.time.Period]}"); + assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) + .containsExactlyInAnyOrder("period"); + } + + @Test + void unsafeMethodDueToUnhandledUsage() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.DateTime; + + class A { + private DateTime dateTime() { + DateTime dt = new DateTime(); + return dt; + } + public void print() { + dateTime().toDateMidnight(); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getSafeMethodMap()).hasSize(1); + assertThat(scanner.getAcc().getSafeMethodMap().entrySet().stream().filter(Map.Entry::getValue)).isEmpty(); + assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) + .containsExactlyInAnyOrder("dt"); + } + + @Test + void unsafeMethodDueToIndirectUnhandledUsage() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.DateTime; + + class A { + private DateTime dateTime() { + DateTime dt = new DateTime(); + return dt; + } + public void print() { + DateTime dt = dateTime(); + dt.toDateMidnight(); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getSafeMethodMap()).hasSize(1); + assertThat(scanner.getAcc().getSafeMethodMap().entrySet().stream().filter(Map.Entry::getValue)).isEmpty(); + assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) + .containsExactlyInAnyOrder("dt", "dt"); } } diff --git a/src/test/java/org/openrewrite/java/migrate/lombok/UseLombokGetterTest.java b/src/test/java/org/openrewrite/java/migrate/lombok/UseLombokGetterTest.java new file mode 100644 index 000000000..ec62019a5 --- /dev/null +++ b/src/test/java/org/openrewrite/java/migrate/lombok/UseLombokGetterTest.java @@ -0,0 +1,531 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lombok; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; + +class UseLombokGetterTest implements RewriteTest { + + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new UseLombokGetter()); + } + + @DocumentExample + @Test + void replaceGetter() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + public int getFoo() { + return foo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + int foo = 9; + } + """ + ) + ); + } + + @Test + void replaceGetterWithFieldAccess() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + public int getFoo() { + return this.foo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + int foo = 9; + } + """ + ) + ); + } + + @Test + void replaceGetterWithMultiVariable() { + // Technically this adds a new public getter not there previously, but we'll tolerate it + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9, bar = 10; + + public int getFoo() { + return foo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + int foo = 9, bar = 10; + } + """ + ) + ); + } + + @Test + void replacePackageGetter() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + int getFoo() { + return foo; + } + } + """, + """ + import lombok.AccessLevel; + import lombok.Getter; + + class A { + + @Getter(AccessLevel.PACKAGE) + int foo = 9; + } + """ + ) + ); + } + + @Test + void replaceProtectedGetter() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + protected int getFoo() { + return foo; + } + } + """, + """ + import lombok.AccessLevel; + import lombok.Getter; + + class A { + + @Getter(AccessLevel.PROTECTED) + int foo = 9; + } + """ + ) + ); + } + + @Test + void replacePrivateGetter() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + private int getFoo() { + return foo; + } + } + """, + """ + import lombok.AccessLevel; + import lombok.Getter; + + class A { + + @Getter(AccessLevel.PRIVATE) + int foo = 9; + } + """ + ) + ); + } + + @Test + void replaceJustTheMatchingGetter() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + int ba; + + public A() { + ba = 1; + } + + public int getFoo() { + return foo; + } + + public int getMoo() {//method name wrong + return ba; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + int foo = 9; + + int ba; + + public A() { + ba = 1; + } + + public int getMoo() {//method name wrong + return ba; + } + } + """ + ) + ); + } + + @Test + void noChangeWhenMethodNameDoesntMatch() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + public A() { + } + + int getfoo() {//method name wrong + return foo; + } + } + """ + ) + ); + } + + @Test + void noChangeWhenReturnTypeDoesntMatch() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + + public A() { + } + + long getFoo() { //return type wrong + return foo; + } + } + """ + ) + ); + } + + @Test + void noChangeWhenFieldIsNotReturned() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + int ba = 10; + + public A() { + } + + int getFoo() { + return 5;//does not return variable + } + } + """ + ) + ); + } + + @Test + void noChangeWhenDifferentFieldIsReturned() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + int ba = 10; + + public A() { + } + + int getFoo() { + return ba;//returns wrong variable + } + } + """ + ) + ); + } + + @Test + void noChangeWhenSideEffects1() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + int ba = 10; + + public A() { + } + + int getfoo() { + foo++;//does extra stuff + return foo; + } + } + """ + ) + ); + } + + @Test + void noChangeWhenSideEffects2() { + rewriteRun(// language=java + java( + """ + class A { + + int foo = 9; + int ba = 10; + + public A() { + } + + int getFoo() { + ba++;//does unrelated extra stuff + return foo; + } + } + """ + ) + ); + } + + @Test + void replacePrimitiveBoolean() { + rewriteRun(// language=java + java( + """ + class A { + + boolean foo = true; + + public boolean isFoo() { + return foo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + boolean foo = true; + } + """ + ) + ); + } + + @Test + void replacePrimitiveBooleanStartingWithIs() { + rewriteRun(// language=java + java( + """ + class A { + + boolean isFoo = true; + + public boolean isFoo() { + return isFoo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + boolean isFoo = true; + } + """ + ) + ); + } + + @Test + void noChangeWhenPrimitiveBooleanUsesGet() { + rewriteRun(// language=java + java( + """ + class A { + + boolean foo = true; + + boolean getFoo() { + return foo; + } + } + """ + ) + ); + } + + @Test + void replaceBoolean() { + rewriteRun(// language=java + java( + """ + class A { + + Boolean foo = true; + + public Boolean getFoo() { + return foo; + } + } + """, + """ + import lombok.Getter; + + class A { + + @Getter + Boolean foo = true; + } + """ + ) + ); + } + + @Test + void noChangeWhenBooleanUsesIs() { + rewriteRun(// language=java + java( + """ + class A { + + Boolean foo = true; + + Boolean isFoo() { + return foo; + } + } + """ + ) + ); + } + + @Test + void noChangeWhenBooleanUsesIs2() { + rewriteRun(// language=java + java( + """ + class A { + + Boolean isfoo = true; + + Boolean isFoo() { + return isfoo; + } + } + """ + ) + ); + } + + @Test + void noChangeNestedClassGetter() { + rewriteRun(// language=java + java( + """ + class Outer { + int foo = 9; + + class Inner { + public int getFoo() { + return foo; + } + } + } + """ + ) + ); + } +}