compare.go

  1/*Package cmp provides Comparisons for Assert and Check*/
  2package cmp // import "gotest.tools/assert/cmp"
  3
  4import (
  5	"fmt"
  6	"reflect"
  7	"strings"
  8
  9	"github.com/google/go-cmp/cmp"
 10	"gotest.tools/internal/format"
 11)
 12
 13// Comparison is a function which compares values and returns ResultSuccess if
 14// the actual value matches the expected value. If the values do not match the
 15// Result will contain a message about why it failed.
 16type Comparison func() Result
 17
 18// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
 19// and succeeds if the values are equal.
 20//
 21// The comparison can be customized using comparison Options.
 22// Package https://godoc.org/gotest.tools/assert/opt provides some additional
 23// commonly used Options.
 24func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
 25	return func() (result Result) {
 26		defer func() {
 27			if panicmsg, handled := handleCmpPanic(recover()); handled {
 28				result = ResultFailure(panicmsg)
 29			}
 30		}()
 31		diff := cmp.Diff(x, y, opts...)
 32		if diff == "" {
 33			return ResultSuccess
 34		}
 35		return multiLineDiffResult(diff)
 36	}
 37}
 38
 39func handleCmpPanic(r interface{}) (string, bool) {
 40	if r == nil {
 41		return "", false
 42	}
 43	panicmsg, ok := r.(string)
 44	if !ok {
 45		panic(r)
 46	}
 47	switch {
 48	case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
 49		return panicmsg, true
 50	}
 51	panic(r)
 52}
 53
 54func toResult(success bool, msg string) Result {
 55	if success {
 56		return ResultSuccess
 57	}
 58	return ResultFailure(msg)
 59}
 60
 61// Equal succeeds if x == y. See assert.Equal for full documentation.
 62func Equal(x, y interface{}) Comparison {
 63	return func() Result {
 64		switch {
 65		case x == y:
 66			return ResultSuccess
 67		case isMultiLineStringCompare(x, y):
 68			diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
 69			return multiLineDiffResult(diff)
 70		}
 71		return ResultFailureTemplate(`
 72			{{- .Data.x}} (
 73				{{- with callArg 0 }}{{ formatNode . }} {{end -}}
 74				{{- printf "%T" .Data.x -}}
 75			) != {{ .Data.y}} (
 76				{{- with callArg 1 }}{{ formatNode . }} {{end -}}
 77				{{- printf "%T" .Data.y -}}
 78			)`,
 79			map[string]interface{}{"x": x, "y": y})
 80	}
 81}
 82
 83func isMultiLineStringCompare(x, y interface{}) bool {
 84	strX, ok := x.(string)
 85	if !ok {
 86		return false
 87	}
 88	strY, ok := y.(string)
 89	if !ok {
 90		return false
 91	}
 92	return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
 93}
 94
 95func multiLineDiffResult(diff string) Result {
 96	return ResultFailureTemplate(`
 97--- {{ with callArg 0 }}{{ formatNode . }}{{else}}{{end}}
 98+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}{{end}}
 99{{ .Data.diff }}`,
100		map[string]interface{}{"diff": diff})
101}
102
103// Len succeeds if the sequence has the expected length.
104func Len(seq interface{}, expected int) Comparison {
105	return func() (result Result) {
106		defer func() {
107			if e := recover(); e != nil {
108				result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
109			}
110		}()
111		value := reflect.ValueOf(seq)
112		length := value.Len()
113		if length == expected {
114			return ResultSuccess
115		}
116		msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
117		return ResultFailure(msg)
118	}
119}
120
121// Contains succeeds if item is in collection. Collection may be a string, map,
122// slice, or array.
123//
124// If collection is a string, item must also be a string, and is compared using
125// strings.Contains().
126// If collection is a Map, contains will succeed if item is a key in the map.
127// If collection is a slice or array, item is compared to each item in the
128// sequence using reflect.DeepEqual().
129func Contains(collection interface{}, item interface{}) Comparison {
130	return func() Result {
131		colValue := reflect.ValueOf(collection)
132		if !colValue.IsValid() {
133			return ResultFailure(fmt.Sprintf("nil does not contain items"))
134		}
135		msg := fmt.Sprintf("%v does not contain %v", collection, item)
136
137		itemValue := reflect.ValueOf(item)
138		switch colValue.Type().Kind() {
139		case reflect.String:
140			if itemValue.Type().Kind() != reflect.String {
141				return ResultFailure("string may only contain strings")
142			}
143			return toResult(
144				strings.Contains(colValue.String(), itemValue.String()),
145				fmt.Sprintf("string %q does not contain %q", collection, item))
146
147		case reflect.Map:
148			if itemValue.Type() != colValue.Type().Key() {
149				return ResultFailure(fmt.Sprintf(
150					"%v can not contain a %v key", colValue.Type(), itemValue.Type()))
151			}
152			return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
153
154		case reflect.Slice, reflect.Array:
155			for i := 0; i < colValue.Len(); i++ {
156				if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
157					return ResultSuccess
158				}
159			}
160			return ResultFailure(msg)
161		default:
162			return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
163		}
164	}
165}
166
167// Panics succeeds if f() panics.
168func Panics(f func()) Comparison {
169	return func() (result Result) {
170		defer func() {
171			if err := recover(); err != nil {
172				result = ResultSuccess
173			}
174		}()
175		f()
176		return ResultFailure("did not panic")
177	}
178}
179
180// Error succeeds if err is a non-nil error, and the error message equals the
181// expected message.
182func Error(err error, message string) Comparison {
183	return func() Result {
184		switch {
185		case err == nil:
186			return ResultFailure("expected an error, got nil")
187		case err.Error() != message:
188			return ResultFailure(fmt.Sprintf(
189				"expected error %q, got %+v", message, err))
190		}
191		return ResultSuccess
192	}
193}
194
195// ErrorContains succeeds if err is a non-nil error, and the error message contains
196// the expected substring.
197func ErrorContains(err error, substring string) Comparison {
198	return func() Result {
199		switch {
200		case err == nil:
201			return ResultFailure("expected an error, got nil")
202		case !strings.Contains(err.Error(), substring):
203			return ResultFailure(fmt.Sprintf(
204				"expected error to contain %q, got %+v", substring, err))
205		}
206		return ResultSuccess
207	}
208}
209
210// Nil succeeds if obj is a nil interface, pointer, or function.
211//
212// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
213// maps, and channels.
214func Nil(obj interface{}) Comparison {
215	msgFunc := func(value reflect.Value) string {
216		return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
217	}
218	return isNil(obj, msgFunc)
219}
220
221func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
222	return func() Result {
223		if obj == nil {
224			return ResultSuccess
225		}
226		value := reflect.ValueOf(obj)
227		kind := value.Type().Kind()
228		if kind >= reflect.Chan && kind <= reflect.Slice {
229			if value.IsNil() {
230				return ResultSuccess
231			}
232			return ResultFailure(msgFunc(value))
233		}
234
235		return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
236	}
237}
238
239// ErrorType succeeds if err is not nil and is of the expected type.
240//
241// Expected can be one of:
242// a func(error) bool which returns true if the error is the expected type,
243// an instance of (or a pointer to) a struct of the expected type,
244// a pointer to an interface the error is expected to implement,
245// a reflect.Type of the expected struct or interface.
246func ErrorType(err error, expected interface{}) Comparison {
247	return func() Result {
248		switch expectedType := expected.(type) {
249		case func(error) bool:
250			return cmpErrorTypeFunc(err, expectedType)
251		case reflect.Type:
252			if expectedType.Kind() == reflect.Interface {
253				return cmpErrorTypeImplementsType(err, expectedType)
254			}
255			return cmpErrorTypeEqualType(err, expectedType)
256		case nil:
257			return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
258		}
259
260		expectedType := reflect.TypeOf(expected)
261		switch {
262		case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
263			return cmpErrorTypeEqualType(err, expectedType)
264		case isPtrToInterface(expectedType):
265			return cmpErrorTypeImplementsType(err, expectedType.Elem())
266		}
267		return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
268	}
269}
270
271func cmpErrorTypeFunc(err error, f func(error) bool) Result {
272	if f(err) {
273		return ResultSuccess
274	}
275	actual := "nil"
276	if err != nil {
277		actual = fmt.Sprintf("%s (%T)", err, err)
278	}
279	return ResultFailureTemplate(`error is {{ .Data.actual }}
280		{{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
281		map[string]interface{}{"actual": actual})
282}
283
284func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
285	if err == nil {
286		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
287	}
288	errValue := reflect.ValueOf(err)
289	if errValue.Type() == expectedType {
290		return ResultSuccess
291	}
292	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
293}
294
295func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
296	if err == nil {
297		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
298	}
299	errValue := reflect.ValueOf(err)
300	if errValue.Type().Implements(expectedType) {
301		return ResultSuccess
302	}
303	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
304}
305
306func isPtrToInterface(typ reflect.Type) bool {
307	return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
308}
309
310func isPtrToStruct(typ reflect.Type) bool {
311	return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
312}