source.go

  1package source // import "gotest.tools/internal/source"
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"go/ast"
  7	"go/format"
  8	"go/parser"
  9	"go/token"
 10	"os"
 11	"runtime"
 12	"strconv"
 13	"strings"
 14
 15	"github.com/pkg/errors"
 16)
 17
 18const baseStackIndex = 1
 19
 20// FormattedCallExprArg returns the argument from an ast.CallExpr at the
 21// index in the call stack. The argument is formatted using FormatNode.
 22func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
 23	args, err := CallExprArgs(stackIndex + 1)
 24	if err != nil {
 25		return "", err
 26	}
 27	return FormatNode(args[argPos])
 28}
 29
 30func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
 31	fileset := token.NewFileSet()
 32	astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
 33	if err != nil {
 34		return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
 35	}
 36
 37	node := scanToLine(fileset, astFile, lineNum)
 38	if node == nil {
 39		return nil, errors.Errorf(
 40			"failed to find an expression on line %d in %s", lineNum, filename)
 41	}
 42	return node, nil
 43}
 44
 45func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
 46	v := &scanToLineVisitor{lineNum: lineNum, fileset: fileset}
 47	ast.Walk(v, node)
 48	return v.matchedNode
 49}
 50
 51type scanToLineVisitor struct {
 52	lineNum     int
 53	matchedNode ast.Node
 54	fileset     *token.FileSet
 55}
 56
 57func (v *scanToLineVisitor) Visit(node ast.Node) ast.Visitor {
 58	if node == nil || v.matchedNode != nil {
 59		return nil
 60	}
 61	if v.nodePosition(node).Line == v.lineNum {
 62		v.matchedNode = node
 63		return nil
 64	}
 65	return v
 66}
 67
 68// In golang 1.9 the line number changed from being the line where the statement
 69// ended to the line where the statement began.
 70func (v *scanToLineVisitor) nodePosition(node ast.Node) token.Position {
 71	if goVersionBefore19 {
 72		return v.fileset.Position(node.End())
 73	}
 74	return v.fileset.Position(node.Pos())
 75}
 76
 77var goVersionBefore19 = isGOVersionBefore19()
 78
 79func isGOVersionBefore19() bool {
 80	version := runtime.Version()
 81	// not a release version
 82	if !strings.HasPrefix(version, "go") {
 83		return false
 84	}
 85	version = strings.TrimPrefix(version, "go")
 86	parts := strings.Split(version, ".")
 87	if len(parts) < 2 {
 88		return false
 89	}
 90	minor, err := strconv.ParseInt(parts[1], 10, 32)
 91	return err == nil && parts[0] == "1" && minor < 9
 92}
 93
 94func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
 95	visitor := &callExprVisitor{}
 96	ast.Walk(visitor, node)
 97	if visitor.expr == nil {
 98		return nil, errors.New("failed to find call expression")
 99	}
100	return visitor.expr.Args, nil
101}
102
103type callExprVisitor struct {
104	expr *ast.CallExpr
105}
106
107func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
108	if v.expr != nil || node == nil {
109		return nil
110	}
111	debug("visit (%T): %s", node, debugFormatNode{node})
112
113	if callExpr, ok := node.(*ast.CallExpr); ok {
114		v.expr = callExpr
115		return nil
116	}
117	return v
118}
119
120// FormatNode using go/format.Node and return the result as a string
121func FormatNode(node ast.Node) (string, error) {
122	buf := new(bytes.Buffer)
123	err := format.Node(buf, token.NewFileSet(), node)
124	return buf.String(), err
125}
126
127// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
128// the index in the call stack.
129func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
130	_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
131	if !ok {
132		return nil, errors.New("failed to get call stack")
133	}
134	debug("call stack position: %s:%d", filename, lineNum)
135
136	node, err := getNodeAtLine(filename, lineNum)
137	if err != nil {
138		return nil, err
139	}
140	debug("found node (%T): %s", node, debugFormatNode{node})
141
142	return getCallExprArgs(node)
143}
144
145var debugEnabled = os.Getenv("GOTESTYOURSELF_DEBUG") != ""
146
147func debug(format string, args ...interface{}) {
148	if debugEnabled {
149		fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
150	}
151}
152
153type debugFormatNode struct {
154	ast.Node
155}
156
157func (n debugFormatNode) String() string {
158	out, err := FormatNode(n.Node)
159	if err != nil {
160		return fmt.Sprintf("failed to format %s: %s", n.Node, err)
161	}
162	return out
163}