1package assert
2
3import (
4 "fmt"
5 "go/ast"
6
7 "gotest.tools/assert/cmp"
8 "gotest.tools/internal/format"
9 "gotest.tools/internal/source"
10)
11
12func runComparison(
13 t TestingT,
14 argSelector argSelector,
15 f cmp.Comparison,
16 msgAndArgs ...interface{},
17) bool {
18 if ht, ok := t.(helperT); ok {
19 ht.Helper()
20 }
21 result := f()
22 if result.Success() {
23 return true
24 }
25
26 var message string
27 switch typed := result.(type) {
28 case resultWithComparisonArgs:
29 const stackIndex = 3 // Assert/Check, assert, runComparison
30 args, err := source.CallExprArgs(stackIndex)
31 if err != nil {
32 t.Log(err.Error())
33 }
34 message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
35 case resultBasic:
36 message = typed.FailureMessage()
37 default:
38 message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
39 }
40
41 t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
42 return false
43}
44
45type resultWithComparisonArgs interface {
46 FailureMessage(args []ast.Expr) string
47}
48
49type resultBasic interface {
50 FailureMessage() string
51}
52
53// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
54// easy to read when printed and contain relevant information to an assertion.
55//
56// Ident and SelectorExpr are included because they print nicely and the variable
57// names may provide additional context to their values.
58// BasicLit and CompositeLit are excluded because their source is equivalent to
59// their value, which is already available.
60// Other types are ignored for now, but could be added if they are relevant.
61func filterPrintableExpr(args []ast.Expr) []ast.Expr {
62 result := make([]ast.Expr, len(args))
63 for i, arg := range args {
64 if isShortPrintableExpr(arg) {
65 result[i] = arg
66 continue
67 }
68
69 if starExpr, ok := arg.(*ast.StarExpr); ok {
70 result[i] = starExpr.X
71 continue
72 }
73 result[i] = nil
74 }
75 return result
76}
77
78func isShortPrintableExpr(expr ast.Expr) bool {
79 switch expr.(type) {
80 case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
81 return true
82 case *ast.BinaryExpr, *ast.UnaryExpr:
83 return true
84 default:
85 // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
86 return false
87 }
88}
89
90type argSelector func([]ast.Expr) []ast.Expr
91
92func argsAfterT(args []ast.Expr) []ast.Expr {
93 if len(args) < 1 {
94 return nil
95 }
96 return args[1:]
97}
98
99func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
100 if len(args) < 1 {
101 return nil
102 }
103 if callExpr, ok := args[1].(*ast.CallExpr); ok {
104 return callExpr.Args
105 }
106 return nil
107}