handler.go

  1// Copyright The OpenTelemetry Authors
  2// SPDX-License-Identifier: Apache-2.0
  3
  4package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
  5
  6import (
  7	"net/http"
  8	"time"
  9
 10	"github.com/felixge/httpsnoop"
 11
 12	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request"
 13	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv"
 14	"go.opentelemetry.io/otel"
 15	"go.opentelemetry.io/otel/propagation"
 16	"go.opentelemetry.io/otel/trace"
 17)
 18
 19// middleware is an http middleware which wraps the next handler in a span.
 20type middleware struct {
 21	operation string
 22	server    string
 23
 24	tracer            trace.Tracer
 25	propagators       propagation.TextMapPropagator
 26	spanStartOptions  []trace.SpanStartOption
 27	readEvent         bool
 28	writeEvent        bool
 29	filters           []Filter
 30	spanNameFormatter func(string, *http.Request) string
 31	publicEndpoint    bool
 32	publicEndpointFn  func(*http.Request) bool
 33
 34	semconv semconv.HTTPServer
 35}
 36
 37func defaultHandlerFormatter(operation string, _ *http.Request) string {
 38	return operation
 39}
 40
 41// NewHandler wraps the passed handler in a span named after the operation and
 42// enriches it with metrics.
 43func NewHandler(handler http.Handler, operation string, opts ...Option) http.Handler {
 44	return NewMiddleware(operation, opts...)(handler)
 45}
 46
 47// NewMiddleware returns a tracing and metrics instrumentation middleware.
 48// The handler returned by the middleware wraps a handler
 49// in a span named after the operation and enriches it with metrics.
 50func NewMiddleware(operation string, opts ...Option) func(http.Handler) http.Handler {
 51	h := middleware{
 52		operation: operation,
 53	}
 54
 55	defaultOpts := []Option{
 56		WithSpanOptions(trace.WithSpanKind(trace.SpanKindServer)),
 57		WithSpanNameFormatter(defaultHandlerFormatter),
 58	}
 59
 60	c := newConfig(append(defaultOpts, opts...)...)
 61	h.configure(c)
 62
 63	return func(next http.Handler) http.Handler {
 64		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 65			h.serveHTTP(w, r, next)
 66		})
 67	}
 68}
 69
 70func (h *middleware) configure(c *config) {
 71	h.tracer = c.Tracer
 72	h.propagators = c.Propagators
 73	h.spanStartOptions = c.SpanStartOptions
 74	h.readEvent = c.ReadEvent
 75	h.writeEvent = c.WriteEvent
 76	h.filters = c.Filters
 77	h.spanNameFormatter = c.SpanNameFormatter
 78	h.publicEndpoint = c.PublicEndpoint
 79	h.publicEndpointFn = c.PublicEndpointFn
 80	h.server = c.ServerName
 81	h.semconv = semconv.NewHTTPServer(c.Meter)
 82}
 83
 84func handleErr(err error) {
 85	if err != nil {
 86		otel.Handle(err)
 87	}
 88}
 89
 90// serveHTTP sets up tracing and calls the given next http.Handler with the span
 91// context injected into the request context.
 92func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http.Handler) {
 93	requestStartTime := time.Now()
 94	for _, f := range h.filters {
 95		if !f(r) {
 96			// Simply pass through to the handler if a filter rejects the request
 97			next.ServeHTTP(w, r)
 98			return
 99		}
100	}
101
102	ctx := h.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
103	opts := []trace.SpanStartOption{
104		trace.WithAttributes(h.semconv.RequestTraceAttrs(h.server, r)...),
105	}
106
107	opts = append(opts, h.spanStartOptions...)
108	if h.publicEndpoint || (h.publicEndpointFn != nil && h.publicEndpointFn(r.WithContext(ctx))) {
109		opts = append(opts, trace.WithNewRoot())
110		// Linking incoming span context if any for public endpoint.
111		if s := trace.SpanContextFromContext(ctx); s.IsValid() && s.IsRemote() {
112			opts = append(opts, trace.WithLinks(trace.Link{SpanContext: s}))
113		}
114	}
115
116	tracer := h.tracer
117
118	if tracer == nil {
119		if span := trace.SpanFromContext(r.Context()); span.SpanContext().IsValid() {
120			tracer = newTracer(span.TracerProvider())
121		} else {
122			tracer = newTracer(otel.GetTracerProvider())
123		}
124	}
125
126	ctx, span := tracer.Start(ctx, h.spanNameFormatter(h.operation, r), opts...)
127	defer span.End()
128
129	readRecordFunc := func(int64) {}
130	if h.readEvent {
131		readRecordFunc = func(n int64) {
132			span.AddEvent("read", trace.WithAttributes(ReadBytesKey.Int64(n)))
133		}
134	}
135
136	// if request body is nil or NoBody, we don't want to mutate the body as it
137	// will affect the identity of it in an unforeseeable way because we assert
138	// ReadCloser fulfills a certain interface and it is indeed nil or NoBody.
139	bw := request.NewBodyWrapper(r.Body, readRecordFunc)
140	if r.Body != nil && r.Body != http.NoBody {
141		r.Body = bw
142	}
143
144	writeRecordFunc := func(int64) {}
145	if h.writeEvent {
146		writeRecordFunc = func(n int64) {
147			span.AddEvent("write", trace.WithAttributes(WroteBytesKey.Int64(n)))
148		}
149	}
150
151	rww := request.NewRespWriterWrapper(w, writeRecordFunc)
152
153	// Wrap w to use our ResponseWriter methods while also exposing
154	// other interfaces that w may implement (http.CloseNotifier,
155	// http.Flusher, http.Hijacker, http.Pusher, io.ReaderFrom).
156
157	w = httpsnoop.Wrap(w, httpsnoop.Hooks{
158		Header: func(httpsnoop.HeaderFunc) httpsnoop.HeaderFunc {
159			return rww.Header
160		},
161		Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
162			return rww.Write
163		},
164		WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
165			return rww.WriteHeader
166		},
167		Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc {
168			return rww.Flush
169		},
170	})
171
172	labeler, found := LabelerFromContext(ctx)
173	if !found {
174		ctx = ContextWithLabeler(ctx, labeler)
175	}
176
177	next.ServeHTTP(w, r.WithContext(ctx))
178
179	statusCode := rww.StatusCode()
180	bytesWritten := rww.BytesWritten()
181	span.SetStatus(h.semconv.Status(statusCode))
182	span.SetAttributes(h.semconv.ResponseTraceAttrs(semconv.ResponseTelemetry{
183		StatusCode: statusCode,
184		ReadBytes:  bw.BytesRead(),
185		ReadError:  bw.Error(),
186		WriteBytes: bytesWritten,
187		WriteError: rww.Error(),
188	})...)
189
190	// Use floating point division here for higher precision (instead of Millisecond method).
191	elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond)
192
193	h.semconv.RecordMetrics(ctx, semconv.MetricData{
194		ServerName:           h.server,
195		Req:                  r,
196		StatusCode:           statusCode,
197		AdditionalAttributes: labeler.Get(),
198		RequestSize:          bw.BytesRead(),
199		ResponseSize:         bytesWritten,
200		ElapsedTime:          elapsedTime,
201	})
202}
203
204// WithRouteTag annotates spans and metrics with the provided route name
205// with HTTP route attribute.
206func WithRouteTag(route string, h http.Handler) http.Handler {
207	attr := semconv.NewHTTPServer(nil).Route(route)
208	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
209		span := trace.SpanFromContext(r.Context())
210		span.SetAttributes(attr)
211
212		labeler, _ := LabelerFromContext(r.Context())
213		labeler.Add(attr)
214
215		h.ServeHTTP(w, r)
216	})
217}