interface_build.go

 1package codegen
 2
 3import (
 4	"fmt"
 5	"go/types"
 6	"os"
 7	"sort"
 8	"strings"
 9
10	"github.com/vektah/gqlgen/neelance/schema"
11	"golang.org/x/tools/go/loader"
12)
13
14func (cfg *Config) buildInterfaces(types NamedTypes, prog *loader.Program) []*Interface {
15	var interfaces []*Interface
16	for _, typ := range cfg.schema.Types {
17		switch typ := typ.(type) {
18		case *schema.Union, *schema.Interface:
19			interfaces = append(interfaces, cfg.buildInterface(types, typ, prog))
20		default:
21			continue
22		}
23	}
24
25	sort.Slice(interfaces, func(i, j int) bool {
26		return strings.Compare(interfaces[i].GQLType, interfaces[j].GQLType) == -1
27	})
28
29	return interfaces
30}
31
32func (cfg *Config) buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program) *Interface {
33	switch typ := typ.(type) {
34
35	case *schema.Union:
36		i := &Interface{NamedType: types[typ.TypeName()]}
37
38		for _, implementor := range typ.PossibleTypes {
39			t := types[implementor.TypeName()]
40
41			i.Implementors = append(i.Implementors, InterfaceImplementor{
42				NamedType:     t,
43				ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog),
44			})
45		}
46
47		return i
48
49	case *schema.Interface:
50		i := &Interface{NamedType: types[typ.TypeName()]}
51
52		for _, implementor := range typ.PossibleTypes {
53			t := types[implementor.TypeName()]
54
55			i.Implementors = append(i.Implementors, InterfaceImplementor{
56				NamedType:     t,
57				ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog),
58			})
59		}
60
61		return i
62	default:
63		panic(fmt.Errorf("unknown interface %#v", typ))
64	}
65}
66
67func (cfg *Config) isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool {
68	interfaceType, err := findGoInterface(prog, intf.Package, intf.GoType)
69	if interfaceType == nil || err != nil {
70		return true
71	}
72
73	implementorType, err := findGoNamedType(prog, implementor.Package, implementor.GoType)
74	if implementorType == nil || err != nil {
75		return true
76	}
77
78	for i := 0; i < interfaceType.NumMethods(); i++ {
79		intfMethod := interfaceType.Method(i)
80
81		implMethod := findMethod(implementorType, intfMethod.Name())
82		if implMethod == nil {
83			fmt.Fprintf(os.Stderr, "missing method %s on %s\n", intfMethod.Name(), implementor.GoType)
84			return false
85		}
86
87		sig := implMethod.Type().(*types.Signature)
88		if _, isPtr := sig.Recv().Type().(*types.Pointer); isPtr {
89			return false
90		}
91	}
92
93	return true
94}