Skip to content

Commit 609e200

Browse files
Release v0.3.0
- Fixed default configuration - Improved error type resolution - Added suggested fixes
1 parent 24ec80b commit 609e200

File tree

4 files changed

+167
-32
lines changed

4 files changed

+167
-32
lines changed

README.md

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ linters-settings:
3838
wrapperFunctions:
3939
- pkg: github.com/pkg/errors
4040
names: [ New, Errorf, Wrap, Wrapf, WithStack ]
41+
replaceWith: WithMessage
42+
replaceWithFormat: WithMessagef
4143
cleanFunctions:
4244
- pkg: errors
4345
names: [ New ]
44-
- pkg: fmt
45-
names: [ Errorf ]
4646
- pkg: github.com/pkg/errors
4747
names: [ WithMessage, WithMessagef ]
4848
@@ -63,24 +63,15 @@ You can configure ErrStack using the `.errstack.yaml` file in your project root,
6363
# If you want to ignore some functions, simply don't add them to the list.
6464
wrapperFunctions:
6565
- pkg: github.com/pkg/errors
66-
names:
67-
- New
68-
- Errorf
69-
- Wrap
70-
- Wrapf
71-
- WithStack
66+
names: [ New, Errorf, Wrap, Wrapf, WithStack ]
67+
replaceWith: WithMessage # Optional. Attempts to replace errors.Wrap like functions with errors.WithMessage.
68+
replaceWithFormat: WithMessagef # Optional. Attempts to replace errors.Wrapf like functions with errors.WithMessagef.
7269
# List of functions that are considered to clean errors without stacktrace.
7370
cleanFunctions:
74-
- pkg: errors
75-
names:
76-
- New
77-
- pkg: fmt
78-
names:
79-
- Errorf
80-
- pkg: github.com/pkg/errors
81-
names:
82-
- WithMessage
83-
- WithMessagef
71+
- pkg: errors
72+
names: [ New ]
73+
- pkg: github.com/pkg/errors
74+
names: [ WithMessage, WithMessagef ]
8475
```
8576

8677
## Usage

internal/config/config.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@ package config
22

33
var (
44
DefaultWrapperFunctions = []PkgFunctions{
5-
{Pkg: "github.com/pkg/errors", Names: []string{
6-
"New", "Errorf", "Wrap", "Wrapf", "WithStack",
7-
}},
5+
{
6+
Pkg: "github.com/pkg/errors",
7+
Names: []string{
8+
"New", "Errorf", "Wrap", "Wrapf", "WithStack",
9+
},
10+
ReplaceWith: "WithMessage",
11+
ReplaceWithFormat: "WithMessagef",
12+
},
813
}
914
DefaultCleanFunctions = []PkgFunctions{
1015
{Pkg: "github.com/pkg/errors", Names: []string{
@@ -13,9 +18,6 @@ var (
1318
{Pkg: "errors", Names: []string{
1419
"New", "Wrapf", "WithStack",
1520
}},
16-
{Pkg: "fmt", Names: []string{
17-
"Errorf",
18-
}},
1921
}
2022
)
2123

internal/config/pkg_functions.go

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
package config
22

3-
import "slices"
3+
import (
4+
"slices"
5+
"strings"
6+
)
47

58
type PkgFunctions struct {
6-
Pkg string `mapstructure:"pkg" yaml:"pkg"`
7-
Names []string `mapstructure:"names" yaml:"names"`
9+
Pkg string `mapstructure:"pkg" yaml:"pkg"`
10+
Names []string `mapstructure:"names" yaml:"names"`
11+
ReplaceWith string `mapstructure:"replaceWith" yaml:"replaceWith"`
12+
ReplaceWithFormat string `mapstructure:"replaceWithFormat" yaml:"replaceWithFormat"`
813
}
914

1015
type PkgsFunctions []PkgFunctions
@@ -19,3 +24,31 @@ func (pkgFunctions PkgsFunctions) Match(pkg, name string) bool {
1924

2025
return false
2126
}
27+
28+
// ReplaceWith returns new formatted node with replaced function name.
29+
func (pkgFunctions PkgsFunctions) ReplaceWith(pkg, name, text string) string {
30+
for _, item := range pkgFunctions {
31+
if item.Pkg == pkg && slices.Contains(item.Names, name) {
32+
if item.ReplaceWith == "" {
33+
return ""
34+
}
35+
return strings.Replace(text, name, item.ReplaceWith, 1)
36+
}
37+
}
38+
39+
return ""
40+
}
41+
42+
// ReplaceWithFunction returns new formatted node with replaced function name.
43+
func (pkgFunctions PkgsFunctions) ReplaceWithFunction(pkg, name, text string) string {
44+
for _, item := range pkgFunctions {
45+
if item.Pkg == pkg && slices.Contains(item.Names, name) {
46+
if item.ReplaceWithFormat == "" {
47+
return ""
48+
}
49+
return strings.Replace(text, name, item.ReplaceWithFormat, 1)
50+
}
51+
}
52+
53+
return ""
54+
}

internal/passes/errstack/pass.go

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package errstack
22

33
import (
4+
"fmt"
45
"go/ast"
56
"go/token"
7+
"go/types"
68
"reflect"
79
"slices"
810

@@ -197,6 +199,9 @@ func (res *Result) analyzeOriginalFunctionBlock(
197199
return
198200
}
199201
info := model.NewInfo(pass)
202+
matchWrapping := res.conf.WrapperFunctions.Match
203+
replaceWith := res.conf.WrapperFunctions.ReplaceWith
204+
replaceWithFunction := res.conf.WrapperFunctions.ReplaceWithFunction
200205

201206
visited[block] = true
202207
log.Log("Visiting block %v\n", block)
@@ -210,7 +215,7 @@ func (res *Result) analyzeOriginalFunctionBlock(
210215
switch node := n.(type) {
211216
case *ast.CallExpr:
212217
fn := res.TryAddCallExpr(info, cfgs, node)
213-
if fn == nil || !fn.IsWrapping {
218+
if fn == nil || !matchWrapping(fn.Pkg, fn.Name) {
214219
return true
215220
}
216221
var wrapping bool
@@ -223,7 +228,55 @@ func (res *Result) analyzeOriginalFunctionBlock(
223228
if wrapping {
224229
fn.IsWrapping = true
225230
log.Log("Node unnecessarily wraps error with stacktrace %s\n", info.FormatNode(node))
226-
pass.Reportf(node.Pos(), "%s call unnecessarily wraps error with stacktrace. Replace with errors.WithMessage() or fmt.Errorf()", fn.Name)
231+
errorArgument := res.getErrorArgument(cfgs, info, node)
232+
var fixes []analysis.SuggestedFix
233+
if errorArgument != nil {
234+
if len(node.Args) == 1 {
235+
fixes = []analysis.SuggestedFix{
236+
{
237+
Message: "Remove unnecessary error wrapping",
238+
TextEdits: []analysis.TextEdit{
239+
{
240+
Pos: node.Pos(),
241+
End: node.End(),
242+
NewText: []byte(info.FormatNode(errorArgument)),
243+
},
244+
},
245+
},
246+
}
247+
} else {
248+
message := "Replace unnecessary error wrapping"
249+
newText := info.FormatNode(node)
250+
if len(node.Args) == 2 {
251+
newText = replaceWith(fn.Pkg, fn.Name, newText)
252+
} else {
253+
newText = replaceWithFunction(fn.Pkg, fn.Name, newText)
254+
}
255+
if newText != "" {
256+
fixes = []analysis.SuggestedFix{
257+
{
258+
Message: message,
259+
TextEdits: []analysis.TextEdit{
260+
{
261+
Pos: node.Pos(),
262+
End: node.End(),
263+
NewText: []byte(newText),
264+
},
265+
},
266+
},
267+
}
268+
}
269+
}
270+
}
271+
pass.Report(analysis.Diagnostic{
272+
Pos: node.Pos(),
273+
End: node.End(),
274+
Category: "",
275+
Message: fmt.Sprintf("%s call unnecessarily wraps error with stacktrace. Replace with errors.WithMessage() or fmt.Errorf()", fn.Name),
276+
URL: "",
277+
SuggestedFixes: fixes,
278+
Related: nil,
279+
})
227280
}
228281
return true
229282
}
@@ -243,7 +296,7 @@ func (res *Result) analyzeOriginalFunctionBlock(
243296
for i, expr := range assignStmt.Lhs {
244297
if id, idOk := expr.(*ast.Ident); idOk && id != nil {
245298
obj := info.Types.ObjectOf(id)
246-
if obj == nil || obj.Type().String() != "error" {
299+
if !isObjectError(obj) {
247300
continue
248301
}
249302
objPos := info.Fset.Position(obj.Pos())
@@ -336,8 +389,8 @@ func (res *Result) analyzeCallStack(
336389
case *ast.Ident:
337390
log.Log("Ident %s\n", info.FormatNode(node))
338391
if obj := info.Types.ObjectOf(node); obj != nil {
339-
log.Log("Ident Object %s\n", obj.Type().String())
340-
if obj.Type().String() == "error" {
392+
log.Log("Ident Object error\n")
393+
if isObjectError(obj) {
341394
log.Log("Ident Object is error\n")
342395
if variables[info.Fset.Position(obj.Pos())] {
343396
log.Log("Ident Object is error and variables[%t]\n", variables[info.Fset.Position(obj.Pos())])
@@ -355,3 +408,59 @@ func (res *Result) analyzeCallStack(
355408
}
356409
return nil
357410
}
411+
412+
func (res *Result) getErrorArgument(cfgs *ctrlflow.CFGs, info *model.Info, call *ast.CallExpr) ast.Expr {
413+
if len(call.Args) == 0 {
414+
return nil
415+
}
416+
for _, rootArg := range call.Args {
417+
untypedArg := rootArg
418+
for {
419+
switch arg := untypedArg.(type) {
420+
case *ast.Ident:
421+
obj := info.Types.ObjectOf(arg)
422+
if isObjectError(obj) {
423+
return rootArg
424+
}
425+
case *ast.CallExpr:
426+
fn := res.TryAddCallExpr(info, cfgs, arg)
427+
if fn != nil {
428+
return rootArg
429+
}
430+
case *ast.StarExpr:
431+
untypedArg = arg.X
432+
case *ast.ParenExpr:
433+
untypedArg = arg.X
434+
case *ast.SelectorExpr:
435+
untypedArg = arg.Sel
436+
case *ast.IndexExpr:
437+
untypedArg = arg.X
438+
default:
439+
break
440+
}
441+
}
442+
}
443+
444+
return nil
445+
}
446+
447+
func isObjectError(obj types.Object) bool {
448+
if obj == nil {
449+
return false
450+
}
451+
452+
t := obj.Type()
453+
var underlying types.Type
454+
for t != nil {
455+
if t.String() == "error" {
456+
return true
457+
}
458+
underlying = t.Underlying()
459+
if underlying == t {
460+
break
461+
}
462+
t = underlying
463+
}
464+
465+
return false
466+
}

0 commit comments

Comments
 (0)