1package source // import "gotest.tools/internal/source"
2
3import (
4 "bytes"
5 "fmt"
6 "go/ast"
7 "go/format"
8 "go/parser"
9 "go/token"
10 "os"
11 "runtime"
12 "strconv"
13 "strings"
14
15 "github.com/pkg/errors"
16)
17
18const baseStackIndex = 1
19
20// FormattedCallExprArg returns the argument from an ast.CallExpr at the
21// index in the call stack. The argument is formatted using FormatNode.
22func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
23 args, err := CallExprArgs(stackIndex + 1)
24 if err != nil {
25 return "", err
26 }
27 return FormatNode(args[argPos])
28}
29
30func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
31 fileset := token.NewFileSet()
32 astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
33 if err != nil {
34 return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
35 }
36
37 node := scanToLine(fileset, astFile, lineNum)
38 if node == nil {
39 return nil, errors.Errorf(
40 "failed to find an expression on line %d in %s", lineNum, filename)
41 }
42 return node, nil
43}
44
45func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
46 v := &scanToLineVisitor{lineNum: lineNum, fileset: fileset}
47 ast.Walk(v, node)
48 return v.matchedNode
49}
50
51type scanToLineVisitor struct {
52 lineNum int
53 matchedNode ast.Node
54 fileset *token.FileSet
55}
56
57func (v *scanToLineVisitor) Visit(node ast.Node) ast.Visitor {
58 if node == nil || v.matchedNode != nil {
59 return nil
60 }
61 if v.nodePosition(node).Line == v.lineNum {
62 v.matchedNode = node
63 return nil
64 }
65 return v
66}
67
68// In golang 1.9 the line number changed from being the line where the statement
69// ended to the line where the statement began.
70func (v *scanToLineVisitor) nodePosition(node ast.Node) token.Position {
71 if goVersionBefore19 {
72 return v.fileset.Position(node.End())
73 }
74 return v.fileset.Position(node.Pos())
75}
76
77var goVersionBefore19 = isGOVersionBefore19()
78
79func isGOVersionBefore19() bool {
80 version := runtime.Version()
81 // not a release version
82 if !strings.HasPrefix(version, "go") {
83 return false
84 }
85 version = strings.TrimPrefix(version, "go")
86 parts := strings.Split(version, ".")
87 if len(parts) < 2 {
88 return false
89 }
90 minor, err := strconv.ParseInt(parts[1], 10, 32)
91 return err == nil && parts[0] == "1" && minor < 9
92}
93
94func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
95 visitor := &callExprVisitor{}
96 ast.Walk(visitor, node)
97 if visitor.expr == nil {
98 return nil, errors.New("failed to find call expression")
99 }
100 return visitor.expr.Args, nil
101}
102
103type callExprVisitor struct {
104 expr *ast.CallExpr
105}
106
107func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
108 if v.expr != nil || node == nil {
109 return nil
110 }
111 debug("visit (%T): %s", node, debugFormatNode{node})
112
113 if callExpr, ok := node.(*ast.CallExpr); ok {
114 v.expr = callExpr
115 return nil
116 }
117 return v
118}
119
120// FormatNode using go/format.Node and return the result as a string
121func FormatNode(node ast.Node) (string, error) {
122 buf := new(bytes.Buffer)
123 err := format.Node(buf, token.NewFileSet(), node)
124 return buf.String(), err
125}
126
127// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
128// the index in the call stack.
129func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
130 _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
131 if !ok {
132 return nil, errors.New("failed to get call stack")
133 }
134 debug("call stack position: %s:%d", filename, lineNum)
135
136 node, err := getNodeAtLine(filename, lineNum)
137 if err != nil {
138 return nil, err
139 }
140 debug("found node (%T): %s", node, debugFormatNode{node})
141
142 return getCallExprArgs(node)
143}
144
145var debugEnabled = os.Getenv("GOTESTYOURSELF_DEBUG") != ""
146
147func debug(format string, args ...interface{}) {
148 if debugEnabled {
149 fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
150 }
151}
152
153type debugFormatNode struct {
154 ast.Node
155}
156
157func (n debugFormatNode) String() string {
158 out, err := FormatNode(n.Node)
159 if err != nil {
160 return fmt.Sprintf("failed to format %s: %s", n.Node, err)
161 }
162 return out
163}