1// Copyright 2013 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package pretty
16
17import (
18 "encoding"
19 "fmt"
20 "reflect"
21 "sort"
22)
23
24func isZeroVal(val reflect.Value) bool {
25 if !val.CanInterface() {
26 return false
27 }
28 z := reflect.Zero(val.Type()).Interface()
29 return reflect.DeepEqual(val.Interface(), z)
30}
31
32// pointerTracker is a helper for tracking pointer chasing to detect cycles.
33type pointerTracker struct {
34 addrs map[uintptr]int // addr[address] = seen count
35
36 lastID int
37 ids map[uintptr]int // ids[address] = id
38}
39
40// track tracks following a reference (pointer, slice, map, etc). Every call to
41// track should be paired with a call to untrack.
42func (p *pointerTracker) track(ptr uintptr) {
43 if p.addrs == nil {
44 p.addrs = make(map[uintptr]int)
45 }
46 p.addrs[ptr]++
47}
48
49// untrack registers that we have backtracked over the reference to the pointer.
50func (p *pointerTracker) untrack(ptr uintptr) {
51 p.addrs[ptr]--
52 if p.addrs[ptr] == 0 {
53 delete(p.addrs, ptr)
54 }
55}
56
57// seen returns whether the pointer was previously seen along this path.
58func (p *pointerTracker) seen(ptr uintptr) bool {
59 _, ok := p.addrs[ptr]
60 return ok
61}
62
63// keep allocates an ID for the given address and returns it.
64func (p *pointerTracker) keep(ptr uintptr) int {
65 if p.ids == nil {
66 p.ids = make(map[uintptr]int)
67 }
68 if _, ok := p.ids[ptr]; !ok {
69 p.lastID++
70 p.ids[ptr] = p.lastID
71 }
72 return p.ids[ptr]
73}
74
75// id returns the ID for the given address.
76func (p *pointerTracker) id(ptr uintptr) (int, bool) {
77 if p.ids == nil {
78 p.ids = make(map[uintptr]int)
79 }
80 id, ok := p.ids[ptr]
81 return id, ok
82}
83
84// reflector adds local state to the recursive reflection logic.
85type reflector struct {
86 *Config
87 *pointerTracker
88}
89
90// follow handles following a possiblly-recursive reference to the given value
91// from the given ptr address.
92func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
93 if r.pointerTracker == nil {
94 // Tracking disabled
95 return r.val2node(val)
96 }
97
98 // If a parent already followed this, emit a reference marker
99 if r.seen(ptr) {
100 id := r.keep(ptr)
101 return ref{id}
102 }
103
104 // Track the pointer we're following while on this recursive branch
105 r.track(ptr)
106 defer r.untrack(ptr)
107 n := r.val2node(val)
108
109 // If the recursion used this ptr, wrap it with a target marker
110 if id, ok := r.id(ptr); ok {
111 return target{id, n}
112 }
113
114 // Otherwise, return the node unadulterated
115 return n
116}
117
118func (r *reflector) val2node(val reflect.Value) node {
119 if !val.IsValid() {
120 return rawVal("nil")
121 }
122
123 if val.CanInterface() {
124 v := val.Interface()
125 if formatter, ok := r.Formatter[val.Type()]; ok {
126 if formatter != nil {
127 res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
128 return rawVal(res[0].Interface().(string))
129 }
130 } else {
131 if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
132 return stringVal(s.String())
133 }
134 if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
135 if raw, err := t.MarshalText(); err == nil { // if NOT an error
136 return stringVal(string(raw))
137 }
138 }
139 }
140 }
141
142 switch kind := val.Kind(); kind {
143 case reflect.Ptr:
144 if val.IsNil() {
145 return rawVal("nil")
146 }
147 return r.follow(val.Pointer(), val.Elem())
148 case reflect.Interface:
149 if val.IsNil() {
150 return rawVal("nil")
151 }
152 return r.val2node(val.Elem())
153 case reflect.String:
154 return stringVal(val.String())
155 case reflect.Slice:
156 n := list{}
157 length := val.Len()
158 ptr := val.Pointer()
159 for i := 0; i < length; i++ {
160 n = append(n, r.follow(ptr, val.Index(i)))
161 }
162 return n
163 case reflect.Array:
164 n := list{}
165 length := val.Len()
166 for i := 0; i < length; i++ {
167 n = append(n, r.val2node(val.Index(i)))
168 }
169 return n
170 case reflect.Map:
171 // Extract the keys and sort them for stable iteration
172 keys := val.MapKeys()
173 pairs := make([]mapPair, 0, len(keys))
174 for _, key := range keys {
175 pairs = append(pairs, mapPair{
176 key: new(formatter).compactString(r.val2node(key)), // can't be cyclic
177 value: val.MapIndex(key),
178 })
179 }
180 sort.Sort(byKey(pairs))
181
182 // Process the keys into the final representation
183 ptr, n := val.Pointer(), keyvals{}
184 for _, pair := range pairs {
185 n = append(n, keyval{
186 key: pair.key,
187 val: r.follow(ptr, pair.value),
188 })
189 }
190 return n
191 case reflect.Struct:
192 n := keyvals{}
193 typ := val.Type()
194 fields := typ.NumField()
195 for i := 0; i < fields; i++ {
196 sf := typ.Field(i)
197 if !r.IncludeUnexported && sf.PkgPath != "" {
198 continue
199 }
200 field := val.Field(i)
201 if r.SkipZeroFields && isZeroVal(field) {
202 continue
203 }
204 n = append(n, keyval{sf.Name, r.val2node(field)})
205 }
206 return n
207 case reflect.Bool:
208 if val.Bool() {
209 return rawVal("true")
210 }
211 return rawVal("false")
212 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
213 return rawVal(fmt.Sprintf("%d", val.Int()))
214 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
215 return rawVal(fmt.Sprintf("%d", val.Uint()))
216 case reflect.Uintptr:
217 return rawVal(fmt.Sprintf("0x%X", val.Uint()))
218 case reflect.Float32, reflect.Float64:
219 return rawVal(fmt.Sprintf("%v", val.Float()))
220 case reflect.Complex64, reflect.Complex128:
221 return rawVal(fmt.Sprintf("%v", val.Complex()))
222 }
223
224 // Fall back to the default %#v if we can
225 if val.CanInterface() {
226 return rawVal(fmt.Sprintf("%#v", val.Interface()))
227 }
228
229 return rawVal(val.String())
230}
231
232type mapPair struct {
233 key string
234 value reflect.Value
235}
236
237type byKey []mapPair
238
239func (v byKey) Len() int { return len(v) }
240func (v byKey) Swap(i, j int) { v[i], v[j] = v[j], v[i] }
241func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }