complexity.go

  1package complexity
  2
  3import (
  4	"github.com/99designs/gqlgen/graphql"
  5	"github.com/vektah/gqlparser/ast"
  6)
  7
  8func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
  9	walker := complexityWalker{
 10		es:     es,
 11		schema: es.Schema(),
 12		vars:   vars,
 13	}
 14	return walker.selectionSetComplexity(op.SelectionSet)
 15}
 16
 17type complexityWalker struct {
 18	es     graphql.ExecutableSchema
 19	schema *ast.Schema
 20	vars   map[string]interface{}
 21}
 22
 23func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
 24	var complexity int
 25	for _, selection := range selectionSet {
 26		switch s := selection.(type) {
 27		case *ast.Field:
 28			fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
 29			var childComplexity int
 30			switch fieldDefinition.Kind {
 31			case ast.Object, ast.Interface, ast.Union:
 32				childComplexity = cw.selectionSetComplexity(s.SelectionSet)
 33			}
 34
 35			args := s.ArgumentMap(cw.vars)
 36			var fieldComplexity int
 37			if s.ObjectDefinition.Kind == ast.Interface {
 38				fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
 39			} else {
 40				fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
 41			}
 42			complexity = safeAdd(complexity, fieldComplexity)
 43
 44		case *ast.FragmentSpread:
 45			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
 46
 47		case *ast.InlineFragment:
 48			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
 49		}
 50	}
 51	return complexity
 52}
 53
 54func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
 55	// Interfaces don't have their own separate field costs, so they have to assume the worst case.
 56	// We iterate over all implementors and choose the most expensive one.
 57	maxComplexity := 0
 58	implementors := cw.schema.GetPossibleTypes(def)
 59	for _, t := range implementors {
 60		fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
 61		if fieldComplexity > maxComplexity {
 62			maxComplexity = fieldComplexity
 63		}
 64	}
 65	return maxComplexity
 66}
 67
 68func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
 69	if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
 70		return customComplexity
 71	}
 72	// default complexity calculation
 73	return safeAdd(1, childComplexity)
 74}
 75
 76const maxInt = int(^uint(0) >> 1)
 77
 78// safeAdd is a saturating add of a and b that ignores negative operands.
 79// If a + b would overflow through normal Go addition,
 80// it returns the maximum integer value instead.
 81//
 82// Adding complexities with this function prevents attackers from intentionally
 83// overflowing the complexity calculation to allow overly-complex queries.
 84//
 85// It also helps mitigate the impact of custom complexities that accidentally
 86// return negative values.
 87func safeAdd(a, b int) int {
 88	// Ignore negative operands.
 89	if a < 0 {
 90		if b < 0 {
 91			return 1
 92		}
 93		return b
 94	} else if b < 0 {
 95		return a
 96	}
 97
 98	c := a + b
 99	if c < a {
100		// Set c to maximum integer instead of overflowing.
101		c = maxInt
102	}
103	return c
104}