reflect.go

  1// Copyright 2013 Google Inc.  All rights reserved.
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//     http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package pretty
 16
 17import (
 18	"encoding"
 19	"fmt"
 20	"reflect"
 21	"sort"
 22)
 23
 24func isZeroVal(val reflect.Value) bool {
 25	if !val.CanInterface() {
 26		return false
 27	}
 28	z := reflect.Zero(val.Type()).Interface()
 29	return reflect.DeepEqual(val.Interface(), z)
 30}
 31
 32// pointerTracker is a helper for tracking pointer chasing to detect cycles.
 33type pointerTracker struct {
 34	addrs map[uintptr]int // addr[address] = seen count
 35
 36	lastID int
 37	ids    map[uintptr]int // ids[address] = id
 38}
 39
 40// track tracks following a reference (pointer, slice, map, etc).  Every call to
 41// track should be paired with a call to untrack.
 42func (p *pointerTracker) track(ptr uintptr) {
 43	if p.addrs == nil {
 44		p.addrs = make(map[uintptr]int)
 45	}
 46	p.addrs[ptr]++
 47}
 48
 49// untrack registers that we have backtracked over the reference to the pointer.
 50func (p *pointerTracker) untrack(ptr uintptr) {
 51	p.addrs[ptr]--
 52	if p.addrs[ptr] == 0 {
 53		delete(p.addrs, ptr)
 54	}
 55}
 56
 57// seen returns whether the pointer was previously seen along this path.
 58func (p *pointerTracker) seen(ptr uintptr) bool {
 59	_, ok := p.addrs[ptr]
 60	return ok
 61}
 62
 63// keep allocates an ID for the given address and returns it.
 64func (p *pointerTracker) keep(ptr uintptr) int {
 65	if p.ids == nil {
 66		p.ids = make(map[uintptr]int)
 67	}
 68	if _, ok := p.ids[ptr]; !ok {
 69		p.lastID++
 70		p.ids[ptr] = p.lastID
 71	}
 72	return p.ids[ptr]
 73}
 74
 75// id returns the ID for the given address.
 76func (p *pointerTracker) id(ptr uintptr) (int, bool) {
 77	if p.ids == nil {
 78		p.ids = make(map[uintptr]int)
 79	}
 80	id, ok := p.ids[ptr]
 81	return id, ok
 82}
 83
 84// reflector adds local state to the recursive reflection logic.
 85type reflector struct {
 86	*Config
 87	*pointerTracker
 88}
 89
 90// follow handles following a possiblly-recursive reference to the given value
 91// from the given ptr address.
 92func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
 93	if r.pointerTracker == nil {
 94		// Tracking disabled
 95		return r.val2node(val)
 96	}
 97
 98	// If a parent already followed this, emit a reference marker
 99	if r.seen(ptr) {
100		id := r.keep(ptr)
101		return ref{id}
102	}
103
104	// Track the pointer we're following while on this recursive branch
105	r.track(ptr)
106	defer r.untrack(ptr)
107	n := r.val2node(val)
108
109	// If the recursion used this ptr, wrap it with a target marker
110	if id, ok := r.id(ptr); ok {
111		return target{id, n}
112	}
113
114	// Otherwise, return the node unadulterated
115	return n
116}
117
118func (r *reflector) val2node(val reflect.Value) node {
119	if !val.IsValid() {
120		return rawVal("nil")
121	}
122
123	if val.CanInterface() {
124		v := val.Interface()
125		if formatter, ok := r.Formatter[val.Type()]; ok {
126			if formatter != nil {
127				res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
128				return rawVal(res[0].Interface().(string))
129			}
130		} else {
131			if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
132				return stringVal(s.String())
133			}
134			if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
135				if raw, err := t.MarshalText(); err == nil { // if NOT an error
136					return stringVal(string(raw))
137				}
138			}
139		}
140	}
141
142	switch kind := val.Kind(); kind {
143	case reflect.Ptr:
144		if val.IsNil() {
145			return rawVal("nil")
146		}
147		return r.follow(val.Pointer(), val.Elem())
148	case reflect.Interface:
149		if val.IsNil() {
150			return rawVal("nil")
151		}
152		return r.val2node(val.Elem())
153	case reflect.String:
154		return stringVal(val.String())
155	case reflect.Slice:
156		n := list{}
157		length := val.Len()
158		ptr := val.Pointer()
159		for i := 0; i < length; i++ {
160			n = append(n, r.follow(ptr, val.Index(i)))
161		}
162		return n
163	case reflect.Array:
164		n := list{}
165		length := val.Len()
166		for i := 0; i < length; i++ {
167			n = append(n, r.val2node(val.Index(i)))
168		}
169		return n
170	case reflect.Map:
171		// Extract the keys and sort them for stable iteration
172		keys := val.MapKeys()
173		pairs := make([]mapPair, 0, len(keys))
174		for _, key := range keys {
175			pairs = append(pairs, mapPair{
176				key:   new(formatter).compactString(r.val2node(key)), // can't be cyclic
177				value: val.MapIndex(key),
178			})
179		}
180		sort.Sort(byKey(pairs))
181
182		// Process the keys into the final representation
183		ptr, n := val.Pointer(), keyvals{}
184		for _, pair := range pairs {
185			n = append(n, keyval{
186				key: pair.key,
187				val: r.follow(ptr, pair.value),
188			})
189		}
190		return n
191	case reflect.Struct:
192		n := keyvals{}
193		typ := val.Type()
194		fields := typ.NumField()
195		for i := 0; i < fields; i++ {
196			sf := typ.Field(i)
197			if !r.IncludeUnexported && sf.PkgPath != "" {
198				continue
199			}
200			field := val.Field(i)
201			if r.SkipZeroFields && isZeroVal(field) {
202				continue
203			}
204			n = append(n, keyval{sf.Name, r.val2node(field)})
205		}
206		return n
207	case reflect.Bool:
208		if val.Bool() {
209			return rawVal("true")
210		}
211		return rawVal("false")
212	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
213		return rawVal(fmt.Sprintf("%d", val.Int()))
214	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
215		return rawVal(fmt.Sprintf("%d", val.Uint()))
216	case reflect.Uintptr:
217		return rawVal(fmt.Sprintf("0x%X", val.Uint()))
218	case reflect.Float32, reflect.Float64:
219		return rawVal(fmt.Sprintf("%v", val.Float()))
220	case reflect.Complex64, reflect.Complex128:
221		return rawVal(fmt.Sprintf("%v", val.Complex()))
222	}
223
224	// Fall back to the default %#v if we can
225	if val.CanInterface() {
226		return rawVal(fmt.Sprintf("%#v", val.Interface()))
227	}
228
229	return rawVal(val.String())
230}
231
232type mapPair struct {
233	key   string
234	value reflect.Value
235}
236
237type byKey []mapPair
238
239func (v byKey) Len() int           { return len(v) }
240func (v byKey) Swap(i, j int)      { v[i], v[j] = v[j], v[i] }
241func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }