1package complexity
2
3import (
4 "github.com/99designs/gqlgen/graphql"
5 "github.com/vektah/gqlparser/ast"
6)
7
8func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
9 walker := complexityWalker{
10 es: es,
11 schema: es.Schema(),
12 vars: vars,
13 }
14 return walker.selectionSetComplexity(op.SelectionSet)
15}
16
17type complexityWalker struct {
18 es graphql.ExecutableSchema
19 schema *ast.Schema
20 vars map[string]interface{}
21}
22
23func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
24 var complexity int
25 for _, selection := range selectionSet {
26 switch s := selection.(type) {
27 case *ast.Field:
28 fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
29 var childComplexity int
30 switch fieldDefinition.Kind {
31 case ast.Object, ast.Interface, ast.Union:
32 childComplexity = cw.selectionSetComplexity(s.SelectionSet)
33 }
34
35 args := s.ArgumentMap(cw.vars)
36 var fieldComplexity int
37 if s.ObjectDefinition.Kind == ast.Interface {
38 fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
39 } else {
40 fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
41 }
42 complexity = safeAdd(complexity, fieldComplexity)
43
44 case *ast.FragmentSpread:
45 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
46
47 case *ast.InlineFragment:
48 complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
49 }
50 }
51 return complexity
52}
53
54func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
55 // Interfaces don't have their own separate field costs, so they have to assume the worst case.
56 // We iterate over all implementors and choose the most expensive one.
57 maxComplexity := 0
58 implementors := cw.schema.GetPossibleTypes(def)
59 for _, t := range implementors {
60 fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
61 if fieldComplexity > maxComplexity {
62 maxComplexity = fieldComplexity
63 }
64 }
65 return maxComplexity
66}
67
68func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
69 if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
70 return customComplexity
71 }
72 // default complexity calculation
73 return safeAdd(1, childComplexity)
74}
75
76const maxInt = int(^uint(0) >> 1)
77
78// safeAdd is a saturating add of a and b that ignores negative operands.
79// If a + b would overflow through normal Go addition,
80// it returns the maximum integer value instead.
81//
82// Adding complexities with this function prevents attackers from intentionally
83// overflowing the complexity calculation to allow overly-complex queries.
84//
85// It also helps mitigate the impact of custom complexities that accidentally
86// return negative values.
87func safeAdd(a, b int) int {
88 // Ignore negative operands.
89 if a < 0 {
90 if b < 0 {
91 return 1
92 }
93 return b
94 } else if b < 0 {
95 return a
96 }
97
98 c := a + b
99 if c < a {
100 // Set c to maximum integer instead of overflowing.
101 c = maxInt
102 }
103 return c
104}