handler_server.go

  1/*
  2 *
  3 * Copyright 2016 gRPC authors.
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     http://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// This file is the implementation of a gRPC server using HTTP/2 which
 20// uses the standard Go http2 Server implementation (via the
 21// http.Handler interface), rather than speaking low-level HTTP/2
 22// frames itself. It is the implementation of *grpc.Server.ServeHTTP.
 23
 24package transport
 25
 26import (
 27	"context"
 28	"errors"
 29	"fmt"
 30	"io"
 31	"net"
 32	"net/http"
 33	"strings"
 34	"sync"
 35	"time"
 36
 37	"golang.org/x/net/http2"
 38	"google.golang.org/grpc/codes"
 39	"google.golang.org/grpc/credentials"
 40	"google.golang.org/grpc/internal/grpclog"
 41	"google.golang.org/grpc/internal/grpcutil"
 42	"google.golang.org/grpc/mem"
 43	"google.golang.org/grpc/metadata"
 44	"google.golang.org/grpc/peer"
 45	"google.golang.org/grpc/stats"
 46	"google.golang.org/grpc/status"
 47	"google.golang.org/protobuf/proto"
 48)
 49
 50// NewServerHandlerTransport returns a ServerTransport handling gRPC from
 51// inside an http.Handler, or writes an HTTP error to w and returns an error.
 52// It requires that the http Server supports HTTP/2.
 53func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
 54	if r.Method != http.MethodPost {
 55		w.Header().Set("Allow", http.MethodPost)
 56		msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
 57		http.Error(w, msg, http.StatusMethodNotAllowed)
 58		return nil, errors.New(msg)
 59	}
 60	contentType := r.Header.Get("Content-Type")
 61	// TODO: do we assume contentType is lowercase? we did before
 62	contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
 63	if !validContentType {
 64		msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
 65		http.Error(w, msg, http.StatusUnsupportedMediaType)
 66		return nil, errors.New(msg)
 67	}
 68	if r.ProtoMajor != 2 {
 69		msg := "gRPC requires HTTP/2"
 70		http.Error(w, msg, http.StatusHTTPVersionNotSupported)
 71		return nil, errors.New(msg)
 72	}
 73	if _, ok := w.(http.Flusher); !ok {
 74		msg := "gRPC requires a ResponseWriter supporting http.Flusher"
 75		http.Error(w, msg, http.StatusInternalServerError)
 76		return nil, errors.New(msg)
 77	}
 78
 79	var localAddr net.Addr
 80	if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
 81		localAddr, _ = la.(net.Addr)
 82	}
 83	var authInfo credentials.AuthInfo
 84	if r.TLS != nil {
 85		authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
 86	}
 87	p := peer.Peer{
 88		Addr:      strAddr(r.RemoteAddr),
 89		LocalAddr: localAddr,
 90		AuthInfo:  authInfo,
 91	}
 92	st := &serverHandlerTransport{
 93		rw:             w,
 94		req:            r,
 95		closedCh:       make(chan struct{}),
 96		writes:         make(chan func()),
 97		peer:           p,
 98		contentType:    contentType,
 99		contentSubtype: contentSubtype,
100		stats:          stats,
101		bufferPool:     bufferPool,
102	}
103	st.logger = prefixLoggerForServerHandlerTransport(st)
104
105	if v := r.Header.Get("grpc-timeout"); v != "" {
106		to, err := decodeTimeout(v)
107		if err != nil {
108			msg := fmt.Sprintf("malformed grpc-timeout: %v", err)
109			http.Error(w, msg, http.StatusBadRequest)
110			return nil, status.Error(codes.Internal, msg)
111		}
112		st.timeoutSet = true
113		st.timeout = to
114	}
115
116	metakv := []string{"content-type", contentType}
117	if r.Host != "" {
118		metakv = append(metakv, ":authority", r.Host)
119	}
120	for k, vv := range r.Header {
121		k = strings.ToLower(k)
122		if isReservedHeader(k) && !isWhitelistedHeader(k) {
123			continue
124		}
125		for _, v := range vv {
126			v, err := decodeMetadataHeader(k, v)
127			if err != nil {
128				msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err)
129				http.Error(w, msg, http.StatusBadRequest)
130				return nil, status.Error(codes.Internal, msg)
131			}
132			metakv = append(metakv, k, v)
133		}
134	}
135	st.headerMD = metadata.Pairs(metakv...)
136
137	return st, nil
138}
139
140// serverHandlerTransport is an implementation of ServerTransport
141// which replies to exactly one gRPC request (exactly one HTTP request),
142// using the net/http.Handler interface. This http.Handler is guaranteed
143// at this point to be speaking over HTTP/2, so it's able to speak valid
144// gRPC.
145type serverHandlerTransport struct {
146	rw         http.ResponseWriter
147	req        *http.Request
148	timeoutSet bool
149	timeout    time.Duration
150
151	headerMD metadata.MD
152
153	peer peer.Peer
154
155	closeOnce sync.Once
156	closedCh  chan struct{} // closed on Close
157
158	// writes is a channel of code to run serialized in the
159	// ServeHTTP (HandleStreams) goroutine. The channel is closed
160	// when WriteStatus is called.
161	writes chan func()
162
163	// block concurrent WriteStatus calls
164	// e.g. grpc/(*serverStream).SendMsg/RecvMsg
165	writeStatusMu sync.Mutex
166
167	// we just mirror the request content-type
168	contentType string
169	// we store both contentType and contentSubtype so we don't keep recreating them
170	// TODO make sure this is consistent across handler_server and http2_server
171	contentSubtype string
172
173	stats  []stats.Handler
174	logger *grpclog.PrefixLogger
175
176	bufferPool mem.BufferPool
177}
178
179func (ht *serverHandlerTransport) Close(err error) {
180	ht.closeOnce.Do(func() {
181		if ht.logger.V(logLevel) {
182			ht.logger.Infof("Closing: %v", err)
183		}
184		close(ht.closedCh)
185	})
186}
187
188func (ht *serverHandlerTransport) Peer() *peer.Peer {
189	return &peer.Peer{
190		Addr:      ht.peer.Addr,
191		LocalAddr: ht.peer.LocalAddr,
192		AuthInfo:  ht.peer.AuthInfo,
193	}
194}
195
196// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
197// the empty string if unknown.
198type strAddr string
199
200func (a strAddr) Network() string {
201	if a != "" {
202		// Per the documentation on net/http.Request.RemoteAddr, if this is
203		// set, it's set to the IP:port of the peer (hence, TCP):
204		// https://golang.org/pkg/net/http/#Request
205		//
206		// If we want to support Unix sockets later, we can
207		// add our own grpc-specific convention within the
208		// grpc codebase to set RemoteAddr to a different
209		// format, or probably better: we can attach it to the
210		// context and use that from serverHandlerTransport.RemoteAddr.
211		return "tcp"
212	}
213	return ""
214}
215
216func (a strAddr) String() string { return string(a) }
217
218// do runs fn in the ServeHTTP goroutine.
219func (ht *serverHandlerTransport) do(fn func()) error {
220	select {
221	case <-ht.closedCh:
222		return ErrConnClosing
223	case ht.writes <- fn:
224		return nil
225	}
226}
227
228func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status) error {
229	ht.writeStatusMu.Lock()
230	defer ht.writeStatusMu.Unlock()
231
232	headersWritten := s.updateHeaderSent()
233	err := ht.do(func() {
234		if !headersWritten {
235			ht.writePendingHeaders(s)
236		}
237
238		// And flush, in case no header or body has been sent yet.
239		// This forces a separation of headers and trailers if this is the
240		// first call (for example, in end2end tests's TestNoService).
241		ht.rw.(http.Flusher).Flush()
242
243		h := ht.rw.Header()
244		h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code()))
245		if m := st.Message(); m != "" {
246			h.Set("Grpc-Message", encodeGrpcMessage(m))
247		}
248
249		s.hdrMu.Lock()
250		defer s.hdrMu.Unlock()
251		if p := st.Proto(); p != nil && len(p.Details) > 0 {
252			delete(s.trailer, grpcStatusDetailsBinHeader)
253			stBytes, err := proto.Marshal(p)
254			if err != nil {
255				// TODO: return error instead, when callers are able to handle it.
256				panic(err)
257			}
258
259			h.Set(grpcStatusDetailsBinHeader, encodeBinHeader(stBytes))
260		}
261
262		if len(s.trailer) > 0 {
263			for k, vv := range s.trailer {
264				// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
265				if isReservedHeader(k) {
266					continue
267				}
268				for _, v := range vv {
269					// http2 ResponseWriter mechanism to send undeclared Trailers after
270					// the headers have possibly been written.
271					h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v))
272				}
273			}
274		}
275	})
276
277	if err == nil { // transport has not been closed
278		// Note: The trailer fields are compressed with hpack after this call returns.
279		// No WireLength field is set here.
280		for _, sh := range ht.stats {
281			sh.HandleRPC(s.Context(), &stats.OutTrailer{
282				Trailer: s.trailer.Copy(),
283			})
284		}
285	}
286	ht.Close(errors.New("finished writing status"))
287	return err
288}
289
290// writePendingHeaders sets common and custom headers on the first
291// write call (Write, WriteHeader, or WriteStatus)
292func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
293	ht.writeCommonHeaders(s)
294	ht.writeCustomHeaders(s)
295}
296
297// writeCommonHeaders sets common headers on the first write
298// call (Write, WriteHeader, or WriteStatus).
299func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
300	h := ht.rw.Header()
301	h["Date"] = nil // suppress Date to make tests happy; TODO: restore
302	h.Set("Content-Type", ht.contentType)
303
304	// Predeclare trailers we'll set later in WriteStatus (after the body).
305	// This is a SHOULD in the HTTP RFC, and the way you add (known)
306	// Trailers per the net/http.ResponseWriter contract.
307	// See https://golang.org/pkg/net/http/#ResponseWriter
308	// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
309	h.Add("Trailer", "Grpc-Status")
310	h.Add("Trailer", "Grpc-Message")
311	h.Add("Trailer", "Grpc-Status-Details-Bin")
312
313	if s.sendCompress != "" {
314		h.Set("Grpc-Encoding", s.sendCompress)
315	}
316}
317
318// writeCustomHeaders sets custom headers set on the stream via SetHeader
319// on the first write call (Write, WriteHeader, or WriteStatus)
320func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
321	h := ht.rw.Header()
322
323	s.hdrMu.Lock()
324	for k, vv := range s.header {
325		if isReservedHeader(k) {
326			continue
327		}
328		for _, v := range vv {
329			h.Add(k, encodeMetadataHeader(k, v))
330		}
331	}
332
333	s.hdrMu.Unlock()
334}
335
336func (ht *serverHandlerTransport) write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *WriteOptions) error {
337	// Always take a reference because otherwise there is no guarantee the data will
338	// be available after this function returns. This is what callers to Write
339	// expect.
340	data.Ref()
341	headersWritten := s.updateHeaderSent()
342	err := ht.do(func() {
343		defer data.Free()
344		if !headersWritten {
345			ht.writePendingHeaders(s)
346		}
347		ht.rw.Write(hdr)
348		for _, b := range data {
349			_, _ = ht.rw.Write(b.ReadOnlyData())
350		}
351		ht.rw.(http.Flusher).Flush()
352	})
353	if err != nil {
354		data.Free()
355		return err
356	}
357	return nil
358}
359
360func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) error {
361	if err := s.SetHeader(md); err != nil {
362		return err
363	}
364
365	headersWritten := s.updateHeaderSent()
366	err := ht.do(func() {
367		if !headersWritten {
368			ht.writePendingHeaders(s)
369		}
370
371		ht.rw.WriteHeader(200)
372		ht.rw.(http.Flusher).Flush()
373	})
374
375	if err == nil {
376		for _, sh := range ht.stats {
377			// Note: The header fields are compressed with hpack after this call returns.
378			// No WireLength field is set here.
379			sh.HandleRPC(s.Context(), &stats.OutHeader{
380				Header:      md.Copy(),
381				Compression: s.sendCompress,
382			})
383		}
384	}
385	return err
386}
387
388func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
389	// With this transport type there will be exactly 1 stream: this HTTP request.
390	var cancel context.CancelFunc
391	if ht.timeoutSet {
392		ctx, cancel = context.WithTimeout(ctx, ht.timeout)
393	} else {
394		ctx, cancel = context.WithCancel(ctx)
395	}
396
397	// requestOver is closed when the status has been written via WriteStatus.
398	requestOver := make(chan struct{})
399	go func() {
400		select {
401		case <-requestOver:
402		case <-ht.closedCh:
403		case <-ht.req.Context().Done():
404		}
405		cancel()
406		ht.Close(errors.New("request is done processing"))
407	}()
408
409	ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
410	req := ht.req
411	s := &ServerStream{
412		Stream: &Stream{
413			id:             0, // irrelevant
414			ctx:            ctx,
415			requestRead:    func(int) {},
416			buf:            newRecvBuffer(),
417			method:         req.URL.Path,
418			recvCompress:   req.Header.Get("grpc-encoding"),
419			contentSubtype: ht.contentSubtype,
420		},
421		cancel:           cancel,
422		st:               ht,
423		headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
424	}
425	s.trReader = &transportReader{
426		reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
427		windowHandler: func(int) {},
428	}
429
430	// readerDone is closed when the Body.Read-ing goroutine exits.
431	readerDone := make(chan struct{})
432	go func() {
433		defer close(readerDone)
434
435		for {
436			buf := ht.bufferPool.Get(http2MaxFrameLen)
437			n, err := req.Body.Read(*buf)
438			if n > 0 {
439				*buf = (*buf)[:n]
440				s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
441			} else {
442				ht.bufferPool.Put(buf)
443			}
444			if err != nil {
445				s.buf.put(recvMsg{err: mapRecvMsgError(err)})
446				return
447			}
448		}
449	}()
450
451	// startStream is provided by the *grpc.Server's serveStreams.
452	// It starts a goroutine serving s and exits immediately.
453	// The goroutine that is started is the one that then calls
454	// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
455	startStream(s)
456
457	ht.runStream()
458	close(requestOver)
459
460	// Wait for reading goroutine to finish.
461	req.Body.Close()
462	<-readerDone
463}
464
465func (ht *serverHandlerTransport) runStream() {
466	for {
467		select {
468		case fn := <-ht.writes:
469			fn()
470		case <-ht.closedCh:
471			return
472		}
473	}
474}
475
476func (ht *serverHandlerTransport) incrMsgRecv() {}
477
478func (ht *serverHandlerTransport) Drain(string) {
479	panic("Drain() is not implemented")
480}
481
482// mapRecvMsgError returns the non-nil err into the appropriate
483// error value as expected by callers of *grpc.parser.recvMsg.
484// In particular, in can only be:
485//   - io.EOF
486//   - io.ErrUnexpectedEOF
487//   - of type transport.ConnectionError
488//   - an error from the status package
489func mapRecvMsgError(err error) error {
490	if err == io.EOF || err == io.ErrUnexpectedEOF {
491		return err
492	}
493	if se, ok := err.(http2.StreamError); ok {
494		if code, ok := http2ErrConvTab[se.Code]; ok {
495			return status.Error(code, se.Error())
496		}
497	}
498	if strings.Contains(err.Error(), "body closed by handler") {
499		return status.Error(codes.Canceled, err.Error())
500	}
501	return connectionErrorf(true, err, "%s", err.Error())
502}