result.go

  1package assert
  2
  3import (
  4	"fmt"
  5	"go/ast"
  6
  7	"gotest.tools/assert/cmp"
  8	"gotest.tools/internal/format"
  9	"gotest.tools/internal/source"
 10)
 11
 12func runComparison(
 13	t TestingT,
 14	argSelector argSelector,
 15	f cmp.Comparison,
 16	msgAndArgs ...interface{},
 17) bool {
 18	if ht, ok := t.(helperT); ok {
 19		ht.Helper()
 20	}
 21	result := f()
 22	if result.Success() {
 23		return true
 24	}
 25
 26	var message string
 27	switch typed := result.(type) {
 28	case resultWithComparisonArgs:
 29		const stackIndex = 3 // Assert/Check, assert, runComparison
 30		args, err := source.CallExprArgs(stackIndex)
 31		if err != nil {
 32			t.Log(err.Error())
 33		}
 34		message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
 35	case resultBasic:
 36		message = typed.FailureMessage()
 37	default:
 38		message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
 39	}
 40
 41	t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
 42	return false
 43}
 44
 45type resultWithComparisonArgs interface {
 46	FailureMessage(args []ast.Expr) string
 47}
 48
 49type resultBasic interface {
 50	FailureMessage() string
 51}
 52
 53// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
 54// easy to read when printed and contain relevant information to an assertion.
 55//
 56// Ident and SelectorExpr are included because they print nicely and the variable
 57// names may provide additional context to their values.
 58// BasicLit and CompositeLit are excluded because their source is equivalent to
 59// their value, which is already available.
 60// Other types are ignored for now, but could be added if they are relevant.
 61func filterPrintableExpr(args []ast.Expr) []ast.Expr {
 62	result := make([]ast.Expr, len(args))
 63	for i, arg := range args {
 64		if isShortPrintableExpr(arg) {
 65			result[i] = arg
 66			continue
 67		}
 68
 69		if starExpr, ok := arg.(*ast.StarExpr); ok {
 70			result[i] = starExpr.X
 71			continue
 72		}
 73		result[i] = nil
 74	}
 75	return result
 76}
 77
 78func isShortPrintableExpr(expr ast.Expr) bool {
 79	switch expr.(type) {
 80	case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
 81		return true
 82	case *ast.BinaryExpr, *ast.UnaryExpr:
 83		return true
 84	default:
 85		// CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
 86		return false
 87	}
 88}
 89
 90type argSelector func([]ast.Expr) []ast.Expr
 91
 92func argsAfterT(args []ast.Expr) []ast.Expr {
 93	if len(args) < 1 {
 94		return nil
 95	}
 96	return args[1:]
 97}
 98
 99func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
100	if len(args) < 1 {
101		return nil
102	}
103	if callExpr, ok := args[1].(*ast.CallExpr); ok {
104		return callExpr.Args
105	}
106	return nil
107}