compare.go

  1package code
  2
  3import (
  4	"fmt"
  5	"go/types"
  6)
  7
  8// CompatibleTypes isnt a strict comparison, it allows for pointer differences
  9func CompatibleTypes(expected types.Type, actual types.Type) error {
 10	//fmt.Println("Comparing ", expected.String(), actual.String())
 11
 12	// Special case to deal with pointer mismatches
 13	{
 14		expectedPtr, expectedIsPtr := expected.(*types.Pointer)
 15		actualPtr, actualIsPtr := actual.(*types.Pointer)
 16
 17		if expectedIsPtr && actualIsPtr {
 18			return CompatibleTypes(expectedPtr.Elem(), actualPtr.Elem())
 19		}
 20		if expectedIsPtr && !actualIsPtr {
 21			return CompatibleTypes(expectedPtr.Elem(), actual)
 22		}
 23		if !expectedIsPtr && actualIsPtr {
 24			return CompatibleTypes(expected, actualPtr.Elem())
 25		}
 26	}
 27
 28	switch expected := expected.(type) {
 29	case *types.Slice:
 30		if actual, ok := actual.(*types.Slice); ok {
 31			return CompatibleTypes(expected.Elem(), actual.Elem())
 32		}
 33
 34	case *types.Array:
 35		if actual, ok := actual.(*types.Array); ok {
 36			if expected.Len() != actual.Len() {
 37				return fmt.Errorf("array length differs")
 38			}
 39
 40			return CompatibleTypes(expected.Elem(), actual.Elem())
 41		}
 42
 43	case *types.Basic:
 44		if actual, ok := actual.(*types.Basic); ok {
 45			if actual.Kind() != expected.Kind() {
 46				return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actual.Name())
 47			}
 48
 49			return nil
 50		}
 51
 52	case *types.Struct:
 53		if actual, ok := actual.(*types.Struct); ok {
 54			if expected.NumFields() != actual.NumFields() {
 55				return fmt.Errorf("number of struct fields differ")
 56			}
 57
 58			for i := 0; i < expected.NumFields(); i++ {
 59				if expected.Field(i).Name() != actual.Field(i).Name() {
 60					return fmt.Errorf("struct field %d name differs, %s != %s", i, expected.Field(i).Name(), actual.Field(i).Name())
 61				}
 62				if err := CompatibleTypes(expected.Field(i).Type(), actual.Field(i).Type()); err != nil {
 63					return err
 64				}
 65			}
 66			return nil
 67		}
 68
 69	case *types.Tuple:
 70		if actual, ok := actual.(*types.Tuple); ok {
 71			if expected.Len() != actual.Len() {
 72				return fmt.Errorf("tuple length differs, %d != %d", expected.Len(), actual.Len())
 73			}
 74
 75			for i := 0; i < expected.Len(); i++ {
 76				if err := CompatibleTypes(expected.At(i).Type(), actual.At(i).Type()); err != nil {
 77					return err
 78				}
 79			}
 80
 81			return nil
 82		}
 83
 84	case *types.Signature:
 85		if actual, ok := actual.(*types.Signature); ok {
 86			if err := CompatibleTypes(expected.Params(), actual.Params()); err != nil {
 87				return err
 88			}
 89			if err := CompatibleTypes(expected.Results(), actual.Results()); err != nil {
 90				return err
 91			}
 92
 93			return nil
 94		}
 95	case *types.Interface:
 96		if actual, ok := actual.(*types.Interface); ok {
 97			if expected.NumMethods() != actual.NumMethods() {
 98				return fmt.Errorf("interface method count differs, %d != %d", expected.NumMethods(), actual.NumMethods())
 99			}
100
101			for i := 0; i < expected.NumMethods(); i++ {
102				if expected.Method(i).Name() != actual.Method(i).Name() {
103					return fmt.Errorf("interface method %d name differs, %s != %s", i, expected.Method(i).Name(), actual.Method(i).Name())
104				}
105				if err := CompatibleTypes(expected.Method(i).Type(), actual.Method(i).Type()); err != nil {
106					return err
107				}
108			}
109
110			return nil
111		}
112
113	case *types.Map:
114		if actual, ok := actual.(*types.Map); ok {
115			if err := CompatibleTypes(expected.Key(), actual.Key()); err != nil {
116				return err
117			}
118
119			if err := CompatibleTypes(expected.Elem(), actual.Elem()); err != nil {
120				return err
121			}
122
123			return nil
124		}
125
126	case *types.Chan:
127		if actual, ok := actual.(*types.Chan); ok {
128			return CompatibleTypes(expected.Elem(), actual.Elem())
129		}
130
131	case *types.Named:
132		if actual, ok := actual.(*types.Named); ok {
133			if NormalizeVendor(expected.Obj().Pkg().Path()) != NormalizeVendor(actual.Obj().Pkg().Path()) {
134				return fmt.Errorf(
135					"package name of named type differs, %s != %s",
136					NormalizeVendor(expected.Obj().Pkg().Path()),
137					NormalizeVendor(actual.Obj().Pkg().Path()),
138				)
139			}
140
141			if expected.Obj().Name() != actual.Obj().Name() {
142				return fmt.Errorf(
143					"named type name differs, %s != %s",
144					NormalizeVendor(expected.Obj().Name()),
145					NormalizeVendor(actual.Obj().Name()),
146				)
147			}
148
149			return nil
150		}
151
152		// Before models are generated all missing references will be Invalid Basic references.
153		// lets assume these are valid too.
154		if actual, ok := actual.(*types.Basic); ok && actual.Kind() == types.Invalid {
155			return nil
156		}
157
158	default:
159		return fmt.Errorf("missing support for %T", expected)
160	}
161
162	return fmt.Errorf("type mismatch %T != %T", expected, actual)
163}