1package code
2
3import (
4 "fmt"
5 "go/types"
6)
7
8// CompatibleTypes isnt a strict comparison, it allows for pointer differences
9func CompatibleTypes(expected types.Type, actual types.Type) error {
10 //fmt.Println("Comparing ", expected.String(), actual.String())
11
12 // Special case to deal with pointer mismatches
13 {
14 expectedPtr, expectedIsPtr := expected.(*types.Pointer)
15 actualPtr, actualIsPtr := actual.(*types.Pointer)
16
17 if expectedIsPtr && actualIsPtr {
18 return CompatibleTypes(expectedPtr.Elem(), actualPtr.Elem())
19 }
20 if expectedIsPtr && !actualIsPtr {
21 return CompatibleTypes(expectedPtr.Elem(), actual)
22 }
23 if !expectedIsPtr && actualIsPtr {
24 return CompatibleTypes(expected, actualPtr.Elem())
25 }
26 }
27
28 switch expected := expected.(type) {
29 case *types.Slice:
30 if actual, ok := actual.(*types.Slice); ok {
31 return CompatibleTypes(expected.Elem(), actual.Elem())
32 }
33
34 case *types.Array:
35 if actual, ok := actual.(*types.Array); ok {
36 if expected.Len() != actual.Len() {
37 return fmt.Errorf("array length differs")
38 }
39
40 return CompatibleTypes(expected.Elem(), actual.Elem())
41 }
42
43 case *types.Basic:
44 if actual, ok := actual.(*types.Basic); ok {
45 if actual.Kind() != expected.Kind() {
46 return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actual.Name())
47 }
48
49 return nil
50 }
51
52 case *types.Struct:
53 if actual, ok := actual.(*types.Struct); ok {
54 if expected.NumFields() != actual.NumFields() {
55 return fmt.Errorf("number of struct fields differ")
56 }
57
58 for i := 0; i < expected.NumFields(); i++ {
59 if expected.Field(i).Name() != actual.Field(i).Name() {
60 return fmt.Errorf("struct field %d name differs, %s != %s", i, expected.Field(i).Name(), actual.Field(i).Name())
61 }
62 if err := CompatibleTypes(expected.Field(i).Type(), actual.Field(i).Type()); err != nil {
63 return err
64 }
65 }
66 return nil
67 }
68
69 case *types.Tuple:
70 if actual, ok := actual.(*types.Tuple); ok {
71 if expected.Len() != actual.Len() {
72 return fmt.Errorf("tuple length differs, %d != %d", expected.Len(), actual.Len())
73 }
74
75 for i := 0; i < expected.Len(); i++ {
76 if err := CompatibleTypes(expected.At(i).Type(), actual.At(i).Type()); err != nil {
77 return err
78 }
79 }
80
81 return nil
82 }
83
84 case *types.Signature:
85 if actual, ok := actual.(*types.Signature); ok {
86 if err := CompatibleTypes(expected.Params(), actual.Params()); err != nil {
87 return err
88 }
89 if err := CompatibleTypes(expected.Results(), actual.Results()); err != nil {
90 return err
91 }
92
93 return nil
94 }
95 case *types.Interface:
96 if actual, ok := actual.(*types.Interface); ok {
97 if expected.NumMethods() != actual.NumMethods() {
98 return fmt.Errorf("interface method count differs, %d != %d", expected.NumMethods(), actual.NumMethods())
99 }
100
101 for i := 0; i < expected.NumMethods(); i++ {
102 if expected.Method(i).Name() != actual.Method(i).Name() {
103 return fmt.Errorf("interface method %d name differs, %s != %s", i, expected.Method(i).Name(), actual.Method(i).Name())
104 }
105 if err := CompatibleTypes(expected.Method(i).Type(), actual.Method(i).Type()); err != nil {
106 return err
107 }
108 }
109
110 return nil
111 }
112
113 case *types.Map:
114 if actual, ok := actual.(*types.Map); ok {
115 if err := CompatibleTypes(expected.Key(), actual.Key()); err != nil {
116 return err
117 }
118
119 if err := CompatibleTypes(expected.Elem(), actual.Elem()); err != nil {
120 return err
121 }
122
123 return nil
124 }
125
126 case *types.Chan:
127 if actual, ok := actual.(*types.Chan); ok {
128 return CompatibleTypes(expected.Elem(), actual.Elem())
129 }
130
131 case *types.Named:
132 if actual, ok := actual.(*types.Named); ok {
133 if NormalizeVendor(expected.Obj().Pkg().Path()) != NormalizeVendor(actual.Obj().Pkg().Path()) {
134 return fmt.Errorf(
135 "package name of named type differs, %s != %s",
136 NormalizeVendor(expected.Obj().Pkg().Path()),
137 NormalizeVendor(actual.Obj().Pkg().Path()),
138 )
139 }
140
141 if expected.Obj().Name() != actual.Obj().Name() {
142 return fmt.Errorf(
143 "named type name differs, %s != %s",
144 NormalizeVendor(expected.Obj().Name()),
145 NormalizeVendor(actual.Obj().Name()),
146 )
147 }
148
149 return nil
150 }
151
152 // Before models are generated all missing references will be Invalid Basic references.
153 // lets assume these are valid too.
154 if actual, ok := actual.(*types.Basic); ok && actual.Kind() == types.Invalid {
155 return nil
156 }
157
158 default:
159 return fmt.Errorf("missing support for %T", expected)
160 }
161
162 return fmt.Errorf("type mismatch %T != %T", expected, actual)
163}