1package llmhttp
2
3import (
4 "context"
5 "io"
6 "net/http"
7 "net/http/httptest"
8 "strings"
9 "testing"
10 "time"
11)
12
13func TestContextFunctions(t *testing.T) {
14 ctx := context.Background()
15
16 // Test ConversationID
17 ctx = WithConversationID(ctx, "conv-123")
18 if got := ConversationIDFromContext(ctx); got != "conv-123" {
19 t.Errorf("ConversationIDFromContext() = %q, want %q", got, "conv-123")
20 }
21
22 // Test ModelID
23 ctx = WithModelID(ctx, "model-456")
24 if got := ModelIDFromContext(ctx); got != "model-456" {
25 t.Errorf("ModelIDFromContext() = %q, want %q", got, "model-456")
26 }
27
28 // Test Provider
29 ctx = WithProvider(ctx, "anthropic")
30 if got := ProviderFromContext(ctx); got != "anthropic" {
31 t.Errorf("ProviderFromContext() = %q, want %q", got, "anthropic")
32 }
33
34 // Test empty context
35 emptyCtx := context.Background()
36 if got := ConversationIDFromContext(emptyCtx); got != "" {
37 t.Errorf("ConversationIDFromContext(empty) = %q, want empty", got)
38 }
39 if got := ModelIDFromContext(emptyCtx); got != "" {
40 t.Errorf("ModelIDFromContext(empty) = %q, want empty", got)
41 }
42 if got := ProviderFromContext(emptyCtx); got != "" {
43 t.Errorf("ProviderFromContext(empty) = %q, want empty", got)
44 }
45}
46
47func TestTransportAddsHeaders(t *testing.T) {
48 // Create a test server that echoes request headers
49 var receivedHeaders http.Header
50 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51 receivedHeaders = r.Header.Clone()
52 w.WriteHeader(http.StatusOK)
53 w.Write([]byte("ok"))
54 }))
55 defer server.Close()
56
57 // Create client with our transport
58 client := NewClient(nil, nil)
59
60 // Make a request with conversation ID in context
61 ctx := WithConversationID(context.Background(), "test-conv-id")
62 req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
63
64 resp, err := client.Do(req)
65 if err != nil {
66 t.Fatalf("Request failed: %v", err)
67 }
68 resp.Body.Close()
69
70 // Verify User-Agent header was added
71 if !strings.HasPrefix(receivedHeaders.Get("User-Agent"), "Shelley") {
72 t.Errorf("User-Agent = %q, want prefix 'Shelley'", receivedHeaders.Get("User-Agent"))
73 }
74
75 // Verify Shelley-Conversation-Id header was added
76 if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
77 t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
78 }
79
80 // Verify x-session-affinity is NOT added for non-fireworks providers
81 if got := receivedHeaders.Get("x-session-affinity"); got != "" {
82 t.Errorf("x-session-affinity = %q, want empty for non-fireworks", got)
83 }
84}
85
86func TestTransportAddsSessionAffinityForFireworks(t *testing.T) {
87 // Create a test server that echoes request headers
88 var receivedHeaders http.Header
89 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
90 receivedHeaders = r.Header.Clone()
91 w.WriteHeader(http.StatusOK)
92 w.Write([]byte("ok"))
93 }))
94 defer server.Close()
95
96 // Create client with our transport
97 client := NewClient(nil, nil)
98
99 // Make a request with conversation ID and provider=fireworks in context
100 ctx := context.Background()
101 ctx = WithConversationID(ctx, "test-conv-id")
102 ctx = WithProvider(ctx, "fireworks")
103 req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
104
105 resp, err := client.Do(req)
106 if err != nil {
107 t.Fatalf("Request failed: %v", err)
108 }
109 resp.Body.Close()
110
111 // Verify x-session-affinity header was added for fireworks
112 if got := receivedHeaders.Get("x-session-affinity"); got != "test-conv-id" {
113 t.Errorf("x-session-affinity = %q, want %q", got, "test-conv-id")
114 }
115
116 // Verify Shelley-Conversation-Id header was also added
117 if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
118 t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
119 }
120}
121
122func TestTransportRecordsRequest(t *testing.T) {
123 // Create a test server
124 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125 body, _ := io.ReadAll(r.Body)
126 w.WriteHeader(http.StatusOK)
127 w.Write([]byte("response body: " + string(body)))
128 }))
129 defer server.Close()
130
131 // Track recorded values
132 var (
133 recordedURL string
134 recordedRequestBody []byte
135 recordedRespBody []byte
136 recordedStatusCode int
137 recordedDuration time.Duration
138 recorderCalled bool
139 )
140
141 recorder := func(ctx context.Context, url string, requestBody, responseBody []byte, statusCode int, err error, duration time.Duration) {
142 recorderCalled = true
143 recordedURL = url
144 recordedRequestBody = requestBody
145 recordedRespBody = responseBody
146 recordedStatusCode = statusCode
147 recordedDuration = duration
148 }
149
150 // Create client with recorder
151 client := NewClient(nil, recorder)
152
153 // Make a request with body
154 req, _ := http.NewRequest("POST", server.URL, strings.NewReader("test body"))
155 resp, err := client.Do(req)
156 if err != nil {
157 t.Fatalf("Request failed: %v", err)
158 }
159
160 // Read response body to ensure it's still accessible
161 respBody, _ := io.ReadAll(resp.Body)
162 resp.Body.Close()
163
164 if string(respBody) != "response body: test body" {
165 t.Errorf("Response body = %q, want %q", string(respBody), "response body: test body")
166 }
167
168 // Verify recorder was called with correct values
169 if !recorderCalled {
170 t.Fatal("Recorder was not called")
171 }
172
173 if recordedURL != server.URL {
174 t.Errorf("Recorded URL = %q, want %q", recordedURL, server.URL)
175 }
176
177 if string(recordedRequestBody) != "test body" {
178 t.Errorf("Recorded request body = %q, want %q", string(recordedRequestBody), "test body")
179 }
180
181 if string(recordedRespBody) != "response body: test body" {
182 t.Errorf("Recorded response body = %q, want %q", string(recordedRespBody), "response body: test body")
183 }
184
185 if recordedStatusCode != http.StatusOK {
186 t.Errorf("Recorded status code = %d, want %d", recordedStatusCode, http.StatusOK)
187 }
188
189 if recordedDuration <= 0 {
190 t.Error("Recorded duration should be positive")
191 }
192}
193
194func TestTransportWithoutRecorder(t *testing.T) {
195 // Create a test server
196 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
197 w.WriteHeader(http.StatusOK)
198 w.Write([]byte("ok"))
199 }))
200 defer server.Close()
201
202 // Create client without recorder
203 client := NewClient(nil, nil)
204
205 // Make a request
206 req, _ := http.NewRequest("GET", server.URL, nil)
207 resp, err := client.Do(req)
208 if err != nil {
209 t.Fatalf("Request failed: %v", err)
210 }
211 resp.Body.Close()
212
213 if resp.StatusCode != http.StatusOK {
214 t.Errorf("Status code = %d, want %d", resp.StatusCode, http.StatusOK)
215 }
216}