1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22 "context"
23 "errors"
24 "fmt"
25 "io"
26 "math"
27 "net"
28 "net/http"
29 "reflect"
30 "runtime"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "time"
35
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/encoding"
39 "google.golang.org/grpc/encoding/proto"
40 estats "google.golang.org/grpc/experimental/stats"
41 "google.golang.org/grpc/grpclog"
42 "google.golang.org/grpc/internal"
43 "google.golang.org/grpc/internal/binarylog"
44 "google.golang.org/grpc/internal/channelz"
45 "google.golang.org/grpc/internal/grpcsync"
46 "google.golang.org/grpc/internal/grpcutil"
47 istats "google.golang.org/grpc/internal/stats"
48 "google.golang.org/grpc/internal/transport"
49 "google.golang.org/grpc/keepalive"
50 "google.golang.org/grpc/mem"
51 "google.golang.org/grpc/metadata"
52 "google.golang.org/grpc/peer"
53 "google.golang.org/grpc/stats"
54 "google.golang.org/grpc/status"
55 "google.golang.org/grpc/tap"
56)
57
58const (
59 defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
60 defaultServerMaxSendMessageSize = math.MaxInt32
61
62 // Server transports are tracked in a map which is keyed on listener
63 // address. For regular gRPC traffic, connections are accepted in Serve()
64 // through a call to Accept(), and we use the actual listener address as key
65 // when we add it to the map. But for connections received through
66 // ServeHTTP(), we do not have a listener and hence use this dummy value.
67 listenerAddressForServeHTTP = "listenerAddressForServeHTTP"
68)
69
70func init() {
71 internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
72 return srv.opts.creds
73 }
74 internal.IsRegisteredMethod = func(srv *Server, method string) bool {
75 return srv.isRegisteredMethod(method)
76 }
77 internal.ServerFromContext = serverFromContext
78 internal.AddGlobalServerOptions = func(opt ...ServerOption) {
79 globalServerOptions = append(globalServerOptions, opt...)
80 }
81 internal.ClearGlobalServerOptions = func() {
82 globalServerOptions = nil
83 }
84 internal.BinaryLogger = binaryLogger
85 internal.JoinServerOptions = newJoinServerOption
86 internal.BufferPool = bufferPool
87 internal.MetricsRecorderForServer = func(srv *Server) estats.MetricsRecorder {
88 return istats.NewMetricsRecorderList(srv.opts.statsHandlers)
89 }
90}
91
92var statusOK = status.New(codes.OK, "")
93var logger = grpclog.Component("core")
94
95// MethodHandler is a function type that processes a unary RPC method call.
96type MethodHandler func(srv any, ctx context.Context, dec func(any) error, interceptor UnaryServerInterceptor) (any, error)
97
98// MethodDesc represents an RPC service's method specification.
99type MethodDesc struct {
100 MethodName string
101 Handler MethodHandler
102}
103
104// ServiceDesc represents an RPC service's specification.
105type ServiceDesc struct {
106 ServiceName string
107 // The pointer to the service interface. Used to check whether the user
108 // provided implementation satisfies the interface requirements.
109 HandlerType any
110 Methods []MethodDesc
111 Streams []StreamDesc
112 Metadata any
113}
114
115// serviceInfo wraps information about a service. It is very similar to
116// ServiceDesc and is constructed from it for internal purposes.
117type serviceInfo struct {
118 // Contains the implementation for the methods in this service.
119 serviceImpl any
120 methods map[string]*MethodDesc
121 streams map[string]*StreamDesc
122 mdata any
123}
124
125// Server is a gRPC server to serve RPC requests.
126type Server struct {
127 opts serverOptions
128
129 mu sync.Mutex // guards following
130 lis map[net.Listener]bool
131 // conns contains all active server transports. It is a map keyed on a
132 // listener address with the value being the set of active transports
133 // belonging to that listener.
134 conns map[string]map[transport.ServerTransport]bool
135 serve bool
136 drain bool
137 cv *sync.Cond // signaled when connections close for GracefulStop
138 services map[string]*serviceInfo // service name -> service info
139 events traceEventLog
140
141 quit *grpcsync.Event
142 done *grpcsync.Event
143 channelzRemoveOnce sync.Once
144 serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop
145 handlersWG sync.WaitGroup // counts active method handler goroutines
146
147 channelz *channelz.Server
148
149 serverWorkerChannel chan func()
150 serverWorkerChannelClose func()
151}
152
153type serverOptions struct {
154 creds credentials.TransportCredentials
155 codec baseCodec
156 cp Compressor
157 dc Decompressor
158 unaryInt UnaryServerInterceptor
159 streamInt StreamServerInterceptor
160 chainUnaryInts []UnaryServerInterceptor
161 chainStreamInts []StreamServerInterceptor
162 binaryLogger binarylog.Logger
163 inTapHandle tap.ServerInHandle
164 statsHandlers []stats.Handler
165 maxConcurrentStreams uint32
166 maxReceiveMessageSize int
167 maxSendMessageSize int
168 unknownStreamDesc *StreamDesc
169 keepaliveParams keepalive.ServerParameters
170 keepalivePolicy keepalive.EnforcementPolicy
171 initialWindowSize int32
172 initialConnWindowSize int32
173 writeBufferSize int
174 readBufferSize int
175 sharedWriteBuffer bool
176 connectionTimeout time.Duration
177 maxHeaderListSize *uint32
178 headerTableSize *uint32
179 numServerWorkers uint32
180 bufferPool mem.BufferPool
181 waitForHandlers bool
182}
183
184var defaultServerOptions = serverOptions{
185 maxConcurrentStreams: math.MaxUint32,
186 maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
187 maxSendMessageSize: defaultServerMaxSendMessageSize,
188 connectionTimeout: 120 * time.Second,
189 writeBufferSize: defaultWriteBufSize,
190 readBufferSize: defaultReadBufSize,
191 bufferPool: mem.DefaultBufferPool(),
192}
193var globalServerOptions []ServerOption
194
195// A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
196type ServerOption interface {
197 apply(*serverOptions)
198}
199
200// EmptyServerOption does not alter the server configuration. It can be embedded
201// in another structure to build custom server options.
202//
203// # Experimental
204//
205// Notice: This type is EXPERIMENTAL and may be changed or removed in a
206// later release.
207type EmptyServerOption struct{}
208
209func (EmptyServerOption) apply(*serverOptions) {}
210
211// funcServerOption wraps a function that modifies serverOptions into an
212// implementation of the ServerOption interface.
213type funcServerOption struct {
214 f func(*serverOptions)
215}
216
217func (fdo *funcServerOption) apply(do *serverOptions) {
218 fdo.f(do)
219}
220
221func newFuncServerOption(f func(*serverOptions)) *funcServerOption {
222 return &funcServerOption{
223 f: f,
224 }
225}
226
227// joinServerOption provides a way to combine arbitrary number of server
228// options into one.
229type joinServerOption struct {
230 opts []ServerOption
231}
232
233func (mdo *joinServerOption) apply(do *serverOptions) {
234 for _, opt := range mdo.opts {
235 opt.apply(do)
236 }
237}
238
239func newJoinServerOption(opts ...ServerOption) ServerOption {
240 return &joinServerOption{opts: opts}
241}
242
243// SharedWriteBuffer allows reusing per-connection transport write buffer.
244// If this option is set to true every connection will release the buffer after
245// flushing the data on the wire.
246//
247// # Experimental
248//
249// Notice: This API is EXPERIMENTAL and may be changed or removed in a
250// later release.
251func SharedWriteBuffer(val bool) ServerOption {
252 return newFuncServerOption(func(o *serverOptions) {
253 o.sharedWriteBuffer = val
254 })
255}
256
257// WriteBufferSize determines how much data can be batched before doing a write
258// on the wire. The default value for this buffer is 32KB. Zero or negative
259// values will disable the write buffer such that each write will be on underlying
260// connection. Note: A Send call may not directly translate to a write.
261func WriteBufferSize(s int) ServerOption {
262 return newFuncServerOption(func(o *serverOptions) {
263 o.writeBufferSize = s
264 })
265}
266
267// ReadBufferSize lets you set the size of read buffer, this determines how much
268// data can be read at most for one read syscall. The default value for this
269// buffer is 32KB. Zero or negative values will disable read buffer for a
270// connection so data framer can access the underlying conn directly.
271func ReadBufferSize(s int) ServerOption {
272 return newFuncServerOption(func(o *serverOptions) {
273 o.readBufferSize = s
274 })
275}
276
277// InitialWindowSize returns a ServerOption that sets window size for stream.
278// The lower bound for window size is 64K and any value smaller than that will be ignored.
279func InitialWindowSize(s int32) ServerOption {
280 return newFuncServerOption(func(o *serverOptions) {
281 o.initialWindowSize = s
282 })
283}
284
285// InitialConnWindowSize returns a ServerOption that sets window size for a connection.
286// The lower bound for window size is 64K and any value smaller than that will be ignored.
287func InitialConnWindowSize(s int32) ServerOption {
288 return newFuncServerOption(func(o *serverOptions) {
289 o.initialConnWindowSize = s
290 })
291}
292
293// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
294func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
295 if kp.Time > 0 && kp.Time < internal.KeepaliveMinServerPingTime {
296 logger.Warning("Adjusting keepalive ping interval to minimum period of 1s")
297 kp.Time = internal.KeepaliveMinServerPingTime
298 }
299
300 return newFuncServerOption(func(o *serverOptions) {
301 o.keepaliveParams = kp
302 })
303}
304
305// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
306func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
307 return newFuncServerOption(func(o *serverOptions) {
308 o.keepalivePolicy = kep
309 })
310}
311
312// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
313//
314// This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
315//
316// Deprecated: register codecs using encoding.RegisterCodec. The server will
317// automatically use registered codecs based on the incoming requests' headers.
318// See also
319// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec.
320// Will be supported throughout 1.x.
321func CustomCodec(codec Codec) ServerOption {
322 return newFuncServerOption(func(o *serverOptions) {
323 o.codec = newCodecV0Bridge(codec)
324 })
325}
326
327// ForceServerCodec returns a ServerOption that sets a codec for message
328// marshaling and unmarshaling.
329//
330// This will override any lookups by content-subtype for Codecs registered
331// with RegisterCodec.
332//
333// See Content-Type on
334// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
335// more details. Also see the documentation on RegisterCodec and
336// CallContentSubtype for more details on the interaction between encoding.Codec
337// and content-subtype.
338//
339// This function is provided for advanced users; prefer to register codecs
340// using encoding.RegisterCodec.
341// The server will automatically use registered codecs based on the incoming
342// requests' headers. See also
343// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec.
344// Will be supported throughout 1.x.
345//
346// # Experimental
347//
348// Notice: This API is EXPERIMENTAL and may be changed or removed in a
349// later release.
350func ForceServerCodec(codec encoding.Codec) ServerOption {
351 return newFuncServerOption(func(o *serverOptions) {
352 o.codec = newCodecV1Bridge(codec)
353 })
354}
355
356// ForceServerCodecV2 is the equivalent of ForceServerCodec, but for the new
357// CodecV2 interface.
358//
359// Will be supported throughout 1.x.
360//
361// # Experimental
362//
363// Notice: This API is EXPERIMENTAL and may be changed or removed in a
364// later release.
365func ForceServerCodecV2(codecV2 encoding.CodecV2) ServerOption {
366 return newFuncServerOption(func(o *serverOptions) {
367 o.codec = codecV2
368 })
369}
370
371// RPCCompressor returns a ServerOption that sets a compressor for outbound
372// messages. For backward compatibility, all outbound messages will be sent
373// using this compressor, regardless of incoming message compression. By
374// default, server messages will be sent using the same compressor with which
375// request messages were sent.
376//
377// Deprecated: use encoding.RegisterCompressor instead. Will be supported
378// throughout 1.x.
379func RPCCompressor(cp Compressor) ServerOption {
380 return newFuncServerOption(func(o *serverOptions) {
381 o.cp = cp
382 })
383}
384
385// RPCDecompressor returns a ServerOption that sets a decompressor for inbound
386// messages. It has higher priority than decompressors registered via
387// encoding.RegisterCompressor.
388//
389// Deprecated: use encoding.RegisterCompressor instead. Will be supported
390// throughout 1.x.
391func RPCDecompressor(dc Decompressor) ServerOption {
392 return newFuncServerOption(func(o *serverOptions) {
393 o.dc = dc
394 })
395}
396
397// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
398// If this is not set, gRPC uses the default limit.
399//
400// Deprecated: use MaxRecvMsgSize instead. Will be supported throughout 1.x.
401func MaxMsgSize(m int) ServerOption {
402 return MaxRecvMsgSize(m)
403}
404
405// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
406// If this is not set, gRPC uses the default 4MB.
407func MaxRecvMsgSize(m int) ServerOption {
408 return newFuncServerOption(func(o *serverOptions) {
409 o.maxReceiveMessageSize = m
410 })
411}
412
413// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
414// If this is not set, gRPC uses the default `math.MaxInt32`.
415func MaxSendMsgSize(m int) ServerOption {
416 return newFuncServerOption(func(o *serverOptions) {
417 o.maxSendMessageSize = m
418 })
419}
420
421// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
422// of concurrent streams to each ServerTransport.
423func MaxConcurrentStreams(n uint32) ServerOption {
424 if n == 0 {
425 n = math.MaxUint32
426 }
427 return newFuncServerOption(func(o *serverOptions) {
428 o.maxConcurrentStreams = n
429 })
430}
431
432// Creds returns a ServerOption that sets credentials for server connections.
433func Creds(c credentials.TransportCredentials) ServerOption {
434 return newFuncServerOption(func(o *serverOptions) {
435 o.creds = c
436 })
437}
438
439// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
440// server. Only one unary interceptor can be installed. The construction of multiple
441// interceptors (e.g., chaining) can be implemented at the caller.
442func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
443 return newFuncServerOption(func(o *serverOptions) {
444 if o.unaryInt != nil {
445 panic("The unary server interceptor was already set and may not be reset.")
446 }
447 o.unaryInt = i
448 })
449}
450
451// ChainUnaryInterceptor returns a ServerOption that specifies the chained interceptor
452// for unary RPCs. The first interceptor will be the outer most,
453// while the last interceptor will be the inner most wrapper around the real call.
454// All unary interceptors added by this method will be chained.
455func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption {
456 return newFuncServerOption(func(o *serverOptions) {
457 o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
458 })
459}
460
461// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
462// server. Only one stream interceptor can be installed.
463func StreamInterceptor(i StreamServerInterceptor) ServerOption {
464 return newFuncServerOption(func(o *serverOptions) {
465 if o.streamInt != nil {
466 panic("The stream server interceptor was already set and may not be reset.")
467 }
468 o.streamInt = i
469 })
470}
471
472// ChainStreamInterceptor returns a ServerOption that specifies the chained interceptor
473// for streaming RPCs. The first interceptor will be the outer most,
474// while the last interceptor will be the inner most wrapper around the real call.
475// All stream interceptors added by this method will be chained.
476func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption {
477 return newFuncServerOption(func(o *serverOptions) {
478 o.chainStreamInts = append(o.chainStreamInts, interceptors...)
479 })
480}
481
482// InTapHandle returns a ServerOption that sets the tap handle for all the server
483// transport to be created. Only one can be installed.
484//
485// # Experimental
486//
487// Notice: This API is EXPERIMENTAL and may be changed or removed in a
488// later release.
489func InTapHandle(h tap.ServerInHandle) ServerOption {
490 return newFuncServerOption(func(o *serverOptions) {
491 if o.inTapHandle != nil {
492 panic("The tap handle was already set and may not be reset.")
493 }
494 o.inTapHandle = h
495 })
496}
497
498// StatsHandler returns a ServerOption that sets the stats handler for the server.
499func StatsHandler(h stats.Handler) ServerOption {
500 return newFuncServerOption(func(o *serverOptions) {
501 if h == nil {
502 logger.Error("ignoring nil parameter in grpc.StatsHandler ServerOption")
503 // Do not allow a nil stats handler, which would otherwise cause
504 // panics.
505 return
506 }
507 o.statsHandlers = append(o.statsHandlers, h)
508 })
509}
510
511// binaryLogger returns a ServerOption that can set the binary logger for the
512// server.
513func binaryLogger(bl binarylog.Logger) ServerOption {
514 return newFuncServerOption(func(o *serverOptions) {
515 o.binaryLogger = bl
516 })
517}
518
519// UnknownServiceHandler returns a ServerOption that allows for adding a custom
520// unknown service handler. The provided method is a bidi-streaming RPC service
521// handler that will be invoked instead of returning the "unimplemented" gRPC
522// error whenever a request is received for an unregistered service or method.
523// The handling function and stream interceptor (if set) have full access to
524// the ServerStream, including its Context.
525func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
526 return newFuncServerOption(func(o *serverOptions) {
527 o.unknownStreamDesc = &StreamDesc{
528 StreamName: "unknown_service_handler",
529 Handler: streamHandler,
530 // We need to assume that the users of the streamHandler will want to use both.
531 ClientStreams: true,
532 ServerStreams: true,
533 }
534 })
535}
536
537// ConnectionTimeout returns a ServerOption that sets the timeout for
538// connection establishment (up to and including HTTP/2 handshaking) for all
539// new connections. If this is not set, the default is 120 seconds. A zero or
540// negative value will result in an immediate timeout.
541//
542// # Experimental
543//
544// Notice: This API is EXPERIMENTAL and may be changed or removed in a
545// later release.
546func ConnectionTimeout(d time.Duration) ServerOption {
547 return newFuncServerOption(func(o *serverOptions) {
548 o.connectionTimeout = d
549 })
550}
551
552// MaxHeaderListSizeServerOption is a ServerOption that sets the max
553// (uncompressed) size of header list that the server is prepared to accept.
554type MaxHeaderListSizeServerOption struct {
555 MaxHeaderListSize uint32
556}
557
558func (o MaxHeaderListSizeServerOption) apply(so *serverOptions) {
559 so.maxHeaderListSize = &o.MaxHeaderListSize
560}
561
562// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
563// of header list that the server is prepared to accept.
564func MaxHeaderListSize(s uint32) ServerOption {
565 return MaxHeaderListSizeServerOption{
566 MaxHeaderListSize: s,
567 }
568}
569
570// HeaderTableSize returns a ServerOption that sets the size of dynamic
571// header table for stream.
572//
573// # Experimental
574//
575// Notice: This API is EXPERIMENTAL and may be changed or removed in a
576// later release.
577func HeaderTableSize(s uint32) ServerOption {
578 return newFuncServerOption(func(o *serverOptions) {
579 o.headerTableSize = &s
580 })
581}
582
583// NumStreamWorkers returns a ServerOption that sets the number of worker
584// goroutines that should be used to process incoming streams. Setting this to
585// zero (default) will disable workers and spawn a new goroutine for each
586// stream.
587//
588// # Experimental
589//
590// Notice: This API is EXPERIMENTAL and may be changed or removed in a
591// later release.
592func NumStreamWorkers(numServerWorkers uint32) ServerOption {
593 // TODO: If/when this API gets stabilized (i.e. stream workers become the
594 // only way streams are processed), change the behavior of the zero value to
595 // a sane default. Preliminary experiments suggest that a value equal to the
596 // number of CPUs available is most performant; requires thorough testing.
597 return newFuncServerOption(func(o *serverOptions) {
598 o.numServerWorkers = numServerWorkers
599 })
600}
601
602// WaitForHandlers cause Stop to wait until all outstanding method handlers have
603// exited before returning. If false, Stop will return as soon as all
604// connections have closed, but method handlers may still be running. By
605// default, Stop does not wait for method handlers to return.
606//
607// # Experimental
608//
609// Notice: This API is EXPERIMENTAL and may be changed or removed in a
610// later release.
611func WaitForHandlers(w bool) ServerOption {
612 return newFuncServerOption(func(o *serverOptions) {
613 o.waitForHandlers = w
614 })
615}
616
617func bufferPool(bufferPool mem.BufferPool) ServerOption {
618 return newFuncServerOption(func(o *serverOptions) {
619 o.bufferPool = bufferPool
620 })
621}
622
623// serverWorkerResetThreshold defines how often the stack must be reset. Every
624// N requests, by spawning a new goroutine in its place, a worker can reset its
625// stack so that large stacks don't live in memory forever. 2^16 should allow
626// each goroutine stack to live for at least a few seconds in a typical
627// workload (assuming a QPS of a few thousand requests/sec).
628const serverWorkerResetThreshold = 1 << 16
629
630// serverWorker blocks on a *transport.ServerStream channel forever and waits
631// for data to be fed by serveStreams. This allows multiple requests to be
632// processed by the same goroutine, removing the need for expensive stack
633// re-allocations (see the runtime.morestack problem [1]).
634//
635// [1] https://github.com/golang/go/issues/18138
636func (s *Server) serverWorker() {
637 for completed := 0; completed < serverWorkerResetThreshold; completed++ {
638 f, ok := <-s.serverWorkerChannel
639 if !ok {
640 return
641 }
642 f()
643 }
644 go s.serverWorker()
645}
646
647// initServerWorkers creates worker goroutines and a channel to process incoming
648// connections to reduce the time spent overall on runtime.morestack.
649func (s *Server) initServerWorkers() {
650 s.serverWorkerChannel = make(chan func())
651 s.serverWorkerChannelClose = sync.OnceFunc(func() {
652 close(s.serverWorkerChannel)
653 })
654 for i := uint32(0); i < s.opts.numServerWorkers; i++ {
655 go s.serverWorker()
656 }
657}
658
659// NewServer creates a gRPC server which has no service registered and has not
660// started to accept requests yet.
661func NewServer(opt ...ServerOption) *Server {
662 opts := defaultServerOptions
663 for _, o := range globalServerOptions {
664 o.apply(&opts)
665 }
666 for _, o := range opt {
667 o.apply(&opts)
668 }
669 s := &Server{
670 lis: make(map[net.Listener]bool),
671 opts: opts,
672 conns: make(map[string]map[transport.ServerTransport]bool),
673 services: make(map[string]*serviceInfo),
674 quit: grpcsync.NewEvent(),
675 done: grpcsync.NewEvent(),
676 channelz: channelz.RegisterServer(""),
677 }
678 chainUnaryServerInterceptors(s)
679 chainStreamServerInterceptors(s)
680 s.cv = sync.NewCond(&s.mu)
681 if EnableTracing {
682 _, file, line, _ := runtime.Caller(1)
683 s.events = newTraceEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
684 }
685
686 if s.opts.numServerWorkers > 0 {
687 s.initServerWorkers()
688 }
689
690 channelz.Info(logger, s.channelz, "Server created")
691 return s
692}
693
694// printf records an event in s's event log, unless s has been stopped.
695// REQUIRES s.mu is held.
696func (s *Server) printf(format string, a ...any) {
697 if s.events != nil {
698 s.events.Printf(format, a...)
699 }
700}
701
702// errorf records an error in s's event log, unless s has been stopped.
703// REQUIRES s.mu is held.
704func (s *Server) errorf(format string, a ...any) {
705 if s.events != nil {
706 s.events.Errorf(format, a...)
707 }
708}
709
710// ServiceRegistrar wraps a single method that supports service registration. It
711// enables users to pass concrete types other than grpc.Server to the service
712// registration methods exported by the IDL generated code.
713type ServiceRegistrar interface {
714 // RegisterService registers a service and its implementation to the
715 // concrete type implementing this interface. It may not be called
716 // once the server has started serving.
717 // desc describes the service and its methods and handlers. impl is the
718 // service implementation which is passed to the method handlers.
719 RegisterService(desc *ServiceDesc, impl any)
720}
721
722// RegisterService registers a service and its implementation to the gRPC
723// server. It is called from the IDL generated code. This must be called before
724// invoking Serve. If ss is non-nil (for legacy code), its type is checked to
725// ensure it implements sd.HandlerType.
726func (s *Server) RegisterService(sd *ServiceDesc, ss any) {
727 if ss != nil {
728 ht := reflect.TypeOf(sd.HandlerType).Elem()
729 st := reflect.TypeOf(ss)
730 if !st.Implements(ht) {
731 logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
732 }
733 }
734 s.register(sd, ss)
735}
736
737func (s *Server) register(sd *ServiceDesc, ss any) {
738 s.mu.Lock()
739 defer s.mu.Unlock()
740 s.printf("RegisterService(%q)", sd.ServiceName)
741 if s.serve {
742 logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
743 }
744 if _, ok := s.services[sd.ServiceName]; ok {
745 logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
746 }
747 info := &serviceInfo{
748 serviceImpl: ss,
749 methods: make(map[string]*MethodDesc),
750 streams: make(map[string]*StreamDesc),
751 mdata: sd.Metadata,
752 }
753 for i := range sd.Methods {
754 d := &sd.Methods[i]
755 info.methods[d.MethodName] = d
756 }
757 for i := range sd.Streams {
758 d := &sd.Streams[i]
759 info.streams[d.StreamName] = d
760 }
761 s.services[sd.ServiceName] = info
762}
763
764// MethodInfo contains the information of an RPC including its method name and type.
765type MethodInfo struct {
766 // Name is the method name only, without the service name or package name.
767 Name string
768 // IsClientStream indicates whether the RPC is a client streaming RPC.
769 IsClientStream bool
770 // IsServerStream indicates whether the RPC is a server streaming RPC.
771 IsServerStream bool
772}
773
774// ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service.
775type ServiceInfo struct {
776 Methods []MethodInfo
777 // Metadata is the metadata specified in ServiceDesc when registering service.
778 Metadata any
779}
780
781// GetServiceInfo returns a map from service names to ServiceInfo.
782// Service names include the package names, in the form of <package>.<service>.
783func (s *Server) GetServiceInfo() map[string]ServiceInfo {
784 ret := make(map[string]ServiceInfo)
785 for n, srv := range s.services {
786 methods := make([]MethodInfo, 0, len(srv.methods)+len(srv.streams))
787 for m := range srv.methods {
788 methods = append(methods, MethodInfo{
789 Name: m,
790 IsClientStream: false,
791 IsServerStream: false,
792 })
793 }
794 for m, d := range srv.streams {
795 methods = append(methods, MethodInfo{
796 Name: m,
797 IsClientStream: d.ClientStreams,
798 IsServerStream: d.ServerStreams,
799 })
800 }
801
802 ret[n] = ServiceInfo{
803 Methods: methods,
804 Metadata: srv.mdata,
805 }
806 }
807 return ret
808}
809
810// ErrServerStopped indicates that the operation is now illegal because of
811// the server being stopped.
812var ErrServerStopped = errors.New("grpc: the server has been stopped")
813
814type listenSocket struct {
815 net.Listener
816 channelz *channelz.Socket
817}
818
819func (l *listenSocket) Close() error {
820 err := l.Listener.Close()
821 channelz.RemoveEntry(l.channelz.ID)
822 channelz.Info(logger, l.channelz, "ListenSocket deleted")
823 return err
824}
825
826// Serve accepts incoming connections on the listener lis, creating a new
827// ServerTransport and service goroutine for each. The service goroutines
828// read gRPC requests and then call the registered handlers to reply to them.
829// Serve returns when lis.Accept fails with fatal errors. lis will be closed when
830// this method returns.
831// Serve will return a non-nil error unless Stop or GracefulStop is called.
832//
833// Note: All supported releases of Go (as of December 2023) override the OS
834// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive
835// with OS defaults for keepalive time and interval, callers need to do the
836// following two things:
837// - pass a net.Listener created by calling the Listen method on a
838// net.ListenConfig with the `KeepAlive` field set to a negative value. This
839// will result in the Go standard library not overriding OS defaults for TCP
840// keepalive interval and time. But this will also result in the Go standard
841// library not enabling TCP keepalives by default.
842// - override the Accept method on the passed in net.Listener and set the
843// SO_KEEPALIVE socket option to enable TCP keepalives, with OS defaults.
844func (s *Server) Serve(lis net.Listener) error {
845 s.mu.Lock()
846 s.printf("serving")
847 s.serve = true
848 if s.lis == nil {
849 // Serve called after Stop or GracefulStop.
850 s.mu.Unlock()
851 lis.Close()
852 return ErrServerStopped
853 }
854
855 s.serveWG.Add(1)
856 defer func() {
857 s.serveWG.Done()
858 if s.quit.HasFired() {
859 // Stop or GracefulStop called; block until done and return nil.
860 <-s.done.Done()
861 }
862 }()
863
864 ls := &listenSocket{
865 Listener: lis,
866 channelz: channelz.RegisterSocket(&channelz.Socket{
867 SocketType: channelz.SocketTypeListen,
868 Parent: s.channelz,
869 RefName: lis.Addr().String(),
870 LocalAddr: lis.Addr(),
871 SocketOptions: channelz.GetSocketOption(lis)},
872 ),
873 }
874 s.lis[ls] = true
875
876 defer func() {
877 s.mu.Lock()
878 if s.lis != nil && s.lis[ls] {
879 ls.Close()
880 delete(s.lis, ls)
881 }
882 s.mu.Unlock()
883 }()
884
885 s.mu.Unlock()
886 channelz.Info(logger, ls.channelz, "ListenSocket created")
887
888 var tempDelay time.Duration // how long to sleep on accept failure
889 for {
890 rawConn, err := lis.Accept()
891 if err != nil {
892 if ne, ok := err.(interface {
893 Temporary() bool
894 }); ok && ne.Temporary() {
895 if tempDelay == 0 {
896 tempDelay = 5 * time.Millisecond
897 } else {
898 tempDelay *= 2
899 }
900 if max := 1 * time.Second; tempDelay > max {
901 tempDelay = max
902 }
903 s.mu.Lock()
904 s.printf("Accept error: %v; retrying in %v", err, tempDelay)
905 s.mu.Unlock()
906 timer := time.NewTimer(tempDelay)
907 select {
908 case <-timer.C:
909 case <-s.quit.Done():
910 timer.Stop()
911 return nil
912 }
913 continue
914 }
915 s.mu.Lock()
916 s.printf("done serving; Accept = %v", err)
917 s.mu.Unlock()
918
919 if s.quit.HasFired() {
920 return nil
921 }
922 return err
923 }
924 tempDelay = 0
925 // Start a new goroutine to deal with rawConn so we don't stall this Accept
926 // loop goroutine.
927 //
928 // Make sure we account for the goroutine so GracefulStop doesn't nil out
929 // s.conns before this conn can be added.
930 s.serveWG.Add(1)
931 go func() {
932 s.handleRawConn(lis.Addr().String(), rawConn)
933 s.serveWG.Done()
934 }()
935 }
936}
937
938// handleRawConn forks a goroutine to handle a just-accepted connection that
939// has not had any I/O performed on it yet.
940func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
941 if s.quit.HasFired() {
942 rawConn.Close()
943 return
944 }
945 rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
946
947 // Finish handshaking (HTTP2)
948 st := s.newHTTP2Transport(rawConn)
949 rawConn.SetDeadline(time.Time{})
950 if st == nil {
951 return
952 }
953
954 if cc, ok := rawConn.(interface {
955 PassServerTransport(transport.ServerTransport)
956 }); ok {
957 cc.PassServerTransport(st)
958 }
959
960 if !s.addConn(lisAddr, st) {
961 return
962 }
963 go func() {
964 s.serveStreams(context.Background(), st, rawConn)
965 s.removeConn(lisAddr, st)
966 }()
967}
968
969// newHTTP2Transport sets up a http/2 transport (using the
970// gRPC http2 server transport in transport/http2_server.go).
971func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
972 config := &transport.ServerConfig{
973 MaxStreams: s.opts.maxConcurrentStreams,
974 ConnectionTimeout: s.opts.connectionTimeout,
975 Credentials: s.opts.creds,
976 InTapHandle: s.opts.inTapHandle,
977 StatsHandlers: s.opts.statsHandlers,
978 KeepaliveParams: s.opts.keepaliveParams,
979 KeepalivePolicy: s.opts.keepalivePolicy,
980 InitialWindowSize: s.opts.initialWindowSize,
981 InitialConnWindowSize: s.opts.initialConnWindowSize,
982 WriteBufferSize: s.opts.writeBufferSize,
983 ReadBufferSize: s.opts.readBufferSize,
984 SharedWriteBuffer: s.opts.sharedWriteBuffer,
985 ChannelzParent: s.channelz,
986 MaxHeaderListSize: s.opts.maxHeaderListSize,
987 HeaderTableSize: s.opts.headerTableSize,
988 BufferPool: s.opts.bufferPool,
989 }
990 st, err := transport.NewServerTransport(c, config)
991 if err != nil {
992 s.mu.Lock()
993 s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
994 s.mu.Unlock()
995 // ErrConnDispatched means that the connection was dispatched away from
996 // gRPC; those connections should be left open.
997 if err != credentials.ErrConnDispatched {
998 // Don't log on ErrConnDispatched and io.EOF to prevent log spam.
999 if err != io.EOF {
1000 channelz.Info(logger, s.channelz, "grpc: Server.Serve failed to create ServerTransport: ", err)
1001 }
1002 c.Close()
1003 }
1004 return nil
1005 }
1006
1007 return st
1008}
1009
1010func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
1011 ctx = transport.SetConnection(ctx, rawConn)
1012 ctx = peer.NewContext(ctx, st.Peer())
1013 for _, sh := range s.opts.statsHandlers {
1014 ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
1015 RemoteAddr: st.Peer().Addr,
1016 LocalAddr: st.Peer().LocalAddr,
1017 })
1018 sh.HandleConn(ctx, &stats.ConnBegin{})
1019 }
1020
1021 defer func() {
1022 st.Close(errors.New("finished serving streams for the server transport"))
1023 for _, sh := range s.opts.statsHandlers {
1024 sh.HandleConn(ctx, &stats.ConnEnd{})
1025 }
1026 }()
1027
1028 streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
1029 st.HandleStreams(ctx, func(stream *transport.ServerStream) {
1030 s.handlersWG.Add(1)
1031 streamQuota.acquire()
1032 f := func() {
1033 defer streamQuota.release()
1034 defer s.handlersWG.Done()
1035 s.handleStream(st, stream)
1036 }
1037
1038 if s.opts.numServerWorkers > 0 {
1039 select {
1040 case s.serverWorkerChannel <- f:
1041 return
1042 default:
1043 // If all stream workers are busy, fallback to the default code path.
1044 }
1045 }
1046 go f()
1047 })
1048}
1049
1050var _ http.Handler = (*Server)(nil)
1051
1052// ServeHTTP implements the Go standard library's http.Handler
1053// interface by responding to the gRPC request r, by looking up
1054// the requested gRPC method in the gRPC server s.
1055//
1056// The provided HTTP request must have arrived on an HTTP/2
1057// connection. When using the Go standard library's server,
1058// practically this means that the Request must also have arrived
1059// over TLS.
1060//
1061// To share one port (such as 443 for https) between gRPC and an
1062// existing http.Handler, use a root http.Handler such as:
1063//
1064// if r.ProtoMajor == 2 && strings.HasPrefix(
1065// r.Header.Get("Content-Type"), "application/grpc") {
1066// grpcServer.ServeHTTP(w, r)
1067// } else {
1068// yourMux.ServeHTTP(w, r)
1069// }
1070//
1071// Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally
1072// separate from grpc-go's HTTP/2 server. Performance and features may vary
1073// between the two paths. ServeHTTP does not support some gRPC features
1074// available through grpc-go's HTTP/2 server.
1075//
1076// # Experimental
1077//
1078// Notice: This API is EXPERIMENTAL and may be changed or removed in a
1079// later release.
1080func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
1081 st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool)
1082 if err != nil {
1083 // Errors returned from transport.NewServerHandlerTransport have
1084 // already been written to w.
1085 return
1086 }
1087 if !s.addConn(listenerAddressForServeHTTP, st) {
1088 return
1089 }
1090 defer s.removeConn(listenerAddressForServeHTTP, st)
1091 s.serveStreams(r.Context(), st, nil)
1092}
1093
1094func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
1095 s.mu.Lock()
1096 defer s.mu.Unlock()
1097 if s.conns == nil {
1098 st.Close(errors.New("Server.addConn called when server has already been stopped"))
1099 return false
1100 }
1101 if s.drain {
1102 // Transport added after we drained our existing conns: drain it
1103 // immediately.
1104 st.Drain("")
1105 }
1106
1107 if s.conns[addr] == nil {
1108 // Create a map entry if this is the first connection on this listener.
1109 s.conns[addr] = make(map[transport.ServerTransport]bool)
1110 }
1111 s.conns[addr][st] = true
1112 return true
1113}
1114
1115func (s *Server) removeConn(addr string, st transport.ServerTransport) {
1116 s.mu.Lock()
1117 defer s.mu.Unlock()
1118
1119 conns := s.conns[addr]
1120 if conns != nil {
1121 delete(conns, st)
1122 if len(conns) == 0 {
1123 // If the last connection for this address is being removed, also
1124 // remove the map entry corresponding to the address. This is used
1125 // in GracefulStop() when waiting for all connections to be closed.
1126 delete(s.conns, addr)
1127 }
1128 s.cv.Broadcast()
1129 }
1130}
1131
1132func (s *Server) incrCallsStarted() {
1133 s.channelz.ServerMetrics.CallsStarted.Add(1)
1134 s.channelz.ServerMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
1135}
1136
1137func (s *Server) incrCallsSucceeded() {
1138 s.channelz.ServerMetrics.CallsSucceeded.Add(1)
1139}
1140
1141func (s *Server) incrCallsFailed() {
1142 s.channelz.ServerMetrics.CallsFailed.Add(1)
1143}
1144
1145func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor) error {
1146 data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
1147 if err != nil {
1148 channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
1149 return err
1150 }
1151
1152 compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
1153 if err != nil {
1154 data.Free()
1155 channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
1156 return err
1157 }
1158
1159 hdr, payload := msgHeader(data, compData, pf)
1160
1161 defer func() {
1162 compData.Free()
1163 data.Free()
1164 // payload does not need to be freed here, it is either data or compData, both of
1165 // which are already freed.
1166 }()
1167
1168 dataLen := data.Len()
1169 payloadLen := payload.Len()
1170 // TODO(dfawley): should we be checking len(data) instead?
1171 if payloadLen > s.opts.maxSendMessageSize {
1172 return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
1173 }
1174 err = stream.Write(hdr, payload, opts)
1175 if err == nil {
1176 if len(s.opts.statsHandlers) != 0 {
1177 for _, sh := range s.opts.statsHandlers {
1178 sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
1179 }
1180 }
1181 }
1182 return err
1183}
1184
1185// chainUnaryServerInterceptors chains all unary server interceptors into one.
1186func chainUnaryServerInterceptors(s *Server) {
1187 // Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will
1188 // be executed before any other chained interceptors.
1189 interceptors := s.opts.chainUnaryInts
1190 if s.opts.unaryInt != nil {
1191 interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...)
1192 }
1193
1194 var chainedInt UnaryServerInterceptor
1195 if len(interceptors) == 0 {
1196 chainedInt = nil
1197 } else if len(interceptors) == 1 {
1198 chainedInt = interceptors[0]
1199 } else {
1200 chainedInt = chainUnaryInterceptors(interceptors)
1201 }
1202
1203 s.opts.unaryInt = chainedInt
1204}
1205
1206func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
1207 return func(ctx context.Context, req any, info *UnaryServerInfo, handler UnaryHandler) (any, error) {
1208 return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
1209 }
1210}
1211
1212func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
1213 if curr == len(interceptors)-1 {
1214 return finalHandler
1215 }
1216 return func(ctx context.Context, req any) (any, error) {
1217 return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
1218 }
1219}
1220
1221func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
1222 shs := s.opts.statsHandlers
1223 if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
1224 if channelz.IsOn() {
1225 s.incrCallsStarted()
1226 }
1227 var statsBegin *stats.Begin
1228 for _, sh := range shs {
1229 beginTime := time.Now()
1230 statsBegin = &stats.Begin{
1231 BeginTime: beginTime,
1232 IsClientStream: false,
1233 IsServerStream: false,
1234 }
1235 sh.HandleRPC(ctx, statsBegin)
1236 }
1237 if trInfo != nil {
1238 trInfo.tr.LazyLog(&trInfo.firstLine, false)
1239 }
1240 // The deferred error handling for tracing, stats handler and channelz are
1241 // combined into one function to reduce stack usage -- a defer takes ~56-64
1242 // bytes on the stack, so overflowing the stack will require a stack
1243 // re-allocation, which is expensive.
1244 //
1245 // To maintain behavior similar to separate deferred statements, statements
1246 // should be executed in the reverse order. That is, tracing first, stats
1247 // handler second, and channelz last. Note that panics *within* defers will
1248 // lead to different behavior, but that's an acceptable compromise; that
1249 // would be undefined behavior territory anyway.
1250 defer func() {
1251 if trInfo != nil {
1252 if err != nil && err != io.EOF {
1253 trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1254 trInfo.tr.SetError()
1255 }
1256 trInfo.tr.Finish()
1257 }
1258
1259 for _, sh := range shs {
1260 end := &stats.End{
1261 BeginTime: statsBegin.BeginTime,
1262 EndTime: time.Now(),
1263 }
1264 if err != nil && err != io.EOF {
1265 end.Error = toRPCErr(err)
1266 }
1267 sh.HandleRPC(ctx, end)
1268 }
1269
1270 if channelz.IsOn() {
1271 if err != nil && err != io.EOF {
1272 s.incrCallsFailed()
1273 } else {
1274 s.incrCallsSucceeded()
1275 }
1276 }
1277 }()
1278 }
1279 var binlogs []binarylog.MethodLogger
1280 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil {
1281 binlogs = append(binlogs, ml)
1282 }
1283 if s.opts.binaryLogger != nil {
1284 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil {
1285 binlogs = append(binlogs, ml)
1286 }
1287 }
1288 if len(binlogs) != 0 {
1289 md, _ := metadata.FromIncomingContext(ctx)
1290 logEntry := &binarylog.ClientHeader{
1291 Header: md,
1292 MethodName: stream.Method(),
1293 PeerAddr: nil,
1294 }
1295 if deadline, ok := ctx.Deadline(); ok {
1296 logEntry.Timeout = time.Until(deadline)
1297 if logEntry.Timeout < 0 {
1298 logEntry.Timeout = 0
1299 }
1300 }
1301 if a := md[":authority"]; len(a) > 0 {
1302 logEntry.Authority = a[0]
1303 }
1304 if peer, ok := peer.FromContext(ctx); ok {
1305 logEntry.PeerAddr = peer.Addr
1306 }
1307 for _, binlog := range binlogs {
1308 binlog.Log(ctx, logEntry)
1309 }
1310 }
1311
1312 // comp and cp are used for compression. decomp and dc are used for
1313 // decompression. If comp and decomp are both set, they are the same;
1314 // however they are kept separate to ensure that at most one of the
1315 // compressor/decompressor variable pairs are set for use later.
1316 var comp, decomp encoding.Compressor
1317 var cp Compressor
1318 var dc Decompressor
1319 var sendCompressorName string
1320
1321 // If dc is set and matches the stream's compression, use it. Otherwise, try
1322 // to find a matching registered compressor for decomp.
1323 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
1324 dc = s.opts.dc
1325 } else if rc != "" && rc != encoding.Identity {
1326 decomp = encoding.GetCompressor(rc)
1327 if decomp == nil {
1328 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
1329 stream.WriteStatus(st)
1330 return st.Err()
1331 }
1332 }
1333
1334 // If cp is set, use it. Otherwise, attempt to compress the response using
1335 // the incoming message compression method.
1336 //
1337 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
1338 if s.opts.cp != nil {
1339 cp = s.opts.cp
1340 sendCompressorName = cp.Type()
1341 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
1342 // Legacy compressor not specified; attempt to respond with same encoding.
1343 comp = encoding.GetCompressor(rc)
1344 if comp != nil {
1345 sendCompressorName = comp.Name()
1346 }
1347 }
1348
1349 if sendCompressorName != "" {
1350 if err := stream.SetSendCompress(sendCompressorName); err != nil {
1351 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
1352 }
1353 }
1354
1355 var payInfo *payloadInfo
1356 if len(shs) != 0 || len(binlogs) != 0 {
1357 payInfo = &payloadInfo{}
1358 defer payInfo.free()
1359 }
1360
1361 d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
1362 if err != nil {
1363 if e := stream.WriteStatus(status.Convert(err)); e != nil {
1364 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1365 }
1366 return err
1367 }
1368 freed := false
1369 dataFree := func() {
1370 if !freed {
1371 d.Free()
1372 freed = true
1373 }
1374 }
1375 defer dataFree()
1376 df := func(v any) error {
1377 defer dataFree()
1378 if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
1379 return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
1380 }
1381
1382 for _, sh := range shs {
1383 sh.HandleRPC(ctx, &stats.InPayload{
1384 RecvTime: time.Now(),
1385 Payload: v,
1386 Length: d.Len(),
1387 WireLength: payInfo.compressedLength + headerLen,
1388 CompressedLength: payInfo.compressedLength,
1389 })
1390 }
1391 if len(binlogs) != 0 {
1392 cm := &binarylog.ClientMessage{
1393 Message: d.Materialize(),
1394 }
1395 for _, binlog := range binlogs {
1396 binlog.Log(ctx, cm)
1397 }
1398 }
1399 if trInfo != nil {
1400 trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
1401 }
1402 return nil
1403 }
1404 ctx = NewContextWithServerTransportStream(ctx, stream)
1405 reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt)
1406 if appErr != nil {
1407 appStatus, ok := status.FromError(appErr)
1408 if !ok {
1409 // Convert non-status application error to a status error with code
1410 // Unknown, but handle context errors specifically.
1411 appStatus = status.FromContextError(appErr)
1412 appErr = appStatus.Err()
1413 }
1414 if trInfo != nil {
1415 trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1416 trInfo.tr.SetError()
1417 }
1418 if e := stream.WriteStatus(appStatus); e != nil {
1419 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1420 }
1421 if len(binlogs) != 0 {
1422 if h, _ := stream.Header(); h.Len() > 0 {
1423 // Only log serverHeader if there was header. Otherwise it can
1424 // be trailer only.
1425 sh := &binarylog.ServerHeader{
1426 Header: h,
1427 }
1428 for _, binlog := range binlogs {
1429 binlog.Log(ctx, sh)
1430 }
1431 }
1432 st := &binarylog.ServerTrailer{
1433 Trailer: stream.Trailer(),
1434 Err: appErr,
1435 }
1436 for _, binlog := range binlogs {
1437 binlog.Log(ctx, st)
1438 }
1439 }
1440 return appErr
1441 }
1442 if trInfo != nil {
1443 trInfo.tr.LazyLog(stringer("OK"), false)
1444 }
1445 opts := &transport.WriteOptions{Last: true}
1446
1447 // Server handler could have set new compressor by calling SetSendCompressor.
1448 // In case it is set, we need to use it for compressing outbound message.
1449 if stream.SendCompress() != sendCompressorName {
1450 comp = encoding.GetCompressor(stream.SendCompress())
1451 }
1452 if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil {
1453 if err == io.EOF {
1454 // The entire stream is done (for unary RPC only).
1455 return err
1456 }
1457 if sts, ok := status.FromError(err); ok {
1458 if e := stream.WriteStatus(sts); e != nil {
1459 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1460 }
1461 } else {
1462 switch st := err.(type) {
1463 case transport.ConnectionError:
1464 // Nothing to do here.
1465 default:
1466 panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
1467 }
1468 }
1469 if len(binlogs) != 0 {
1470 h, _ := stream.Header()
1471 sh := &binarylog.ServerHeader{
1472 Header: h,
1473 }
1474 st := &binarylog.ServerTrailer{
1475 Trailer: stream.Trailer(),
1476 Err: appErr,
1477 }
1478 for _, binlog := range binlogs {
1479 binlog.Log(ctx, sh)
1480 binlog.Log(ctx, st)
1481 }
1482 }
1483 return err
1484 }
1485 if len(binlogs) != 0 {
1486 h, _ := stream.Header()
1487 sh := &binarylog.ServerHeader{
1488 Header: h,
1489 }
1490 sm := &binarylog.ServerMessage{
1491 Message: reply,
1492 }
1493 for _, binlog := range binlogs {
1494 binlog.Log(ctx, sh)
1495 binlog.Log(ctx, sm)
1496 }
1497 }
1498 if trInfo != nil {
1499 trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
1500 }
1501 // TODO: Should we be logging if writing status failed here, like above?
1502 // Should the logging be in WriteStatus? Should we ignore the WriteStatus
1503 // error or allow the stats handler to see it?
1504 if len(binlogs) != 0 {
1505 st := &binarylog.ServerTrailer{
1506 Trailer: stream.Trailer(),
1507 Err: appErr,
1508 }
1509 for _, binlog := range binlogs {
1510 binlog.Log(ctx, st)
1511 }
1512 }
1513 return stream.WriteStatus(statusOK)
1514}
1515
1516// chainStreamServerInterceptors chains all stream server interceptors into one.
1517func chainStreamServerInterceptors(s *Server) {
1518 // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will
1519 // be executed before any other chained interceptors.
1520 interceptors := s.opts.chainStreamInts
1521 if s.opts.streamInt != nil {
1522 interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...)
1523 }
1524
1525 var chainedInt StreamServerInterceptor
1526 if len(interceptors) == 0 {
1527 chainedInt = nil
1528 } else if len(interceptors) == 1 {
1529 chainedInt = interceptors[0]
1530 } else {
1531 chainedInt = chainStreamInterceptors(interceptors)
1532 }
1533
1534 s.opts.streamInt = chainedInt
1535}
1536
1537func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
1538 return func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
1539 return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
1540 }
1541}
1542
1543func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
1544 if curr == len(interceptors)-1 {
1545 return finalHandler
1546 }
1547 return func(srv any, stream ServerStream) error {
1548 return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
1549 }
1550}
1551
1552func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) {
1553 if channelz.IsOn() {
1554 s.incrCallsStarted()
1555 }
1556 shs := s.opts.statsHandlers
1557 var statsBegin *stats.Begin
1558 if len(shs) != 0 {
1559 beginTime := time.Now()
1560 statsBegin = &stats.Begin{
1561 BeginTime: beginTime,
1562 IsClientStream: sd.ClientStreams,
1563 IsServerStream: sd.ServerStreams,
1564 }
1565 for _, sh := range shs {
1566 sh.HandleRPC(ctx, statsBegin)
1567 }
1568 }
1569 ctx = NewContextWithServerTransportStream(ctx, stream)
1570 ss := &serverStream{
1571 ctx: ctx,
1572 s: stream,
1573 p: &parser{r: stream, bufferPool: s.opts.bufferPool},
1574 codec: s.getCodec(stream.ContentSubtype()),
1575 maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
1576 maxSendMessageSize: s.opts.maxSendMessageSize,
1577 trInfo: trInfo,
1578 statsHandler: shs,
1579 }
1580
1581 if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
1582 // See comment in processUnaryRPC on defers.
1583 defer func() {
1584 if trInfo != nil {
1585 ss.mu.Lock()
1586 if err != nil && err != io.EOF {
1587 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1588 ss.trInfo.tr.SetError()
1589 }
1590 ss.trInfo.tr.Finish()
1591 ss.trInfo.tr = nil
1592 ss.mu.Unlock()
1593 }
1594
1595 if len(shs) != 0 {
1596 end := &stats.End{
1597 BeginTime: statsBegin.BeginTime,
1598 EndTime: time.Now(),
1599 }
1600 if err != nil && err != io.EOF {
1601 end.Error = toRPCErr(err)
1602 }
1603 for _, sh := range shs {
1604 sh.HandleRPC(ctx, end)
1605 }
1606 }
1607
1608 if channelz.IsOn() {
1609 if err != nil && err != io.EOF {
1610 s.incrCallsFailed()
1611 } else {
1612 s.incrCallsSucceeded()
1613 }
1614 }
1615 }()
1616 }
1617
1618 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil {
1619 ss.binlogs = append(ss.binlogs, ml)
1620 }
1621 if s.opts.binaryLogger != nil {
1622 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil {
1623 ss.binlogs = append(ss.binlogs, ml)
1624 }
1625 }
1626 if len(ss.binlogs) != 0 {
1627 md, _ := metadata.FromIncomingContext(ctx)
1628 logEntry := &binarylog.ClientHeader{
1629 Header: md,
1630 MethodName: stream.Method(),
1631 PeerAddr: nil,
1632 }
1633 if deadline, ok := ctx.Deadline(); ok {
1634 logEntry.Timeout = time.Until(deadline)
1635 if logEntry.Timeout < 0 {
1636 logEntry.Timeout = 0
1637 }
1638 }
1639 if a := md[":authority"]; len(a) > 0 {
1640 logEntry.Authority = a[0]
1641 }
1642 if peer, ok := peer.FromContext(ss.Context()); ok {
1643 logEntry.PeerAddr = peer.Addr
1644 }
1645 for _, binlog := range ss.binlogs {
1646 binlog.Log(ctx, logEntry)
1647 }
1648 }
1649
1650 // If dc is set and matches the stream's compression, use it. Otherwise, try
1651 // to find a matching registered compressor for decomp.
1652 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
1653 ss.decompressorV0 = s.opts.dc
1654 } else if rc != "" && rc != encoding.Identity {
1655 ss.decompressorV1 = encoding.GetCompressor(rc)
1656 if ss.decompressorV1 == nil {
1657 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
1658 ss.s.WriteStatus(st)
1659 return st.Err()
1660 }
1661 }
1662
1663 // If cp is set, use it. Otherwise, attempt to compress the response using
1664 // the incoming message compression method.
1665 //
1666 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
1667 if s.opts.cp != nil {
1668 ss.compressorV0 = s.opts.cp
1669 ss.sendCompressorName = s.opts.cp.Type()
1670 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
1671 // Legacy compressor not specified; attempt to respond with same encoding.
1672 ss.compressorV1 = encoding.GetCompressor(rc)
1673 if ss.compressorV1 != nil {
1674 ss.sendCompressorName = rc
1675 }
1676 }
1677
1678 if ss.sendCompressorName != "" {
1679 if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
1680 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
1681 }
1682 }
1683
1684 ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.compressorV0, ss.compressorV1)
1685
1686 if trInfo != nil {
1687 trInfo.tr.LazyLog(&trInfo.firstLine, false)
1688 }
1689 var appErr error
1690 var server any
1691 if info != nil {
1692 server = info.serviceImpl
1693 }
1694 if s.opts.streamInt == nil {
1695 appErr = sd.Handler(server, ss)
1696 } else {
1697 info := &StreamServerInfo{
1698 FullMethod: stream.Method(),
1699 IsClientStream: sd.ClientStreams,
1700 IsServerStream: sd.ServerStreams,
1701 }
1702 appErr = s.opts.streamInt(server, ss, info, sd.Handler)
1703 }
1704 if appErr != nil {
1705 appStatus, ok := status.FromError(appErr)
1706 if !ok {
1707 // Convert non-status application error to a status error with code
1708 // Unknown, but handle context errors specifically.
1709 appStatus = status.FromContextError(appErr)
1710 appErr = appStatus.Err()
1711 }
1712 if trInfo != nil {
1713 ss.mu.Lock()
1714 ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1715 ss.trInfo.tr.SetError()
1716 ss.mu.Unlock()
1717 }
1718 if len(ss.binlogs) != 0 {
1719 st := &binarylog.ServerTrailer{
1720 Trailer: ss.s.Trailer(),
1721 Err: appErr,
1722 }
1723 for _, binlog := range ss.binlogs {
1724 binlog.Log(ctx, st)
1725 }
1726 }
1727 ss.s.WriteStatus(appStatus)
1728 // TODO: Should we log an error from WriteStatus here and below?
1729 return appErr
1730 }
1731 if trInfo != nil {
1732 ss.mu.Lock()
1733 ss.trInfo.tr.LazyLog(stringer("OK"), false)
1734 ss.mu.Unlock()
1735 }
1736 if len(ss.binlogs) != 0 {
1737 st := &binarylog.ServerTrailer{
1738 Trailer: ss.s.Trailer(),
1739 Err: appErr,
1740 }
1741 for _, binlog := range ss.binlogs {
1742 binlog.Log(ctx, st)
1743 }
1744 }
1745 return ss.s.WriteStatus(statusOK)
1746}
1747
1748func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) {
1749 ctx := stream.Context()
1750 ctx = contextWithServer(ctx, s)
1751 var ti *traceInfo
1752 if EnableTracing {
1753 tr := newTrace("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
1754 ctx = newTraceContext(ctx, tr)
1755 ti = &traceInfo{
1756 tr: tr,
1757 firstLine: firstLine{
1758 client: false,
1759 remoteAddr: t.Peer().Addr,
1760 },
1761 }
1762 if dl, ok := ctx.Deadline(); ok {
1763 ti.firstLine.deadline = time.Until(dl)
1764 }
1765 }
1766
1767 sm := stream.Method()
1768 if sm != "" && sm[0] == '/' {
1769 sm = sm[1:]
1770 }
1771 pos := strings.LastIndex(sm, "/")
1772 if pos == -1 {
1773 if ti != nil {
1774 ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true)
1775 ti.tr.SetError()
1776 }
1777 errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
1778 if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
1779 if ti != nil {
1780 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1781 ti.tr.SetError()
1782 }
1783 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
1784 }
1785 if ti != nil {
1786 ti.tr.Finish()
1787 }
1788 return
1789 }
1790 service := sm[:pos]
1791 method := sm[pos+1:]
1792
1793 // FromIncomingContext is expensive: skip if there are no statsHandlers
1794 if len(s.opts.statsHandlers) > 0 {
1795 md, _ := metadata.FromIncomingContext(ctx)
1796 for _, sh := range s.opts.statsHandlers {
1797 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
1798 sh.HandleRPC(ctx, &stats.InHeader{
1799 FullMethod: stream.Method(),
1800 RemoteAddr: t.Peer().Addr,
1801 LocalAddr: t.Peer().LocalAddr,
1802 Compression: stream.RecvCompress(),
1803 WireLength: stream.HeaderWireLength(),
1804 Header: md,
1805 })
1806 }
1807 }
1808 // To have calls in stream callouts work. Will delete once all stats handler
1809 // calls come from the gRPC layer.
1810 stream.SetContext(ctx)
1811
1812 srv, knownService := s.services[service]
1813 if knownService {
1814 if md, ok := srv.methods[method]; ok {
1815 s.processUnaryRPC(ctx, stream, srv, md, ti)
1816 return
1817 }
1818 if sd, ok := srv.streams[method]; ok {
1819 s.processStreamingRPC(ctx, stream, srv, sd, ti)
1820 return
1821 }
1822 }
1823 // Unknown service, or known server unknown method.
1824 if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
1825 s.processStreamingRPC(ctx, stream, nil, unknownDesc, ti)
1826 return
1827 }
1828 var errDesc string
1829 if !knownService {
1830 errDesc = fmt.Sprintf("unknown service %v", service)
1831 } else {
1832 errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
1833 }
1834 if ti != nil {
1835 ti.tr.LazyPrintf("%s", errDesc)
1836 ti.tr.SetError()
1837 }
1838 if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
1839 if ti != nil {
1840 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1841 ti.tr.SetError()
1842 }
1843 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
1844 }
1845 if ti != nil {
1846 ti.tr.Finish()
1847 }
1848}
1849
1850// The key to save ServerTransportStream in the context.
1851type streamKey struct{}
1852
1853// NewContextWithServerTransportStream creates a new context from ctx and
1854// attaches stream to it.
1855//
1856// # Experimental
1857//
1858// Notice: This API is EXPERIMENTAL and may be changed or removed in a
1859// later release.
1860func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
1861 return context.WithValue(ctx, streamKey{}, stream)
1862}
1863
1864// ServerTransportStream is a minimal interface that a transport stream must
1865// implement. This can be used to mock an actual transport stream for tests of
1866// handler code that use, for example, grpc.SetHeader (which requires some
1867// stream to be in context).
1868//
1869// See also NewContextWithServerTransportStream.
1870//
1871// # Experimental
1872//
1873// Notice: This type is EXPERIMENTAL and may be changed or removed in a
1874// later release.
1875type ServerTransportStream interface {
1876 Method() string
1877 SetHeader(md metadata.MD) error
1878 SendHeader(md metadata.MD) error
1879 SetTrailer(md metadata.MD) error
1880}
1881
1882// ServerTransportStreamFromContext returns the ServerTransportStream saved in
1883// ctx. Returns nil if the given context has no stream associated with it
1884// (which implies it is not an RPC invocation context).
1885//
1886// # Experimental
1887//
1888// Notice: This API is EXPERIMENTAL and may be changed or removed in a
1889// later release.
1890func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
1891 s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
1892 return s
1893}
1894
1895// Stop stops the gRPC server. It immediately closes all open
1896// connections and listeners.
1897// It cancels all active RPCs on the server side and the corresponding
1898// pending RPCs on the client side will get notified by connection
1899// errors.
1900func (s *Server) Stop() {
1901 s.stop(false)
1902}
1903
1904// GracefulStop stops the gRPC server gracefully. It stops the server from
1905// accepting new connections and RPCs and blocks until all the pending RPCs are
1906// finished.
1907func (s *Server) GracefulStop() {
1908 s.stop(true)
1909}
1910
1911func (s *Server) stop(graceful bool) {
1912 s.quit.Fire()
1913 defer s.done.Fire()
1914
1915 s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelz.ID) })
1916 s.mu.Lock()
1917 s.closeListenersLocked()
1918 // Wait for serving threads to be ready to exit. Only then can we be sure no
1919 // new conns will be created.
1920 s.mu.Unlock()
1921 s.serveWG.Wait()
1922
1923 s.mu.Lock()
1924 defer s.mu.Unlock()
1925
1926 if graceful {
1927 s.drainAllServerTransportsLocked()
1928 } else {
1929 s.closeServerTransportsLocked()
1930 }
1931
1932 for len(s.conns) != 0 {
1933 s.cv.Wait()
1934 }
1935 s.conns = nil
1936
1937 if s.opts.numServerWorkers > 0 {
1938 // Closing the channel (only once, via sync.OnceFunc) after all the
1939 // connections have been closed above ensures that there are no
1940 // goroutines executing the callback passed to st.HandleStreams (where
1941 // the channel is written to).
1942 s.serverWorkerChannelClose()
1943 }
1944
1945 if graceful || s.opts.waitForHandlers {
1946 s.handlersWG.Wait()
1947 }
1948
1949 if s.events != nil {
1950 s.events.Finish()
1951 s.events = nil
1952 }
1953}
1954
1955// s.mu must be held by the caller.
1956func (s *Server) closeServerTransportsLocked() {
1957 for _, conns := range s.conns {
1958 for st := range conns {
1959 st.Close(errors.New("Server.Stop called"))
1960 }
1961 }
1962}
1963
1964// s.mu must be held by the caller.
1965func (s *Server) drainAllServerTransportsLocked() {
1966 if !s.drain {
1967 for _, conns := range s.conns {
1968 for st := range conns {
1969 st.Drain("graceful_stop")
1970 }
1971 }
1972 s.drain = true
1973 }
1974}
1975
1976// s.mu must be held by the caller.
1977func (s *Server) closeListenersLocked() {
1978 for lis := range s.lis {
1979 lis.Close()
1980 }
1981 s.lis = nil
1982}
1983
1984// contentSubtype must be lowercase
1985// cannot return nil
1986func (s *Server) getCodec(contentSubtype string) baseCodec {
1987 if s.opts.codec != nil {
1988 return s.opts.codec
1989 }
1990 if contentSubtype == "" {
1991 return getCodec(proto.Name)
1992 }
1993 codec := getCodec(contentSubtype)
1994 if codec == nil {
1995 logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
1996 return getCodec(proto.Name)
1997 }
1998 return codec
1999}
2000
2001type serverKey struct{}
2002
2003// serverFromContext gets the Server from the context.
2004func serverFromContext(ctx context.Context) *Server {
2005 s, _ := ctx.Value(serverKey{}).(*Server)
2006 return s
2007}
2008
2009// contextWithServer sets the Server in the context.
2010func contextWithServer(ctx context.Context, server *Server) context.Context {
2011 return context.WithValue(ctx, serverKey{}, server)
2012}
2013
2014// isRegisteredMethod returns whether the passed in method is registered as a
2015// method on the server. /service/method and service/method will match if the
2016// service and method are registered on the server.
2017func (s *Server) isRegisteredMethod(serviceMethod string) bool {
2018 if serviceMethod != "" && serviceMethod[0] == '/' {
2019 serviceMethod = serviceMethod[1:]
2020 }
2021 pos := strings.LastIndex(serviceMethod, "/")
2022 if pos == -1 { // Invalid method name syntax.
2023 return false
2024 }
2025 service := serviceMethod[:pos]
2026 method := serviceMethod[pos+1:]
2027 srv, knownService := s.services[service]
2028 if knownService {
2029 if _, ok := srv.methods[method]; ok {
2030 return true
2031 }
2032 if _, ok := srv.streams[method]; ok {
2033 return true
2034 }
2035 }
2036 return false
2037}
2038
2039// SetHeader sets the header metadata to be sent from the server to the client.
2040// The context provided must be the context passed to the server's handler.
2041//
2042// Streaming RPCs should prefer the SetHeader method of the ServerStream.
2043//
2044// When called multiple times, all the provided metadata will be merged. All
2045// the metadata will be sent out when one of the following happens:
2046//
2047// - grpc.SendHeader is called, or for streaming handlers, stream.SendHeader.
2048// - The first response message is sent. For unary handlers, this occurs when
2049// the handler returns; for streaming handlers, this can happen when stream's
2050// SendMsg method is called.
2051// - An RPC status is sent out (error or success). This occurs when the handler
2052// returns.
2053//
2054// SetHeader will fail if called after any of the events above.
2055//
2056// The error returned is compatible with the status package. However, the
2057// status code will often not match the RPC status as seen by the client
2058// application, and therefore, should not be relied upon for this purpose.
2059func SetHeader(ctx context.Context, md metadata.MD) error {
2060 if md.Len() == 0 {
2061 return nil
2062 }
2063 stream := ServerTransportStreamFromContext(ctx)
2064 if stream == nil {
2065 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2066 }
2067 return stream.SetHeader(md)
2068}
2069
2070// SendHeader sends header metadata. It may be called at most once, and may not
2071// be called after any event that causes headers to be sent (see SetHeader for
2072// a complete list). The provided md and headers set by SetHeader() will be
2073// sent.
2074//
2075// The error returned is compatible with the status package. However, the
2076// status code will often not match the RPC status as seen by the client
2077// application, and therefore, should not be relied upon for this purpose.
2078func SendHeader(ctx context.Context, md metadata.MD) error {
2079 stream := ServerTransportStreamFromContext(ctx)
2080 if stream == nil {
2081 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2082 }
2083 if err := stream.SendHeader(md); err != nil {
2084 return toRPCErr(err)
2085 }
2086 return nil
2087}
2088
2089// SetSendCompressor sets a compressor for outbound messages from the server.
2090// It must not be called after any event that causes headers to be sent
2091// (see ServerStream.SetHeader for the complete list). Provided compressor is
2092// used when below conditions are met:
2093//
2094// - compressor is registered via encoding.RegisterCompressor
2095// - compressor name must exist in the client advertised compressor names
2096// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to
2097// get client supported compressor names.
2098//
2099// The context provided must be the context passed to the server's handler.
2100// It must be noted that compressor name encoding.Identity disables the
2101// outbound compression.
2102// By default, server messages will be sent using the same compressor with
2103// which request messages were sent.
2104//
2105// It is not safe to call SetSendCompressor concurrently with SendHeader and
2106// SendMsg.
2107//
2108// # Experimental
2109//
2110// Notice: This function is EXPERIMENTAL and may be changed or removed in a
2111// later release.
2112func SetSendCompressor(ctx context.Context, name string) error {
2113 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
2114 if !ok || stream == nil {
2115 return fmt.Errorf("failed to fetch the stream from the given context")
2116 }
2117
2118 if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
2119 return fmt.Errorf("unable to set send compressor: %w", err)
2120 }
2121
2122 return stream.SetSendCompress(name)
2123}
2124
2125// ClientSupportedCompressors returns compressor names advertised by the client
2126// via grpc-accept-encoding header.
2127//
2128// The context provided must be the context passed to the server's handler.
2129//
2130// # Experimental
2131//
2132// Notice: This function is EXPERIMENTAL and may be changed or removed in a
2133// later release.
2134func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
2135 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
2136 if !ok || stream == nil {
2137 return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
2138 }
2139
2140 return stream.ClientAdvertisedCompressors(), nil
2141}
2142
2143// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
2144// When called more than once, all the provided metadata will be merged.
2145//
2146// The error returned is compatible with the status package. However, the
2147// status code will often not match the RPC status as seen by the client
2148// application, and therefore, should not be relied upon for this purpose.
2149func SetTrailer(ctx context.Context, md metadata.MD) error {
2150 if md.Len() == 0 {
2151 return nil
2152 }
2153 stream := ServerTransportStreamFromContext(ctx)
2154 if stream == nil {
2155 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2156 }
2157 return stream.SetTrailer(md)
2158}
2159
2160// Method returns the method string for the server context. The returned
2161// string is in the format of "/service/method".
2162func Method(ctx context.Context) (string, bool) {
2163 s := ServerTransportStreamFromContext(ctx)
2164 if s == nil {
2165 return "", false
2166 }
2167 return s.Method(), true
2168}
2169
2170// validateSendCompressor returns an error when given compressor name cannot be
2171// handled by the server or the client based on the advertised compressors.
2172func validateSendCompressor(name string, clientCompressors []string) error {
2173 if name == encoding.Identity {
2174 return nil
2175 }
2176
2177 if !grpcutil.IsCompressorNameRegistered(name) {
2178 return fmt.Errorf("compressor not registered %q", name)
2179 }
2180
2181 for _, c := range clientCompressors {
2182 if c == name {
2183 return nil // found match
2184 }
2185 }
2186 return fmt.Errorf("client does not support compressor %q", name)
2187}
2188
2189// atomicSemaphore implements a blocking, counting semaphore. acquire should be
2190// called synchronously; release may be called asynchronously.
2191type atomicSemaphore struct {
2192 n atomic.Int64
2193 wait chan struct{}
2194}
2195
2196func (q *atomicSemaphore) acquire() {
2197 if q.n.Add(-1) < 0 {
2198 // We ran out of quota. Block until a release happens.
2199 <-q.wait
2200 }
2201}
2202
2203func (q *atomicSemaphore) release() {
2204 // N.B. the "<= 0" check below should allow for this to work with multiple
2205 // concurrent calls to acquire, but also note that with synchronous calls to
2206 // acquire, as our system does, n will never be less than -1. There are
2207 // fairness issues (queuing) to consider if this was to be generalized.
2208 if q.n.Add(1) <= 0 {
2209 // An acquire was waiting on us. Unblock it.
2210 q.wait <- struct{}{}
2211 }
2212}
2213
2214func newHandlerQuota(n uint32) *atomicSemaphore {
2215 a := &atomicSemaphore{wait: make(chan struct{}, 1)}
2216 a.n.Store(int64(n))
2217 return a
2218}