1// Package deep provides function deep.Equal which is like reflect.DeepEqual but
2// returns a list of differences. This is helpful when comparing complex types
3// like structures and maps.
4package deep
5
6import (
7 "errors"
8 "fmt"
9 "log"
10 "reflect"
11 "strings"
12)
13
14var (
15 // FloatPrecision is the number of decimal places to round float values
16 // to when comparing.
17 FloatPrecision = 10
18
19 // MaxDiff specifies the maximum number of differences to return.
20 MaxDiff = 10
21
22 // MaxDepth specifies the maximum levels of a struct to recurse into.
23 MaxDepth = 10
24
25 // LogErrors causes errors to be logged to STDERR when true.
26 LogErrors = false
27
28 // CompareUnexportedFields causes unexported struct fields, like s in
29 // T{s int}, to be comparsed when true.
30 CompareUnexportedFields = false
31)
32
33var (
34 // ErrMaxRecursion is logged when MaxDepth is reached.
35 ErrMaxRecursion = errors.New("recursed to MaxDepth")
36
37 // ErrTypeMismatch is logged when Equal passed two different types of values.
38 ErrTypeMismatch = errors.New("variables are different reflect.Type")
39
40 // ErrNotHandled is logged when a primitive Go kind is not handled.
41 ErrNotHandled = errors.New("cannot compare the reflect.Kind")
42)
43
44type cmp struct {
45 diff []string
46 buff []string
47 floatFormat string
48}
49
50var errorType = reflect.TypeOf((*error)(nil)).Elem()
51
52// Equal compares variables a and b, recursing into their structure up to
53// MaxDepth levels deep, and returns a list of differences, or nil if there are
54// none. Some differences may not be found if an error is also returned.
55//
56// If a type has an Equal method, like time.Equal, it is called to check for
57// equality.
58func Equal(a, b interface{}) []string {
59 aVal := reflect.ValueOf(a)
60 bVal := reflect.ValueOf(b)
61 c := &cmp{
62 diff: []string{},
63 buff: []string{},
64 floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
65 }
66 if a == nil && b == nil {
67 return nil
68 } else if a == nil && b != nil {
69 c.saveDiff(b, "<nil pointer>")
70 } else if a != nil && b == nil {
71 c.saveDiff(a, "<nil pointer>")
72 }
73 if len(c.diff) > 0 {
74 return c.diff
75 }
76
77 c.equals(aVal, bVal, 0)
78 if len(c.diff) > 0 {
79 return c.diff // diffs
80 }
81 return nil // no diffs
82}
83
84func (c *cmp) equals(a, b reflect.Value, level int) {
85 if level > MaxDepth {
86 logError(ErrMaxRecursion)
87 return
88 }
89
90 // Check if one value is nil, e.g. T{x: *X} and T.x is nil
91 if !a.IsValid() || !b.IsValid() {
92 if a.IsValid() && !b.IsValid() {
93 c.saveDiff(a.Type(), "<nil pointer>")
94 } else if !a.IsValid() && b.IsValid() {
95 c.saveDiff("<nil pointer>", b.Type())
96 }
97 return
98 }
99
100 // If differenet types, they can't be equal
101 aType := a.Type()
102 bType := b.Type()
103 if aType != bType {
104 c.saveDiff(aType, bType)
105 logError(ErrTypeMismatch)
106 return
107 }
108
109 // Primitive https://golang.org/pkg/reflect/#Kind
110 aKind := a.Kind()
111 bKind := b.Kind()
112
113 // If both types implement the error interface, compare the error strings.
114 // This must be done before dereferencing because the interface is on a
115 // pointer receiver.
116 if aType.Implements(errorType) && bType.Implements(errorType) {
117 if a.Elem().IsValid() && b.Elem().IsValid() { // both err != nil
118 aString := a.MethodByName("Error").Call(nil)[0].String()
119 bString := b.MethodByName("Error").Call(nil)[0].String()
120 if aString != bString {
121 c.saveDiff(aString, bString)
122 }
123 return
124 }
125 }
126
127 // Dereference pointers and interface{}
128 if aElem, bElem := (aKind == reflect.Ptr || aKind == reflect.Interface),
129 (bKind == reflect.Ptr || bKind == reflect.Interface); aElem || bElem {
130
131 if aElem {
132 a = a.Elem()
133 }
134
135 if bElem {
136 b = b.Elem()
137 }
138
139 c.equals(a, b, level+1)
140 return
141 }
142
143 // Types with an Equal(), like time.Time.
144 eqFunc := a.MethodByName("Equal")
145 if eqFunc.IsValid() {
146 retVals := eqFunc.Call([]reflect.Value{b})
147 if !retVals[0].Bool() {
148 c.saveDiff(a, b)
149 }
150 return
151 }
152
153 switch aKind {
154
155 /////////////////////////////////////////////////////////////////////
156 // Iterable kinds
157 /////////////////////////////////////////////////////////////////////
158
159 case reflect.Struct:
160 /*
161 The variables are structs like:
162 type T struct {
163 FirstName string
164 LastName string
165 }
166 Type = <pkg>.T, Kind = reflect.Struct
167
168 Iterate through the fields (FirstName, LastName), recurse into their values.
169 */
170 for i := 0; i < a.NumField(); i++ {
171 if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
172 continue // skip unexported field, e.g. s in type T struct {s string}
173 }
174
175 c.push(aType.Field(i).Name) // push field name to buff
176
177 // Get the Value for each field, e.g. FirstName has Type = string,
178 // Kind = reflect.String.
179 af := a.Field(i)
180 bf := b.Field(i)
181
182 // Recurse to compare the field values
183 c.equals(af, bf, level+1)
184
185 c.pop() // pop field name from buff
186
187 if len(c.diff) >= MaxDiff {
188 break
189 }
190 }
191 case reflect.Map:
192 /*
193 The variables are maps like:
194 map[string]int{
195 "foo": 1,
196 "bar": 2,
197 }
198 Type = map[string]int, Kind = reflect.Map
199
200 Or:
201 type T map[string]int{}
202 Type = <pkg>.T, Kind = reflect.Map
203
204 Iterate through the map keys (foo, bar), recurse into their values.
205 */
206
207 if a.IsNil() || b.IsNil() {
208 if a.IsNil() && !b.IsNil() {
209 c.saveDiff("<nil map>", b)
210 } else if !a.IsNil() && b.IsNil() {
211 c.saveDiff(a, "<nil map>")
212 }
213 return
214 }
215
216 if a.Pointer() == b.Pointer() {
217 return
218 }
219
220 for _, key := range a.MapKeys() {
221 c.push(fmt.Sprintf("map[%s]", key))
222
223 aVal := a.MapIndex(key)
224 bVal := b.MapIndex(key)
225 if bVal.IsValid() {
226 c.equals(aVal, bVal, level+1)
227 } else {
228 c.saveDiff(aVal, "<does not have key>")
229 }
230
231 c.pop()
232
233 if len(c.diff) >= MaxDiff {
234 return
235 }
236 }
237
238 for _, key := range b.MapKeys() {
239 if aVal := a.MapIndex(key); aVal.IsValid() {
240 continue
241 }
242
243 c.push(fmt.Sprintf("map[%s]", key))
244 c.saveDiff("<does not have key>", b.MapIndex(key))
245 c.pop()
246 if len(c.diff) >= MaxDiff {
247 return
248 }
249 }
250 case reflect.Array:
251 n := a.Len()
252 for i := 0; i < n; i++ {
253 c.push(fmt.Sprintf("array[%d]", i))
254 c.equals(a.Index(i), b.Index(i), level+1)
255 c.pop()
256 if len(c.diff) >= MaxDiff {
257 break
258 }
259 }
260 case reflect.Slice:
261 if a.IsNil() || b.IsNil() {
262 if a.IsNil() && !b.IsNil() {
263 c.saveDiff("<nil slice>", b)
264 } else if !a.IsNil() && b.IsNil() {
265 c.saveDiff(a, "<nil slice>")
266 }
267 return
268 }
269
270 if a.Pointer() == b.Pointer() {
271 return
272 }
273
274 aLen := a.Len()
275 bLen := b.Len()
276 n := aLen
277 if bLen > aLen {
278 n = bLen
279 }
280 for i := 0; i < n; i++ {
281 c.push(fmt.Sprintf("slice[%d]", i))
282 if i < aLen && i < bLen {
283 c.equals(a.Index(i), b.Index(i), level+1)
284 } else if i < aLen {
285 c.saveDiff(a.Index(i), "<no value>")
286 } else {
287 c.saveDiff("<no value>", b.Index(i))
288 }
289 c.pop()
290 if len(c.diff) >= MaxDiff {
291 break
292 }
293 }
294
295 /////////////////////////////////////////////////////////////////////
296 // Primitive kinds
297 /////////////////////////////////////////////////////////////////////
298
299 case reflect.Float32, reflect.Float64:
300 // Avoid 0.04147685731961082 != 0.041476857319611
301 // 6 decimal places is close enough
302 aval := fmt.Sprintf(c.floatFormat, a.Float())
303 bval := fmt.Sprintf(c.floatFormat, b.Float())
304 if aval != bval {
305 c.saveDiff(a.Float(), b.Float())
306 }
307 case reflect.Bool:
308 if a.Bool() != b.Bool() {
309 c.saveDiff(a.Bool(), b.Bool())
310 }
311 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
312 if a.Int() != b.Int() {
313 c.saveDiff(a.Int(), b.Int())
314 }
315 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
316 if a.Uint() != b.Uint() {
317 c.saveDiff(a.Uint(), b.Uint())
318 }
319 case reflect.String:
320 if a.String() != b.String() {
321 c.saveDiff(a.String(), b.String())
322 }
323
324 default:
325 logError(ErrNotHandled)
326 }
327}
328
329func (c *cmp) push(name string) {
330 c.buff = append(c.buff, name)
331}
332
333func (c *cmp) pop() {
334 if len(c.buff) > 0 {
335 c.buff = c.buff[0 : len(c.buff)-1]
336 }
337}
338
339func (c *cmp) saveDiff(aval, bval interface{}) {
340 if len(c.buff) > 0 {
341 varName := strings.Join(c.buff, ".")
342 c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
343 } else {
344 c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
345 }
346}
347
348func logError(err error) {
349 if LogErrors {
350 log.Println(err)
351 }
352}