llmhttp.go

  1// Package llmhttp provides HTTP utilities for LLM requests including
  2// custom headers and database recording.
  3package llmhttp
  4
  5import (
  6	"bytes"
  7	"context"
  8	"io"
  9	"net/http"
 10	"time"
 11
 12	"shelley.exe.dev/version"
 13)
 14
 15// contextKey is the type for context keys in this package.
 16type contextKey int
 17
 18const (
 19	conversationIDKey contextKey = iota
 20	modelIDKey
 21	providerKey
 22)
 23
 24// WithConversationID returns a context with the conversation ID attached.
 25func WithConversationID(ctx context.Context, conversationID string) context.Context {
 26	return context.WithValue(ctx, conversationIDKey, conversationID)
 27}
 28
 29// ConversationIDFromContext returns the conversation ID from the context, if any.
 30func ConversationIDFromContext(ctx context.Context) string {
 31	if v := ctx.Value(conversationIDKey); v != nil {
 32		return v.(string)
 33	}
 34	return ""
 35}
 36
 37// WithModelID returns a context with the model ID attached.
 38func WithModelID(ctx context.Context, modelID string) context.Context {
 39	return context.WithValue(ctx, modelIDKey, modelID)
 40}
 41
 42// ModelIDFromContext returns the model ID from the context, if any.
 43func ModelIDFromContext(ctx context.Context) string {
 44	if v := ctx.Value(modelIDKey); v != nil {
 45		return v.(string)
 46	}
 47	return ""
 48}
 49
 50// WithProvider returns a context with the provider name attached.
 51func WithProvider(ctx context.Context, provider string) context.Context {
 52	return context.WithValue(ctx, providerKey, provider)
 53}
 54
 55// ProviderFromContext returns the provider name from the context, if any.
 56func ProviderFromContext(ctx context.Context) string {
 57	if v := ctx.Value(providerKey); v != nil {
 58		return v.(string)
 59	}
 60	return ""
 61}
 62
 63// Recorder is called after each LLM HTTP request with the request/response details.
 64type Recorder func(ctx context.Context, url string, requestBody, responseBody []byte, statusCode int, err error, duration time.Duration)
 65
 66// Transport wraps an http.RoundTripper to add Shelley-specific headers
 67// and optionally record requests to a database.
 68type Transport struct {
 69	Base     http.RoundTripper
 70	Recorder Recorder
 71}
 72
 73// RoundTrip implements http.RoundTripper.
 74func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 75	start := time.Now()
 76
 77	// Clone the request to avoid modifying the original
 78	req = req.Clone(req.Context())
 79
 80	// Add User-Agent with Shelley version
 81	info := version.GetInfo()
 82	userAgent := "Shelley"
 83	if info.Commit != "" {
 84		userAgent += "/" + info.Commit[:min(8, len(info.Commit))]
 85	}
 86	req.Header.Set("User-Agent", userAgent)
 87
 88	// Add conversation ID header if present
 89	if conversationID := ConversationIDFromContext(req.Context()); conversationID != "" {
 90		req.Header.Set("Shelley-Conversation-Id", conversationID)
 91
 92		// Add x-session-affinity header for Fireworks to enable prompt caching
 93		if ProviderFromContext(req.Context()) == "fireworks" {
 94			req.Header.Set("x-session-affinity", conversationID)
 95		}
 96	}
 97
 98	// Read and store the request body for recording
 99	var requestBody []byte
100	if t.Recorder != nil && req.Body != nil {
101		var err error
102		requestBody, err = io.ReadAll(req.Body)
103		if err != nil {
104			return nil, err
105		}
106		req.Body = io.NopCloser(bytes.NewReader(requestBody))
107	}
108
109	// Perform the actual request
110	base := t.Base
111	if base == nil {
112		base = http.DefaultTransport
113	}
114
115	resp, err := base.RoundTrip(req)
116
117	// Record the request if we have a recorder
118	if t.Recorder != nil {
119		var responseBody []byte
120		var statusCode int
121
122		if resp != nil {
123			statusCode = resp.StatusCode
124			// Read and restore the response body
125			responseBody, _ = io.ReadAll(resp.Body)
126			resp.Body.Close()
127			resp.Body = io.NopCloser(bytes.NewReader(responseBody))
128		}
129
130		t.Recorder(req.Context(), req.URL.String(), requestBody, responseBody, statusCode, err, time.Since(start))
131	}
132
133	return resp, err
134}
135
136// NewClient creates an http.Client with Shelley headers and optional recording.
137func NewClient(base *http.Client, recorder Recorder) *http.Client {
138	if base == nil {
139		base = http.DefaultClient
140	}
141
142	transport := base.Transport
143	if transport == nil {
144		transport = http.DefaultTransport
145	}
146
147	return &http.Client{
148		Transport: &Transport{
149			Base:     transport,
150			Recorder: recorder,
151		},
152		Timeout: base.Timeout,
153	}
154}