1/*Package cmp provides Comparisons for Assert and Check*/
2package cmp // import "gotest.tools/assert/cmp"
3
4import (
5 "fmt"
6 "reflect"
7 "strings"
8
9 "github.com/google/go-cmp/cmp"
10 "gotest.tools/internal/format"
11)
12
13// Comparison is a function which compares values and returns ResultSuccess if
14// the actual value matches the expected value. If the values do not match the
15// Result will contain a message about why it failed.
16type Comparison func() Result
17
18// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
19// and succeeds if the values are equal.
20//
21// The comparison can be customized using comparison Options.
22// Package https://godoc.org/gotest.tools/assert/opt provides some additional
23// commonly used Options.
24func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
25 return func() (result Result) {
26 defer func() {
27 if panicmsg, handled := handleCmpPanic(recover()); handled {
28 result = ResultFailure(panicmsg)
29 }
30 }()
31 diff := cmp.Diff(x, y, opts...)
32 if diff == "" {
33 return ResultSuccess
34 }
35 return multiLineDiffResult(diff)
36 }
37}
38
39func handleCmpPanic(r interface{}) (string, bool) {
40 if r == nil {
41 return "", false
42 }
43 panicmsg, ok := r.(string)
44 if !ok {
45 panic(r)
46 }
47 switch {
48 case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
49 return panicmsg, true
50 }
51 panic(r)
52}
53
54func toResult(success bool, msg string) Result {
55 if success {
56 return ResultSuccess
57 }
58 return ResultFailure(msg)
59}
60
61// Equal succeeds if x == y. See assert.Equal for full documentation.
62func Equal(x, y interface{}) Comparison {
63 return func() Result {
64 switch {
65 case x == y:
66 return ResultSuccess
67 case isMultiLineStringCompare(x, y):
68 diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
69 return multiLineDiffResult(diff)
70 }
71 return ResultFailureTemplate(`
72 {{- .Data.x}} (
73 {{- with callArg 0 }}{{ formatNode . }} {{end -}}
74 {{- printf "%T" .Data.x -}}
75 ) != {{ .Data.y}} (
76 {{- with callArg 1 }}{{ formatNode . }} {{end -}}
77 {{- printf "%T" .Data.y -}}
78 )`,
79 map[string]interface{}{"x": x, "y": y})
80 }
81}
82
83func isMultiLineStringCompare(x, y interface{}) bool {
84 strX, ok := x.(string)
85 if !ok {
86 return false
87 }
88 strY, ok := y.(string)
89 if !ok {
90 return false
91 }
92 return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
93}
94
95func multiLineDiffResult(diff string) Result {
96 return ResultFailureTemplate(`
97--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
98+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
99{{ .Data.diff }}`,
100 map[string]interface{}{"diff": diff})
101}
102
103// Len succeeds if the sequence has the expected length.
104func Len(seq interface{}, expected int) Comparison {
105 return func() (result Result) {
106 defer func() {
107 if e := recover(); e != nil {
108 result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
109 }
110 }()
111 value := reflect.ValueOf(seq)
112 length := value.Len()
113 if length == expected {
114 return ResultSuccess
115 }
116 msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
117 return ResultFailure(msg)
118 }
119}
120
121// Contains succeeds if item is in collection. Collection may be a string, map,
122// slice, or array.
123//
124// If collection is a string, item must also be a string, and is compared using
125// strings.Contains().
126// If collection is a Map, contains will succeed if item is a key in the map.
127// If collection is a slice or array, item is compared to each item in the
128// sequence using reflect.DeepEqual().
129func Contains(collection interface{}, item interface{}) Comparison {
130 return func() Result {
131 colValue := reflect.ValueOf(collection)
132 if !colValue.IsValid() {
133 return ResultFailure(fmt.Sprintf("nil does not contain items"))
134 }
135 msg := fmt.Sprintf("%v does not contain %v", collection, item)
136
137 itemValue := reflect.ValueOf(item)
138 switch colValue.Type().Kind() {
139 case reflect.String:
140 if itemValue.Type().Kind() != reflect.String {
141 return ResultFailure("string may only contain strings")
142 }
143 return toResult(
144 strings.Contains(colValue.String(), itemValue.String()),
145 fmt.Sprintf("string %q does not contain %q", collection, item))
146
147 case reflect.Map:
148 if itemValue.Type() != colValue.Type().Key() {
149 return ResultFailure(fmt.Sprintf(
150 "%v can not contain a %v key", colValue.Type(), itemValue.Type()))
151 }
152 return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
153
154 case reflect.Slice, reflect.Array:
155 for i := 0; i < colValue.Len(); i++ {
156 if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
157 return ResultSuccess
158 }
159 }
160 return ResultFailure(msg)
161 default:
162 return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
163 }
164 }
165}
166
167// Panics succeeds if f() panics.
168func Panics(f func()) Comparison {
169 return func() (result Result) {
170 defer func() {
171 if err := recover(); err != nil {
172 result = ResultSuccess
173 }
174 }()
175 f()
176 return ResultFailure("did not panic")
177 }
178}
179
180// Error succeeds if err is a non-nil error, and the error message equals the
181// expected message.
182func Error(err error, message string) Comparison {
183 return func() Result {
184 switch {
185 case err == nil:
186 return ResultFailure("expected an error, got nil")
187 case err.Error() != message:
188 return ResultFailure(fmt.Sprintf(
189 "expected error %q, got %+v", message, err))
190 }
191 return ResultSuccess
192 }
193}
194
195// ErrorContains succeeds if err is a non-nil error, and the error message contains
196// the expected substring.
197func ErrorContains(err error, substring string) Comparison {
198 return func() Result {
199 switch {
200 case err == nil:
201 return ResultFailure("expected an error, got nil")
202 case !strings.Contains(err.Error(), substring):
203 return ResultFailure(fmt.Sprintf(
204 "expected error to contain %q, got %+v", substring, err))
205 }
206 return ResultSuccess
207 }
208}
209
210// Nil succeeds if obj is a nil interface, pointer, or function.
211//
212// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
213// maps, and channels.
214func Nil(obj interface{}) Comparison {
215 msgFunc := func(value reflect.Value) string {
216 return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
217 }
218 return isNil(obj, msgFunc)
219}
220
221func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
222 return func() Result {
223 if obj == nil {
224 return ResultSuccess
225 }
226 value := reflect.ValueOf(obj)
227 kind := value.Type().Kind()
228 if kind >= reflect.Chan && kind <= reflect.Slice {
229 if value.IsNil() {
230 return ResultSuccess
231 }
232 return ResultFailure(msgFunc(value))
233 }
234
235 return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
236 }
237}
238
239// ErrorType succeeds if err is not nil and is of the expected type.
240//
241// Expected can be one of:
242// a func(error) bool which returns true if the error is the expected type,
243// an instance of (or a pointer to) a struct of the expected type,
244// a pointer to an interface the error is expected to implement,
245// a reflect.Type of the expected struct or interface.
246func ErrorType(err error, expected interface{}) Comparison {
247 return func() Result {
248 switch expectedType := expected.(type) {
249 case func(error) bool:
250 return cmpErrorTypeFunc(err, expectedType)
251 case reflect.Type:
252 if expectedType.Kind() == reflect.Interface {
253 return cmpErrorTypeImplementsType(err, expectedType)
254 }
255 return cmpErrorTypeEqualType(err, expectedType)
256 case nil:
257 return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
258 }
259
260 expectedType := reflect.TypeOf(expected)
261 switch {
262 case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
263 return cmpErrorTypeEqualType(err, expectedType)
264 case isPtrToInterface(expectedType):
265 return cmpErrorTypeImplementsType(err, expectedType.Elem())
266 }
267 return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
268 }
269}
270
271func cmpErrorTypeFunc(err error, f func(error) bool) Result {
272 if f(err) {
273 return ResultSuccess
274 }
275 actual := "nil"
276 if err != nil {
277 actual = fmt.Sprintf("%s (%T)", err, err)
278 }
279 return ResultFailureTemplate(`error is {{ .Data.actual }}
280 {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
281 map[string]interface{}{"actual": actual})
282}
283
284func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
285 if err == nil {
286 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
287 }
288 errValue := reflect.ValueOf(err)
289 if errValue.Type() == expectedType {
290 return ResultSuccess
291 }
292 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
293}
294
295func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
296 if err == nil {
297 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
298 }
299 errValue := reflect.ValueOf(err)
300 if errValue.Type().Implements(expectedType) {
301 return ResultSuccess
302 }
303 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
304}
305
306func isPtrToInterface(typ reflect.Type) bool {
307 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
308}
309
310func isPtrToStruct(typ reflect.Type) bool {
311 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
312}