diff --git a/rule/flag_param.go b/rule/flag_param.go index 8156121aa..0a67f3322 100644 --- a/rule/flag_param.go +++ b/rule/flag_param.go @@ -13,58 +13,48 @@ type FlagParamRule struct{} // Apply applies the rule to given file. func (*FlagParamRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure { var failures []lint.Failure - onFailure := func(failure lint.Failure) { failures = append(failures, failure) } - w := lintFlagParamRule{onFailure: onFailure} - ast.Walk(w, file.AST) - return failures -} - -// Name returns the rule name. -func (*FlagParamRule) Name() string { - return "flag-parameter" -} - -type lintFlagParamRule struct { - onFailure func(lint.Failure) -} - -func (w lintFlagParamRule) Visit(node ast.Node) ast.Visitor { - fd, ok := node.(*ast.FuncDecl) - if !ok { - return w - } - - if fd.Body == nil { - return nil // skip whole function declaration - } + for _, decl := range file.AST.Decls { + fd, ok := decl.(*ast.FuncDecl) + isFuncWithNonEmptyBody := ok && fd.Body != nil + if !isFuncWithNonEmptyBody { + continue + } - for _, p := range fd.Type.Params.List { - t := p.Type + boolParams := map[string]struct{}{} + for _, param := range fd.Type.Params.List { + if !isIdent(param.Type, "bool") { + continue + } - id, ok := t.(*ast.Ident) - if !ok { - continue + for _, paramIdent := range param.Names { + boolParams[paramIdent.Name] = struct{}{} + } } - if id.Name != "bool" { + if len(boolParams) == 0 { continue } - cv := conditionVisitor{p.Names, fd, w} + cv := conditionVisitor{boolParams, fd, onFailure} ast.Walk(cv, fd.Body) } - return w + return failures +} + +// Name returns the rule name. +func (*FlagParamRule) Name() string { + return "flag-parameter" } type conditionVisitor struct { - ids []*ast.Ident - fd *ast.FuncDecl - linter lintFlagParamRule + idents map[string]struct{} + fd *ast.FuncDecl + onFailure func(lint.Failure) } func (w conditionVisitor) Visit(node ast.Node) ast.Visitor { @@ -73,28 +63,22 @@ func (w conditionVisitor) Visit(node ast.Node) ast.Visitor { return w } - fselect := func(n ast.Node) bool { + findUsesOfIdents := func(n ast.Node) bool { ident, ok := n.(*ast.Ident) if !ok { return false } - for _, id := range w.ids { - if ident.Name == id.Name { - return true - } - } - - return false + return w.idents[ident.Name] == struct{}{} } - uses := pick(ifStmt.Cond, fselect) + uses := pick(ifStmt.Cond, findUsesOfIdents) if len(uses) < 1 { return w } - w.linter.onFailure(lint.Failure{ + w.onFailure(lint.Failure{ Confidence: 1, Node: w.fd.Type.Params, Category: "bad practice",