deep.go

  1// Package deep provides function deep.Equal which is like reflect.DeepEqual but
  2// returns a list of differences. This is helpful when comparing complex types
  3// like structures and maps.
  4package deep
  5
  6import (
  7	"errors"
  8	"fmt"
  9	"log"
 10	"reflect"
 11	"strings"
 12)
 13
 14var (
 15	// FloatPrecision is the number of decimal places to round float values
 16	// to when comparing.
 17	FloatPrecision = 10
 18
 19	// MaxDiff specifies the maximum number of differences to return.
 20	MaxDiff = 10
 21
 22	// MaxDepth specifies the maximum levels of a struct to recurse into.
 23	MaxDepth = 10
 24
 25	// LogErrors causes errors to be logged to STDERR when true.
 26	LogErrors = false
 27
 28	// CompareUnexportedFields causes unexported struct fields, like s in
 29	// T{s int}, to be comparsed when true.
 30	CompareUnexportedFields = false
 31)
 32
 33var (
 34	// ErrMaxRecursion is logged when MaxDepth is reached.
 35	ErrMaxRecursion = errors.New("recursed to MaxDepth")
 36
 37	// ErrTypeMismatch is logged when Equal passed two different types of values.
 38	ErrTypeMismatch = errors.New("variables are different reflect.Type")
 39
 40	// ErrNotHandled is logged when a primitive Go kind is not handled.
 41	ErrNotHandled = errors.New("cannot compare the reflect.Kind")
 42)
 43
 44type cmp struct {
 45	diff        []string
 46	buff        []string
 47	floatFormat string
 48}
 49
 50var errorType = reflect.TypeOf((*error)(nil)).Elem()
 51
 52// Equal compares variables a and b, recursing into their structure up to
 53// MaxDepth levels deep, and returns a list of differences, or nil if there are
 54// none. Some differences may not be found if an error is also returned.
 55//
 56// If a type has an Equal method, like time.Equal, it is called to check for
 57// equality.
 58func Equal(a, b interface{}) []string {
 59	aVal := reflect.ValueOf(a)
 60	bVal := reflect.ValueOf(b)
 61	c := &cmp{
 62		diff:        []string{},
 63		buff:        []string{},
 64		floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
 65	}
 66	if a == nil && b == nil {
 67		return nil
 68	} else if a == nil && b != nil {
 69		c.saveDiff(b, "<nil pointer>")
 70	} else if a != nil && b == nil {
 71		c.saveDiff(a, "<nil pointer>")
 72	}
 73	if len(c.diff) > 0 {
 74		return c.diff
 75	}
 76
 77	c.equals(aVal, bVal, 0)
 78	if len(c.diff) > 0 {
 79		return c.diff // diffs
 80	}
 81	return nil // no diffs
 82}
 83
 84func (c *cmp) equals(a, b reflect.Value, level int) {
 85	if level > MaxDepth {
 86		logError(ErrMaxRecursion)
 87		return
 88	}
 89
 90	// Check if one value is nil, e.g. T{x: *X} and T.x is nil
 91	if !a.IsValid() || !b.IsValid() {
 92		if a.IsValid() && !b.IsValid() {
 93			c.saveDiff(a.Type(), "<nil pointer>")
 94		} else if !a.IsValid() && b.IsValid() {
 95			c.saveDiff("<nil pointer>", b.Type())
 96		}
 97		return
 98	}
 99
100	// If differenet types, they can't be equal
101	aType := a.Type()
102	bType := b.Type()
103	if aType != bType {
104		c.saveDiff(aType, bType)
105		logError(ErrTypeMismatch)
106		return
107	}
108
109	// Primitive https://golang.org/pkg/reflect/#Kind
110	aKind := a.Kind()
111	bKind := b.Kind()
112
113	// If both types implement the error interface, compare the error strings.
114	// This must be done before dereferencing because the interface is on a
115	// pointer receiver.
116	if aType.Implements(errorType) && bType.Implements(errorType) {
117		if a.Elem().IsValid() && b.Elem().IsValid() { // both err != nil
118			aString := a.MethodByName("Error").Call(nil)[0].String()
119			bString := b.MethodByName("Error").Call(nil)[0].String()
120			if aString != bString {
121				c.saveDiff(aString, bString)
122			}
123			return
124		}
125	}
126
127	// Dereference pointers and interface{}
128	if aElem, bElem := (aKind == reflect.Ptr || aKind == reflect.Interface),
129		(bKind == reflect.Ptr || bKind == reflect.Interface); aElem || bElem {
130
131		if aElem {
132			a = a.Elem()
133		}
134
135		if bElem {
136			b = b.Elem()
137		}
138
139		c.equals(a, b, level+1)
140		return
141	}
142
143	// Types with an Equal(), like time.Time.
144	eqFunc := a.MethodByName("Equal")
145	if eqFunc.IsValid() {
146		retVals := eqFunc.Call([]reflect.Value{b})
147		if !retVals[0].Bool() {
148			c.saveDiff(a, b)
149		}
150		return
151	}
152
153	switch aKind {
154
155	/////////////////////////////////////////////////////////////////////
156	// Iterable kinds
157	/////////////////////////////////////////////////////////////////////
158
159	case reflect.Struct:
160		/*
161			The variables are structs like:
162				type T struct {
163					FirstName string
164					LastName  string
165				}
166			Type = <pkg>.T, Kind = reflect.Struct
167
168			Iterate through the fields (FirstName, LastName), recurse into their values.
169		*/
170		for i := 0; i < a.NumField(); i++ {
171			if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
172				continue // skip unexported field, e.g. s in type T struct {s string}
173			}
174
175			c.push(aType.Field(i).Name) // push field name to buff
176
177			// Get the Value for each field, e.g. FirstName has Type = string,
178			// Kind = reflect.String.
179			af := a.Field(i)
180			bf := b.Field(i)
181
182			// Recurse to compare the field values
183			c.equals(af, bf, level+1)
184
185			c.pop() // pop field name from buff
186
187			if len(c.diff) >= MaxDiff {
188				break
189			}
190		}
191	case reflect.Map:
192		/*
193			The variables are maps like:
194				map[string]int{
195					"foo": 1,
196					"bar": 2,
197				}
198			Type = map[string]int, Kind = reflect.Map
199
200			Or:
201				type T map[string]int{}
202			Type = <pkg>.T, Kind = reflect.Map
203
204			Iterate through the map keys (foo, bar), recurse into their values.
205		*/
206
207		if a.IsNil() || b.IsNil() {
208			if a.IsNil() && !b.IsNil() {
209				c.saveDiff("<nil map>", b)
210			} else if !a.IsNil() && b.IsNil() {
211				c.saveDiff(a, "<nil map>")
212			}
213			return
214		}
215
216		if a.Pointer() == b.Pointer() {
217			return
218		}
219
220		for _, key := range a.MapKeys() {
221			c.push(fmt.Sprintf("map[%s]", key))
222
223			aVal := a.MapIndex(key)
224			bVal := b.MapIndex(key)
225			if bVal.IsValid() {
226				c.equals(aVal, bVal, level+1)
227			} else {
228				c.saveDiff(aVal, "<does not have key>")
229			}
230
231			c.pop()
232
233			if len(c.diff) >= MaxDiff {
234				return
235			}
236		}
237
238		for _, key := range b.MapKeys() {
239			if aVal := a.MapIndex(key); aVal.IsValid() {
240				continue
241			}
242
243			c.push(fmt.Sprintf("map[%s]", key))
244			c.saveDiff("<does not have key>", b.MapIndex(key))
245			c.pop()
246			if len(c.diff) >= MaxDiff {
247				return
248			}
249		}
250	case reflect.Array:
251		n := a.Len()
252		for i := 0; i < n; i++ {
253			c.push(fmt.Sprintf("array[%d]", i))
254			c.equals(a.Index(i), b.Index(i), level+1)
255			c.pop()
256			if len(c.diff) >= MaxDiff {
257				break
258			}
259		}
260	case reflect.Slice:
261		if a.IsNil() || b.IsNil() {
262			if a.IsNil() && !b.IsNil() {
263				c.saveDiff("<nil slice>", b)
264			} else if !a.IsNil() && b.IsNil() {
265				c.saveDiff(a, "<nil slice>")
266			}
267			return
268		}
269
270		if a.Pointer() == b.Pointer() {
271			return
272		}
273
274		aLen := a.Len()
275		bLen := b.Len()
276		n := aLen
277		if bLen > aLen {
278			n = bLen
279		}
280		for i := 0; i < n; i++ {
281			c.push(fmt.Sprintf("slice[%d]", i))
282			if i < aLen && i < bLen {
283				c.equals(a.Index(i), b.Index(i), level+1)
284			} else if i < aLen {
285				c.saveDiff(a.Index(i), "<no value>")
286			} else {
287				c.saveDiff("<no value>", b.Index(i))
288			}
289			c.pop()
290			if len(c.diff) >= MaxDiff {
291				break
292			}
293		}
294
295	/////////////////////////////////////////////////////////////////////
296	// Primitive kinds
297	/////////////////////////////////////////////////////////////////////
298
299	case reflect.Float32, reflect.Float64:
300		// Avoid 0.04147685731961082 != 0.041476857319611
301		// 6 decimal places is close enough
302		aval := fmt.Sprintf(c.floatFormat, a.Float())
303		bval := fmt.Sprintf(c.floatFormat, b.Float())
304		if aval != bval {
305			c.saveDiff(a.Float(), b.Float())
306		}
307	case reflect.Bool:
308		if a.Bool() != b.Bool() {
309			c.saveDiff(a.Bool(), b.Bool())
310		}
311	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
312		if a.Int() != b.Int() {
313			c.saveDiff(a.Int(), b.Int())
314		}
315	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
316		if a.Uint() != b.Uint() {
317			c.saveDiff(a.Uint(), b.Uint())
318		}
319	case reflect.String:
320		if a.String() != b.String() {
321			c.saveDiff(a.String(), b.String())
322		}
323
324	default:
325		logError(ErrNotHandled)
326	}
327}
328
329func (c *cmp) push(name string) {
330	c.buff = append(c.buff, name)
331}
332
333func (c *cmp) pop() {
334	if len(c.buff) > 0 {
335		c.buff = c.buff[0 : len(c.buff)-1]
336	}
337}
338
339func (c *cmp) saveDiff(aval, bval interface{}) {
340	if len(c.buff) > 0 {
341		varName := strings.Join(c.buff, ".")
342		c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
343	} else {
344		c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
345	}
346}
347
348func logError(err error) {
349	if LogErrors {
350		log.Println(err)
351	}
352}