diff --git a/rule/unchecked-type-assertion.go b/rule/unchecked-type-assertion.go index f5801083b..10b9ff060 100644 --- a/rule/unchecked-type-assertion.go +++ b/rule/unchecked-type-assertion.go @@ -96,8 +96,34 @@ func (w *lintUnchekedTypeAssertion) requireNoTypeAssert(expr ast.Expr) { } } +func (w *lintUnchekedTypeAssertion) handleIfStmt(n *ast.IfStmt) { + ifCondition, ok := n.Cond.(*ast.BinaryExpr) + if !ok { + return + } + + w.requireNoTypeAssert(ifCondition.X) + w.requireNoTypeAssert(ifCondition.Y) +} + +func (w *lintUnchekedTypeAssertion) requireBinaryExpressionWithoutTypeAssertion(expr ast.Expr) { + binaryExpr, ok := expr.(*ast.BinaryExpr) + if ok { + w.requireNoTypeAssert(binaryExpr.X) + w.requireNoTypeAssert(binaryExpr.Y) + } +} + +func (w *lintUnchekedTypeAssertion) handleCaseClause(n *ast.CaseClause) { + for _, expr := range n.List { + w.requireNoTypeAssert(expr) + w.requireBinaryExpressionWithoutTypeAssertion(expr) + } +} + func (w *lintUnchekedTypeAssertion) handleSwitch(n *ast.SwitchStmt) { w.requireNoTypeAssert(n.Tag) + w.requireBinaryExpressionWithoutTypeAssertion(n.Tag) } func (w *lintUnchekedTypeAssertion) handleAssignment(n *ast.AssignStmt) { @@ -144,6 +170,10 @@ func (w *lintUnchekedTypeAssertion) Visit(node ast.Node) ast.Visitor { w.handleReturn(n) case *ast.AssignStmt: w.handleAssignment(n) + case *ast.IfStmt: + w.handleIfStmt(n) + case *ast.CaseClause: + w.handleCaseClause(n) } return w diff --git a/testdata/unchecked-type-assertion.go b/testdata/unchecked-type-assertion.go index b52ed176f..f14b624ed 100644 --- a/testdata/unchecked-type-assertion.go +++ b/testdata/unchecked-type-assertion.go @@ -52,7 +52,51 @@ func handleTypeSwitchWithAssignment() { } } -func handleTypeSwitchReturn() { - // Should not be a lint - return foo.(type) +func handleTypeComparison() { + if foo.(int) == 1 { // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + return + } +} + +func handleTypeComparisonReverse() { + if foo.(int) == 1 { // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + return + } +} + +func handleTypeAssignmentComparison() { + var value any + value = 42 // int + + if v := value.(int); v == 42 { // MATCH /type cast result is unchecked in value.(int) - type assertion will panic if not matched/ + fmt.Printf("Value is an integer: %d\n", v) + } +} + +func handleSwitchComparison() { + switch foo.(int) == 1 { // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + case true: + case false: + } +} + +func handleSwitchComparisonReverse() { + switch 1 == foo.(int) { // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + case true: + case false: + } +} + +func handleInnerSwitchAssertion() { + switch { + case foo.(int) == 1: // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + case bar.(int) == 1: // MATCH /type cast result is unchecked in bar.(int) - type assertion will panic if not matched/ + } +} + +func handleInnerSwitchAssertionReverse() { + switch { + case 1 == foo.(int): // MATCH /type cast result is unchecked in foo.(int) - type assertion will panic if not matched/ + case 1 == bar.(int): // MATCH /type cast result is unchecked in bar.(int) - type assertion will panic if not matched/ + } }