1// Package llmhttp provides HTTP utilities for LLM requests including
2// custom headers and database recording.
3package llmhttp
4
5import (
6 "bytes"
7 "context"
8 "io"
9 "net/http"
10 "time"
11
12 "shelley.exe.dev/version"
13)
14
15// contextKey is the type for context keys in this package.
16type contextKey int
17
18const (
19 conversationIDKey contextKey = iota
20 modelIDKey
21 providerKey
22)
23
24// WithConversationID returns a context with the conversation ID attached.
25func WithConversationID(ctx context.Context, conversationID string) context.Context {
26 return context.WithValue(ctx, conversationIDKey, conversationID)
27}
28
29// ConversationIDFromContext returns the conversation ID from the context, if any.
30func ConversationIDFromContext(ctx context.Context) string {
31 if v := ctx.Value(conversationIDKey); v != nil {
32 return v.(string)
33 }
34 return ""
35}
36
37// WithModelID returns a context with the model ID attached.
38func WithModelID(ctx context.Context, modelID string) context.Context {
39 return context.WithValue(ctx, modelIDKey, modelID)
40}
41
42// ModelIDFromContext returns the model ID from the context, if any.
43func ModelIDFromContext(ctx context.Context) string {
44 if v := ctx.Value(modelIDKey); v != nil {
45 return v.(string)
46 }
47 return ""
48}
49
50// WithProvider returns a context with the provider name attached.
51func WithProvider(ctx context.Context, provider string) context.Context {
52 return context.WithValue(ctx, providerKey, provider)
53}
54
55// ProviderFromContext returns the provider name from the context, if any.
56func ProviderFromContext(ctx context.Context) string {
57 if v := ctx.Value(providerKey); v != nil {
58 return v.(string)
59 }
60 return ""
61}
62
63// Recorder is called after each LLM HTTP request with the request/response details.
64type Recorder func(ctx context.Context, url string, requestBody, responseBody []byte, statusCode int, err error, duration time.Duration)
65
66// Transport wraps an http.RoundTripper to add Shelley-specific headers
67// and optionally record requests to a database.
68type Transport struct {
69 Base http.RoundTripper
70 Recorder Recorder
71}
72
73// RoundTrip implements http.RoundTripper.
74func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
75 start := time.Now()
76
77 // Clone the request to avoid modifying the original
78 req = req.Clone(req.Context())
79
80 // Add User-Agent with Shelley version
81 info := version.GetInfo()
82 userAgent := "Shelley"
83 if info.Commit != "" {
84 userAgent += "/" + info.Commit[:min(8, len(info.Commit))]
85 }
86 req.Header.Set("User-Agent", userAgent)
87
88 // Add conversation ID header if present
89 if conversationID := ConversationIDFromContext(req.Context()); conversationID != "" {
90 req.Header.Set("Shelley-Conversation-Id", conversationID)
91
92 // Add x-session-affinity header for Fireworks to enable prompt caching
93 if ProviderFromContext(req.Context()) == "fireworks" {
94 req.Header.Set("x-session-affinity", conversationID)
95 }
96 }
97
98 // Read and store the request body for recording
99 var requestBody []byte
100 if t.Recorder != nil && req.Body != nil {
101 var err error
102 requestBody, err = io.ReadAll(req.Body)
103 if err != nil {
104 return nil, err
105 }
106 req.Body = io.NopCloser(bytes.NewReader(requestBody))
107 }
108
109 // Perform the actual request
110 base := t.Base
111 if base == nil {
112 base = http.DefaultTransport
113 }
114
115 resp, err := base.RoundTrip(req)
116
117 // Record the request if we have a recorder
118 if t.Recorder != nil {
119 var responseBody []byte
120 var statusCode int
121
122 if resp != nil {
123 statusCode = resp.StatusCode
124 // Read and restore the response body
125 responseBody, _ = io.ReadAll(resp.Body)
126 resp.Body.Close()
127 resp.Body = io.NopCloser(bytes.NewReader(responseBody))
128 }
129
130 t.Recorder(req.Context(), req.URL.String(), requestBody, responseBody, statusCode, err, time.Since(start))
131 }
132
133 return resp, err
134}
135
136// NewClient creates an http.Client with Shelley headers and optional recording.
137func NewClient(base *http.Client, recorder Recorder) *http.Client {
138 if base == nil {
139 base = http.DefaultClient
140 }
141
142 transport := base.Transport
143 if transport == nil {
144 transport = http.DefaultTransport
145 }
146
147 return &http.Client{
148 Transport: &Transport{
149 Base: transport,
150 Recorder: recorder,
151 },
152 Timeout: base.Timeout,
153 }
154}