11package errstack
22
33import (
4+ "fmt"
45 "go/ast"
56 "go/token"
7+ "go/types"
68 "reflect"
79 "slices"
810
@@ -197,6 +199,9 @@ func (res *Result) analyzeOriginalFunctionBlock(
197199 return
198200 }
199201 info := model .NewInfo (pass )
202+ matchWrapping := res .conf .WrapperFunctions .Match
203+ replaceWith := res .conf .WrapperFunctions .ReplaceWith
204+ replaceWithFunction := res .conf .WrapperFunctions .ReplaceWithFunction
200205
201206 visited [block ] = true
202207 log .Log ("Visiting block %v\n " , block )
@@ -210,7 +215,7 @@ func (res *Result) analyzeOriginalFunctionBlock(
210215 switch node := n .(type ) {
211216 case * ast.CallExpr :
212217 fn := res .TryAddCallExpr (info , cfgs , node )
213- if fn == nil || ! fn .IsWrapping {
218+ if fn == nil || ! matchWrapping ( fn .Pkg , fn . Name ) {
214219 return true
215220 }
216221 var wrapping bool
@@ -223,7 +228,55 @@ func (res *Result) analyzeOriginalFunctionBlock(
223228 if wrapping {
224229 fn .IsWrapping = true
225230 log .Log ("Node unnecessarily wraps error with stacktrace %s\n " , info .FormatNode (node ))
226- pass .Reportf (node .Pos (), "%s call unnecessarily wraps error with stacktrace. Replace with errors.WithMessage() or fmt.Errorf()" , fn .Name )
231+ errorArgument := res .getErrorArgument (cfgs , info , node )
232+ var fixes []analysis.SuggestedFix
233+ if errorArgument != nil {
234+ if len (node .Args ) == 1 {
235+ fixes = []analysis.SuggestedFix {
236+ {
237+ Message : "Remove unnecessary error wrapping" ,
238+ TextEdits : []analysis.TextEdit {
239+ {
240+ Pos : node .Pos (),
241+ End : node .End (),
242+ NewText : []byte (info .FormatNode (errorArgument )),
243+ },
244+ },
245+ },
246+ }
247+ } else {
248+ message := "Replace unnecessary error wrapping"
249+ newText := info .FormatNode (node )
250+ if len (node .Args ) == 2 {
251+ newText = replaceWith (fn .Pkg , fn .Name , newText )
252+ } else {
253+ newText = replaceWithFunction (fn .Pkg , fn .Name , newText )
254+ }
255+ if newText != "" {
256+ fixes = []analysis.SuggestedFix {
257+ {
258+ Message : message ,
259+ TextEdits : []analysis.TextEdit {
260+ {
261+ Pos : node .Pos (),
262+ End : node .End (),
263+ NewText : []byte (newText ),
264+ },
265+ },
266+ },
267+ }
268+ }
269+ }
270+ }
271+ pass .Report (analysis.Diagnostic {
272+ Pos : node .Pos (),
273+ End : node .End (),
274+ Category : "" ,
275+ Message : fmt .Sprintf ("%s call unnecessarily wraps error with stacktrace. Replace with errors.WithMessage() or fmt.Errorf()" , fn .Name ),
276+ URL : "" ,
277+ SuggestedFixes : fixes ,
278+ Related : nil ,
279+ })
227280 }
228281 return true
229282 }
@@ -243,7 +296,7 @@ func (res *Result) analyzeOriginalFunctionBlock(
243296 for i , expr := range assignStmt .Lhs {
244297 if id , idOk := expr .(* ast.Ident ); idOk && id != nil {
245298 obj := info .Types .ObjectOf (id )
246- if obj == nil || obj . Type (). String () != "error" {
299+ if ! isObjectError ( obj ) {
247300 continue
248301 }
249302 objPos := info .Fset .Position (obj .Pos ())
@@ -336,8 +389,8 @@ func (res *Result) analyzeCallStack(
336389 case * ast.Ident :
337390 log .Log ("Ident %s\n " , info .FormatNode (node ))
338391 if obj := info .Types .ObjectOf (node ); obj != nil {
339- log .Log ("Ident Object %s \n " , obj . Type (). String () )
340- if obj . Type (). String () == "error" {
392+ log .Log ("Ident Object error \n " )
393+ if isObjectError ( obj ) {
341394 log .Log ("Ident Object is error\n " )
342395 if variables [info .Fset .Position (obj .Pos ())] {
343396 log .Log ("Ident Object is error and variables[%t]\n " , variables [info .Fset .Position (obj .Pos ())])
@@ -355,3 +408,59 @@ func (res *Result) analyzeCallStack(
355408 }
356409 return nil
357410}
411+
412+ func (res * Result ) getErrorArgument (cfgs * ctrlflow.CFGs , info * model.Info , call * ast.CallExpr ) ast.Expr {
413+ if len (call .Args ) == 0 {
414+ return nil
415+ }
416+ for _ , rootArg := range call .Args {
417+ untypedArg := rootArg
418+ for {
419+ switch arg := untypedArg .(type ) {
420+ case * ast.Ident :
421+ obj := info .Types .ObjectOf (arg )
422+ if isObjectError (obj ) {
423+ return rootArg
424+ }
425+ case * ast.CallExpr :
426+ fn := res .TryAddCallExpr (info , cfgs , arg )
427+ if fn != nil {
428+ return rootArg
429+ }
430+ case * ast.StarExpr :
431+ untypedArg = arg .X
432+ case * ast.ParenExpr :
433+ untypedArg = arg .X
434+ case * ast.SelectorExpr :
435+ untypedArg = arg .Sel
436+ case * ast.IndexExpr :
437+ untypedArg = arg .X
438+ default :
439+ break
440+ }
441+ }
442+ }
443+
444+ return nil
445+ }
446+
447+ func isObjectError (obj types.Object ) bool {
448+ if obj == nil {
449+ return false
450+ }
451+
452+ t := obj .Type ()
453+ var underlying types.Type
454+ for t != nil {
455+ if t .String () == "error" {
456+ return true
457+ }
458+ underlying = t .Underlying ()
459+ if underlying == t {
460+ break
461+ }
462+ t = underlying
463+ }
464+
465+ return false
466+ }
0 commit comments