1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22 "bufio"
23 "encoding/base64"
24 "errors"
25 "fmt"
26 "io"
27 "math"
28 "net"
29 "net/http"
30 "net/url"
31 "strconv"
32 "strings"
33 "sync"
34 "time"
35 "unicode/utf8"
36
37 "golang.org/x/net/http2"
38 "golang.org/x/net/http2/hpack"
39 "google.golang.org/grpc/codes"
40)
41
42const (
43 // http2MaxFrameLen specifies the max length of a HTTP2 frame.
44 http2MaxFrameLen = 16384 // 16KB frame
45 // https://httpwg.org/specs/rfc7540.html#SettingValues
46 http2InitHeaderTableSize = 4096
47)
48
49var (
50 clientPreface = []byte(http2.ClientPreface)
51 http2ErrConvTab = map[http2.ErrCode]codes.Code{
52 http2.ErrCodeNo: codes.Internal,
53 http2.ErrCodeProtocol: codes.Internal,
54 http2.ErrCodeInternal: codes.Internal,
55 http2.ErrCodeFlowControl: codes.ResourceExhausted,
56 http2.ErrCodeSettingsTimeout: codes.Internal,
57 http2.ErrCodeStreamClosed: codes.Internal,
58 http2.ErrCodeFrameSize: codes.Internal,
59 http2.ErrCodeRefusedStream: codes.Unavailable,
60 http2.ErrCodeCancel: codes.Canceled,
61 http2.ErrCodeCompression: codes.Internal,
62 http2.ErrCodeConnect: codes.Internal,
63 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
64 http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
65 http2.ErrCodeHTTP11Required: codes.Internal,
66 }
67 // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
68 HTTPStatusConvTab = map[int]codes.Code{
69 // 400 Bad Request - INTERNAL.
70 http.StatusBadRequest: codes.Internal,
71 // 401 Unauthorized - UNAUTHENTICATED.
72 http.StatusUnauthorized: codes.Unauthenticated,
73 // 403 Forbidden - PERMISSION_DENIED.
74 http.StatusForbidden: codes.PermissionDenied,
75 // 404 Not Found - UNIMPLEMENTED.
76 http.StatusNotFound: codes.Unimplemented,
77 // 429 Too Many Requests - UNAVAILABLE.
78 http.StatusTooManyRequests: codes.Unavailable,
79 // 502 Bad Gateway - UNAVAILABLE.
80 http.StatusBadGateway: codes.Unavailable,
81 // 503 Service Unavailable - UNAVAILABLE.
82 http.StatusServiceUnavailable: codes.Unavailable,
83 // 504 Gateway timeout - UNAVAILABLE.
84 http.StatusGatewayTimeout: codes.Unavailable,
85 }
86)
87
88var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
89
90// isReservedHeader checks whether hdr belongs to HTTP2 headers
91// reserved by gRPC protocol. Any other headers are classified as the
92// user-specified metadata.
93func isReservedHeader(hdr string) bool {
94 if hdr != "" && hdr[0] == ':' {
95 return true
96 }
97 switch hdr {
98 case "content-type",
99 "user-agent",
100 "grpc-message-type",
101 "grpc-encoding",
102 "grpc-message",
103 "grpc-status",
104 "grpc-timeout",
105 // Intentionally exclude grpc-previous-rpc-attempts and
106 // grpc-retry-pushback-ms, which are "reserved", but their API
107 // intentionally works via metadata.
108 "te":
109 return true
110 default:
111 return false
112 }
113}
114
115// isWhitelistedHeader checks whether hdr should be propagated into metadata
116// visible to users, even though it is classified as "reserved", above.
117func isWhitelistedHeader(hdr string) bool {
118 switch hdr {
119 case ":authority", "user-agent":
120 return true
121 default:
122 return false
123 }
124}
125
126const binHdrSuffix = "-bin"
127
128func encodeBinHeader(v []byte) string {
129 return base64.RawStdEncoding.EncodeToString(v)
130}
131
132func decodeBinHeader(v string) ([]byte, error) {
133 if len(v)%4 == 0 {
134 // Input was padded, or padding was not necessary.
135 return base64.StdEncoding.DecodeString(v)
136 }
137 return base64.RawStdEncoding.DecodeString(v)
138}
139
140func encodeMetadataHeader(k, v string) string {
141 if strings.HasSuffix(k, binHdrSuffix) {
142 return encodeBinHeader(([]byte)(v))
143 }
144 return v
145}
146
147func decodeMetadataHeader(k, v string) (string, error) {
148 if strings.HasSuffix(k, binHdrSuffix) {
149 b, err := decodeBinHeader(v)
150 return string(b), err
151 }
152 return v, nil
153}
154
155type timeoutUnit uint8
156
157const (
158 hour timeoutUnit = 'H'
159 minute timeoutUnit = 'M'
160 second timeoutUnit = 'S'
161 millisecond timeoutUnit = 'm'
162 microsecond timeoutUnit = 'u'
163 nanosecond timeoutUnit = 'n'
164)
165
166func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
167 switch u {
168 case hour:
169 return time.Hour, true
170 case minute:
171 return time.Minute, true
172 case second:
173 return time.Second, true
174 case millisecond:
175 return time.Millisecond, true
176 case microsecond:
177 return time.Microsecond, true
178 case nanosecond:
179 return time.Nanosecond, true
180 default:
181 }
182 return
183}
184
185func decodeTimeout(s string) (time.Duration, error) {
186 size := len(s)
187 if size < 2 {
188 return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
189 }
190 if size > 9 {
191 // Spec allows for 8 digits plus the unit.
192 return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
193 }
194 unit := timeoutUnit(s[size-1])
195 d, ok := timeoutUnitToDuration(unit)
196 if !ok {
197 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
198 }
199 t, err := strconv.ParseInt(s[:size-1], 10, 64)
200 if err != nil {
201 return 0, err
202 }
203 const maxHours = math.MaxInt64 / int64(time.Hour)
204 if d == time.Hour && t > maxHours {
205 // This timeout would overflow math.MaxInt64; clamp it.
206 return time.Duration(math.MaxInt64), nil
207 }
208 return d * time.Duration(t), nil
209}
210
211const (
212 spaceByte = ' '
213 tildeByte = '~'
214 percentByte = '%'
215)
216
217// encodeGrpcMessage is used to encode status code in header field
218// "grpc-message". It does percent encoding and also replaces invalid utf-8
219// characters with Unicode replacement character.
220//
221// It checks to see if each individual byte in msg is an allowable byte, and
222// then either percent encoding or passing it through. When percent encoding,
223// the byte is converted into hexadecimal notation with a '%' prepended.
224func encodeGrpcMessage(msg string) string {
225 if msg == "" {
226 return ""
227 }
228 lenMsg := len(msg)
229 for i := 0; i < lenMsg; i++ {
230 c := msg[i]
231 if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
232 return encodeGrpcMessageUnchecked(msg)
233 }
234 }
235 return msg
236}
237
238func encodeGrpcMessageUnchecked(msg string) string {
239 var sb strings.Builder
240 for len(msg) > 0 {
241 r, size := utf8.DecodeRuneInString(msg)
242 for _, b := range []byte(string(r)) {
243 if size > 1 {
244 // If size > 1, r is not ascii. Always do percent encoding.
245 fmt.Fprintf(&sb, "%%%02X", b)
246 continue
247 }
248
249 // The for loop is necessary even if size == 1. r could be
250 // utf8.RuneError.
251 //
252 // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
253 if b >= spaceByte && b <= tildeByte && b != percentByte {
254 sb.WriteByte(b)
255 } else {
256 fmt.Fprintf(&sb, "%%%02X", b)
257 }
258 }
259 msg = msg[size:]
260 }
261 return sb.String()
262}
263
264// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
265func decodeGrpcMessage(msg string) string {
266 if msg == "" {
267 return ""
268 }
269 lenMsg := len(msg)
270 for i := 0; i < lenMsg; i++ {
271 if msg[i] == percentByte && i+2 < lenMsg {
272 return decodeGrpcMessageUnchecked(msg)
273 }
274 }
275 return msg
276}
277
278func decodeGrpcMessageUnchecked(msg string) string {
279 var sb strings.Builder
280 lenMsg := len(msg)
281 for i := 0; i < lenMsg; i++ {
282 c := msg[i]
283 if c == percentByte && i+2 < lenMsg {
284 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
285 if err != nil {
286 sb.WriteByte(c)
287 } else {
288 sb.WriteByte(byte(parsed))
289 i += 2
290 }
291 } else {
292 sb.WriteByte(c)
293 }
294 }
295 return sb.String()
296}
297
298type bufWriter struct {
299 pool *sync.Pool
300 buf []byte
301 offset int
302 batchSize int
303 conn net.Conn
304 err error
305}
306
307func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
308 w := &bufWriter{
309 batchSize: batchSize,
310 conn: conn,
311 pool: pool,
312 }
313 // this indicates that we should use non shared buf
314 if pool == nil {
315 w.buf = make([]byte, batchSize)
316 }
317 return w
318}
319
320func (w *bufWriter) Write(b []byte) (int, error) {
321 if w.err != nil {
322 return 0, w.err
323 }
324 if w.batchSize == 0 { // Buffer has been disabled.
325 n, err := w.conn.Write(b)
326 return n, toIOError(err)
327 }
328 if w.buf == nil {
329 b := w.pool.Get().(*[]byte)
330 w.buf = *b
331 }
332 written := 0
333 for len(b) > 0 {
334 copied := copy(w.buf[w.offset:], b)
335 b = b[copied:]
336 written += copied
337 w.offset += copied
338 if w.offset < w.batchSize {
339 continue
340 }
341 if err := w.flushKeepBuffer(); err != nil {
342 return written, err
343 }
344 }
345 return written, nil
346}
347
348func (w *bufWriter) Flush() error {
349 err := w.flushKeepBuffer()
350 // Only release the buffer if we are in a "shared" mode
351 if w.buf != nil && w.pool != nil {
352 b := w.buf
353 w.pool.Put(&b)
354 w.buf = nil
355 }
356 return err
357}
358
359func (w *bufWriter) flushKeepBuffer() error {
360 if w.err != nil {
361 return w.err
362 }
363 if w.offset == 0 {
364 return nil
365 }
366 _, w.err = w.conn.Write(w.buf[:w.offset])
367 w.err = toIOError(w.err)
368 w.offset = 0
369 return w.err
370}
371
372type ioError struct {
373 error
374}
375
376func (i ioError) Unwrap() error {
377 return i.error
378}
379
380func isIOError(err error) bool {
381 return errors.As(err, &ioError{})
382}
383
384func toIOError(err error) error {
385 if err == nil {
386 return nil
387 }
388 return ioError{error: err}
389}
390
391type framer struct {
392 writer *bufWriter
393 fr *http2.Framer
394}
395
396var writeBufferPoolMap = make(map[int]*sync.Pool)
397var writeBufferMutex sync.Mutex
398
399func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
400 if writeBufferSize < 0 {
401 writeBufferSize = 0
402 }
403 var r io.Reader = conn
404 if readBufferSize > 0 {
405 r = bufio.NewReaderSize(r, readBufferSize)
406 }
407 var pool *sync.Pool
408 if sharedWriteBuffer {
409 pool = getWriteBufferPool(writeBufferSize)
410 }
411 w := newBufWriter(conn, writeBufferSize, pool)
412 f := &framer{
413 writer: w,
414 fr: http2.NewFramer(w, r),
415 }
416 f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
417 // Opt-in to Frame reuse API on framer to reduce garbage.
418 // Frames aren't safe to read from after a subsequent call to ReadFrame.
419 f.fr.SetReuseFrames()
420 f.fr.MaxHeaderListSize = maxHeaderListSize
421 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
422 return f
423}
424
425func getWriteBufferPool(size int) *sync.Pool {
426 writeBufferMutex.Lock()
427 defer writeBufferMutex.Unlock()
428 pool, ok := writeBufferPoolMap[size]
429 if ok {
430 return pool
431 }
432 pool = &sync.Pool{
433 New: func() any {
434 b := make([]byte, size)
435 return &b
436 },
437 }
438 writeBufferPoolMap[size] = pool
439 return pool
440}
441
442// parseDialTarget returns the network and address to pass to dialer.
443func parseDialTarget(target string) (string, string) {
444 net := "tcp"
445 m1 := strings.Index(target, ":")
446 m2 := strings.Index(target, ":/")
447 // handle unix:addr which will fail with url.Parse
448 if m1 >= 0 && m2 < 0 {
449 if n := target[0:m1]; n == "unix" {
450 return n, target[m1+1:]
451 }
452 }
453 if m2 >= 0 {
454 t, err := url.Parse(target)
455 if err != nil {
456 return net, target
457 }
458 scheme := t.Scheme
459 addr := t.Path
460 if scheme == "unix" {
461 if addr == "" {
462 addr = t.Host
463 }
464 return scheme, addr
465 }
466 }
467 return net, target
468}