diff --git a/check_api/src/main/java/com/google/errorprone/util/ASTHelpers.java b/check_api/src/main/java/com/google/errorprone/util/ASTHelpers.java index 5be3fc1bd27..d8c1dffd099 100644 --- a/check_api/src/main/java/com/google/errorprone/util/ASTHelpers.java +++ b/check_api/src/main/java/com/google/errorprone/util/ASTHelpers.java @@ -1724,13 +1724,15 @@ public static TargetType targetType(VisitorState state) { Type type = new TargetTypeVisitor(current, state, parent).visit(parent.getLeaf(), null); if (type == null) { - if (CONSTANT_CASE_LABEL_TREE != null + Tree actualTree = null; + if (YIELD_TREE != null && YIELD_TREE.isAssignableFrom(parent.getLeaf().getClass())) { + actualTree = parent.getParentPath().getParentPath().getParentPath().getLeaf(); + } else if (CONSTANT_CASE_LABEL_TREE != null && CONSTANT_CASE_LABEL_TREE.isAssignableFrom(parent.getLeaf().getClass())) { - type = - getType( - TargetTypeVisitor.getSwitchExpression( - parent.getParentPath().getParentPath().getLeaf())); + actualTree = parent.getParentPath().getParentPath().getLeaf(); } + + type = getType(TargetTypeVisitor.getSwitchExpression(actualTree)); if (type == null) { return null; } @@ -1739,6 +1741,7 @@ public static TargetType targetType(VisitorState state) { } @Nullable private static final Class CONSTANT_CASE_LABEL_TREE = constantCaseLabelTree(); + @Nullable private static final Class YIELD_TREE = yieldTree(); @Nullable private static Class constantCaseLabelTree() { @@ -1749,6 +1752,15 @@ private static Class constantCaseLabelTree() { } } + @Nullable + private static Class yieldTree() { + try { + return Class.forName("com.sun.source.tree.YieldTree"); + } catch (ClassNotFoundException e) { + return null; + } + } + private static boolean canHaveTargetType(Tree tree) { // Anything that isn't an expression can't have a target type. if (!(tree instanceof ExpressionTree)) { @@ -1831,7 +1843,11 @@ public Type visitCase(CaseTree tree, Void unused) { } @Nullable - private static ExpressionTree getSwitchExpression(Tree tree) { + private static ExpressionTree getSwitchExpression(@Nullable Tree tree) { + if (tree == null) { + return null; + } + if (tree instanceof SwitchTree) { return ((SwitchTree) tree).getExpression(); } diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/threadsafety/ImmutableCheckerTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/threadsafety/ImmutableCheckerTest.java index 43c2e09ce5f..801763b502a 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/threadsafety/ImmutableCheckerTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/threadsafety/ImmutableCheckerTest.java @@ -2993,4 +2993,62 @@ public void switchExpressionsResultingInGenericTypes_doesNotThrow() { "}") .doTest(); } + + @Test + public void switchExpressionsYieldStatement_doesNotThrow() { + assumeTrue(RuntimeVersion.isAtLeast14()); + compilationHelper + .addSourceLines( + "Test.java", + "import java.util.function.Supplier;", + "class Test {", + " String test(String mode) {", + " return switch (mode) {", + " case \"random\" -> {", + " yield \"foo\";", + " }", + " default -> throw new IllegalArgumentException();", + " };", + " }", + "}") + .doTest(); + } + + @Test + public void switchExpressionsMethodReference_doesNotThrow() { + assumeTrue(RuntimeVersion.isAtLeast14()); + compilationHelper + .addSourceLines( + "Test.java", + "import java.util.function.Supplier;", + "class Test {", + " Supplier test(String mode) {", + " return switch (mode) {", + " case \"random\" -> Math::random;", + " default -> throw new IllegalArgumentException();", + " };", + " }", + "}") + .doTest(); + } + + @Test + public void switchExpressionsYieldStatementMethodReference_doesNotThrow() { + assumeTrue(RuntimeVersion.isAtLeast14()); + compilationHelper + .addSourceLines( + "Test.java", + "import java.util.function.Supplier;", + "class Test {", + " Supplier test(String mode) {", + " return switch (mode) {", + " case \"random\" -> {", + " yield Math::random;", + " }", + " default -> throw new IllegalArgumentException();", + " };", + " }", + "}") + .doTest(); + } }