http_util.go

  1/*
  2 *
  3 * Copyright 2014 gRPC authors.
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     http://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19package transport
 20
 21import (
 22	"bufio"
 23	"encoding/base64"
 24	"errors"
 25	"fmt"
 26	"io"
 27	"math"
 28	"net"
 29	"net/http"
 30	"net/url"
 31	"strconv"
 32	"strings"
 33	"sync"
 34	"time"
 35	"unicode/utf8"
 36
 37	"golang.org/x/net/http2"
 38	"golang.org/x/net/http2/hpack"
 39	"google.golang.org/grpc/codes"
 40)
 41
 42const (
 43	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
 44	http2MaxFrameLen = 16384 // 16KB frame
 45	// https://httpwg.org/specs/rfc7540.html#SettingValues
 46	http2InitHeaderTableSize = 4096
 47)
 48
 49var (
 50	clientPreface   = []byte(http2.ClientPreface)
 51	http2ErrConvTab = map[http2.ErrCode]codes.Code{
 52		http2.ErrCodeNo:                 codes.Internal,
 53		http2.ErrCodeProtocol:           codes.Internal,
 54		http2.ErrCodeInternal:           codes.Internal,
 55		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
 56		http2.ErrCodeSettingsTimeout:    codes.Internal,
 57		http2.ErrCodeStreamClosed:       codes.Internal,
 58		http2.ErrCodeFrameSize:          codes.Internal,
 59		http2.ErrCodeRefusedStream:      codes.Unavailable,
 60		http2.ErrCodeCancel:             codes.Canceled,
 61		http2.ErrCodeCompression:        codes.Internal,
 62		http2.ErrCodeConnect:            codes.Internal,
 63		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
 64		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
 65		http2.ErrCodeHTTP11Required:     codes.Internal,
 66	}
 67	// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
 68	HTTPStatusConvTab = map[int]codes.Code{
 69		// 400 Bad Request - INTERNAL.
 70		http.StatusBadRequest: codes.Internal,
 71		// 401 Unauthorized  - UNAUTHENTICATED.
 72		http.StatusUnauthorized: codes.Unauthenticated,
 73		// 403 Forbidden - PERMISSION_DENIED.
 74		http.StatusForbidden: codes.PermissionDenied,
 75		// 404 Not Found - UNIMPLEMENTED.
 76		http.StatusNotFound: codes.Unimplemented,
 77		// 429 Too Many Requests - UNAVAILABLE.
 78		http.StatusTooManyRequests: codes.Unavailable,
 79		// 502 Bad Gateway - UNAVAILABLE.
 80		http.StatusBadGateway: codes.Unavailable,
 81		// 503 Service Unavailable - UNAVAILABLE.
 82		http.StatusServiceUnavailable: codes.Unavailable,
 83		// 504 Gateway timeout - UNAVAILABLE.
 84		http.StatusGatewayTimeout: codes.Unavailable,
 85	}
 86)
 87
 88var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
 89
 90// isReservedHeader checks whether hdr belongs to HTTP2 headers
 91// reserved by gRPC protocol. Any other headers are classified as the
 92// user-specified metadata.
 93func isReservedHeader(hdr string) bool {
 94	if hdr != "" && hdr[0] == ':' {
 95		return true
 96	}
 97	switch hdr {
 98	case "content-type",
 99		"user-agent",
100		"grpc-message-type",
101		"grpc-encoding",
102		"grpc-message",
103		"grpc-status",
104		"grpc-timeout",
105		// Intentionally exclude grpc-previous-rpc-attempts and
106		// grpc-retry-pushback-ms, which are "reserved", but their API
107		// intentionally works via metadata.
108		"te":
109		return true
110	default:
111		return false
112	}
113}
114
115// isWhitelistedHeader checks whether hdr should be propagated into metadata
116// visible to users, even though it is classified as "reserved", above.
117func isWhitelistedHeader(hdr string) bool {
118	switch hdr {
119	case ":authority", "user-agent":
120		return true
121	default:
122		return false
123	}
124}
125
126const binHdrSuffix = "-bin"
127
128func encodeBinHeader(v []byte) string {
129	return base64.RawStdEncoding.EncodeToString(v)
130}
131
132func decodeBinHeader(v string) ([]byte, error) {
133	if len(v)%4 == 0 {
134		// Input was padded, or padding was not necessary.
135		return base64.StdEncoding.DecodeString(v)
136	}
137	return base64.RawStdEncoding.DecodeString(v)
138}
139
140func encodeMetadataHeader(k, v string) string {
141	if strings.HasSuffix(k, binHdrSuffix) {
142		return encodeBinHeader(([]byte)(v))
143	}
144	return v
145}
146
147func decodeMetadataHeader(k, v string) (string, error) {
148	if strings.HasSuffix(k, binHdrSuffix) {
149		b, err := decodeBinHeader(v)
150		return string(b), err
151	}
152	return v, nil
153}
154
155type timeoutUnit uint8
156
157const (
158	hour        timeoutUnit = 'H'
159	minute      timeoutUnit = 'M'
160	second      timeoutUnit = 'S'
161	millisecond timeoutUnit = 'm'
162	microsecond timeoutUnit = 'u'
163	nanosecond  timeoutUnit = 'n'
164)
165
166func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
167	switch u {
168	case hour:
169		return time.Hour, true
170	case minute:
171		return time.Minute, true
172	case second:
173		return time.Second, true
174	case millisecond:
175		return time.Millisecond, true
176	case microsecond:
177		return time.Microsecond, true
178	case nanosecond:
179		return time.Nanosecond, true
180	default:
181	}
182	return
183}
184
185func decodeTimeout(s string) (time.Duration, error) {
186	size := len(s)
187	if size < 2 {
188		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
189	}
190	if size > 9 {
191		// Spec allows for 8 digits plus the unit.
192		return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
193	}
194	unit := timeoutUnit(s[size-1])
195	d, ok := timeoutUnitToDuration(unit)
196	if !ok {
197		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
198	}
199	t, err := strconv.ParseInt(s[:size-1], 10, 64)
200	if err != nil {
201		return 0, err
202	}
203	const maxHours = math.MaxInt64 / int64(time.Hour)
204	if d == time.Hour && t > maxHours {
205		// This timeout would overflow math.MaxInt64; clamp it.
206		return time.Duration(math.MaxInt64), nil
207	}
208	return d * time.Duration(t), nil
209}
210
211const (
212	spaceByte   = ' '
213	tildeByte   = '~'
214	percentByte = '%'
215)
216
217// encodeGrpcMessage is used to encode status code in header field
218// "grpc-message". It does percent encoding and also replaces invalid utf-8
219// characters with Unicode replacement character.
220//
221// It checks to see if each individual byte in msg is an allowable byte, and
222// then either percent encoding or passing it through. When percent encoding,
223// the byte is converted into hexadecimal notation with a '%' prepended.
224func encodeGrpcMessage(msg string) string {
225	if msg == "" {
226		return ""
227	}
228	lenMsg := len(msg)
229	for i := 0; i < lenMsg; i++ {
230		c := msg[i]
231		if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
232			return encodeGrpcMessageUnchecked(msg)
233		}
234	}
235	return msg
236}
237
238func encodeGrpcMessageUnchecked(msg string) string {
239	var sb strings.Builder
240	for len(msg) > 0 {
241		r, size := utf8.DecodeRuneInString(msg)
242		for _, b := range []byte(string(r)) {
243			if size > 1 {
244				// If size > 1, r is not ascii. Always do percent encoding.
245				fmt.Fprintf(&sb, "%%%02X", b)
246				continue
247			}
248
249			// The for loop is necessary even if size == 1. r could be
250			// utf8.RuneError.
251			//
252			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
253			if b >= spaceByte && b <= tildeByte && b != percentByte {
254				sb.WriteByte(b)
255			} else {
256				fmt.Fprintf(&sb, "%%%02X", b)
257			}
258		}
259		msg = msg[size:]
260	}
261	return sb.String()
262}
263
264// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
265func decodeGrpcMessage(msg string) string {
266	if msg == "" {
267		return ""
268	}
269	lenMsg := len(msg)
270	for i := 0; i < lenMsg; i++ {
271		if msg[i] == percentByte && i+2 < lenMsg {
272			return decodeGrpcMessageUnchecked(msg)
273		}
274	}
275	return msg
276}
277
278func decodeGrpcMessageUnchecked(msg string) string {
279	var sb strings.Builder
280	lenMsg := len(msg)
281	for i := 0; i < lenMsg; i++ {
282		c := msg[i]
283		if c == percentByte && i+2 < lenMsg {
284			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
285			if err != nil {
286				sb.WriteByte(c)
287			} else {
288				sb.WriteByte(byte(parsed))
289				i += 2
290			}
291		} else {
292			sb.WriteByte(c)
293		}
294	}
295	return sb.String()
296}
297
298type bufWriter struct {
299	pool      *sync.Pool
300	buf       []byte
301	offset    int
302	batchSize int
303	conn      net.Conn
304	err       error
305}
306
307func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
308	w := &bufWriter{
309		batchSize: batchSize,
310		conn:      conn,
311		pool:      pool,
312	}
313	// this indicates that we should use non shared buf
314	if pool == nil {
315		w.buf = make([]byte, batchSize)
316	}
317	return w
318}
319
320func (w *bufWriter) Write(b []byte) (int, error) {
321	if w.err != nil {
322		return 0, w.err
323	}
324	if w.batchSize == 0 { // Buffer has been disabled.
325		n, err := w.conn.Write(b)
326		return n, toIOError(err)
327	}
328	if w.buf == nil {
329		b := w.pool.Get().(*[]byte)
330		w.buf = *b
331	}
332	written := 0
333	for len(b) > 0 {
334		copied := copy(w.buf[w.offset:], b)
335		b = b[copied:]
336		written += copied
337		w.offset += copied
338		if w.offset < w.batchSize {
339			continue
340		}
341		if err := w.flushKeepBuffer(); err != nil {
342			return written, err
343		}
344	}
345	return written, nil
346}
347
348func (w *bufWriter) Flush() error {
349	err := w.flushKeepBuffer()
350	// Only release the buffer if we are in a "shared" mode
351	if w.buf != nil && w.pool != nil {
352		b := w.buf
353		w.pool.Put(&b)
354		w.buf = nil
355	}
356	return err
357}
358
359func (w *bufWriter) flushKeepBuffer() error {
360	if w.err != nil {
361		return w.err
362	}
363	if w.offset == 0 {
364		return nil
365	}
366	_, w.err = w.conn.Write(w.buf[:w.offset])
367	w.err = toIOError(w.err)
368	w.offset = 0
369	return w.err
370}
371
372type ioError struct {
373	error
374}
375
376func (i ioError) Unwrap() error {
377	return i.error
378}
379
380func isIOError(err error) bool {
381	return errors.As(err, &ioError{})
382}
383
384func toIOError(err error) error {
385	if err == nil {
386		return nil
387	}
388	return ioError{error: err}
389}
390
391type framer struct {
392	writer *bufWriter
393	fr     *http2.Framer
394}
395
396var writeBufferPoolMap = make(map[int]*sync.Pool)
397var writeBufferMutex sync.Mutex
398
399func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
400	if writeBufferSize < 0 {
401		writeBufferSize = 0
402	}
403	var r io.Reader = conn
404	if readBufferSize > 0 {
405		r = bufio.NewReaderSize(r, readBufferSize)
406	}
407	var pool *sync.Pool
408	if sharedWriteBuffer {
409		pool = getWriteBufferPool(writeBufferSize)
410	}
411	w := newBufWriter(conn, writeBufferSize, pool)
412	f := &framer{
413		writer: w,
414		fr:     http2.NewFramer(w, r),
415	}
416	f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
417	// Opt-in to Frame reuse API on framer to reduce garbage.
418	// Frames aren't safe to read from after a subsequent call to ReadFrame.
419	f.fr.SetReuseFrames()
420	f.fr.MaxHeaderListSize = maxHeaderListSize
421	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
422	return f
423}
424
425func getWriteBufferPool(size int) *sync.Pool {
426	writeBufferMutex.Lock()
427	defer writeBufferMutex.Unlock()
428	pool, ok := writeBufferPoolMap[size]
429	if ok {
430		return pool
431	}
432	pool = &sync.Pool{
433		New: func() any {
434			b := make([]byte, size)
435			return &b
436		},
437	}
438	writeBufferPoolMap[size] = pool
439	return pool
440}
441
442// parseDialTarget returns the network and address to pass to dialer.
443func parseDialTarget(target string) (string, string) {
444	net := "tcp"
445	m1 := strings.Index(target, ":")
446	m2 := strings.Index(target, ":/")
447	// handle unix:addr which will fail with url.Parse
448	if m1 >= 0 && m2 < 0 {
449		if n := target[0:m1]; n == "unix" {
450			return n, target[m1+1:]
451		}
452	}
453	if m2 >= 0 {
454		t, err := url.Parse(target)
455		if err != nil {
456			return net, target
457		}
458		scheme := t.Scheme
459		addr := t.Path
460		if scheme == "unix" {
461			if addr == "" {
462				addr = t.Host
463			}
464			return scheme, addr
465		}
466	}
467	return net, target
468}