diff --git a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java index a46a741bf..643620351 100644 --- a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java +++ b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java @@ -63,7 +63,7 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { @Override public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) { if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) { - return variable; + return (NamedVariable) super.visitVariable(variable, ctx); } // TODO: handle class variables if (isClassVar(variable)) { diff --git a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java index c99fb43bb..d396f3452 100644 --- a/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java +++ b/src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java @@ -249,6 +249,9 @@ private boolean isJodaVarRef(@Nullable Expression expr) { if (expr instanceof J.Identifier) { return ((J.Identifier) expr).getFieldType() != null; } + if (expr instanceof MethodCall) { + return expr.getType().isAssignableFrom(JODA_CLASS_PATTERN); + } return false; } diff --git a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java index 927ce2d3d..96873be82 100644 --- a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java +++ b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeRecipeTest.java @@ -285,4 +285,30 @@ public void bar(DateTime dt) { ) ); } + + @Test + void dontMigrateMethodInvocationIfSelectExprIsNotMigrated() { + //language=java + rewriteRun( + java( + """ + import org.joda.time.Interval; + + class A { + private Query query = new Query(); + public void foo() { + query.interval().getEndMillis(); + } + static class Query { + private Interval interval; + + public Interval interval() { + return interval; + } + } + } + """ + ) + ); + } } diff --git a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java index ca7204814..b5acde58a 100644 --- a/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java +++ b/src/test/java/org/openrewrite/java/migrate/joda/JodaTimeScannerTest.java @@ -21,6 +21,7 @@ import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.openrewrite.java.Assertions.java; @@ -203,4 +204,82 @@ public void bar(DateTime dt) { assertEquals("dt", var.getSimpleName()); } } + + @Test + void detectUnsafeVarsInInitializer() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.Interval; + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class A { + public Interval interval() { + return new Interval(10, 20); + } + + public void foo() { + List list = Stream.of(1, 2, 3).peek(i -> { + Interval i1 = interval(); + Interval i2 = new Interval(i, 100); + if (i1 != null && !i1.contains(i2)) { + System.out.println("i1 does not contain i2"); + } + }).toList(); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) + .hasSize(2) + .containsExactlyInAnyOrder("i1", "i2"); + } + + @Test + void detectUnsafeVarsInChainedLambdaExpressions() { + JodaTimeScanner scanner = new JodaTimeScanner(new JodaTimeRecipe.Accumulator()); + // language=java + rewriteRun( + spec -> spec.recipe(toRecipe(() -> scanner)), + java( + """ + import org.joda.time.Interval; + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class A { + public Interval interval() { + return new Interval(10, 20); + } + + public void foo() { + List list = Stream.of(1, 2, 3).peek(i -> { + Interval i1 = interval(); + Interval i2 = new Interval(i, 100); + if (i1 != null && !i1.contains(i2)) { + System.out.println("i1 does not contain i2"); + } + }).peek(i -> { + Interval i3 = interval(); + Interval i4 = new Interval(i, 100); + if (i3 != null && !i3.contains(i4)) { + System.out.println("i3 does not contain i4"); + } + }).toList(); + } + } + """ + ) + ); + assertThat(scanner.getAcc().getUnsafeVars().stream().map(J.VariableDeclarations.NamedVariable::getSimpleName)) + .hasSize(4) + .containsExactlyInAnyOrder("i1", "i2", "i3", "i4"); + } }