llmhttp_test.go

  1package llmhttp
  2
  3import (
  4	"context"
  5	"io"
  6	"net/http"
  7	"net/http/httptest"
  8	"strings"
  9	"testing"
 10	"time"
 11)
 12
 13func TestContextFunctions(t *testing.T) {
 14	ctx := context.Background()
 15
 16	// Test ConversationID
 17	ctx = WithConversationID(ctx, "conv-123")
 18	if got := ConversationIDFromContext(ctx); got != "conv-123" {
 19		t.Errorf("ConversationIDFromContext() = %q, want %q", got, "conv-123")
 20	}
 21
 22	// Test ModelID
 23	ctx = WithModelID(ctx, "model-456")
 24	if got := ModelIDFromContext(ctx); got != "model-456" {
 25		t.Errorf("ModelIDFromContext() = %q, want %q", got, "model-456")
 26	}
 27
 28	// Test Provider
 29	ctx = WithProvider(ctx, "anthropic")
 30	if got := ProviderFromContext(ctx); got != "anthropic" {
 31		t.Errorf("ProviderFromContext() = %q, want %q", got, "anthropic")
 32	}
 33
 34	// Test empty context
 35	emptyCtx := context.Background()
 36	if got := ConversationIDFromContext(emptyCtx); got != "" {
 37		t.Errorf("ConversationIDFromContext(empty) = %q, want empty", got)
 38	}
 39	if got := ModelIDFromContext(emptyCtx); got != "" {
 40		t.Errorf("ModelIDFromContext(empty) = %q, want empty", got)
 41	}
 42	if got := ProviderFromContext(emptyCtx); got != "" {
 43		t.Errorf("ProviderFromContext(empty) = %q, want empty", got)
 44	}
 45}
 46
 47func TestTransportAddsHeaders(t *testing.T) {
 48	// Create a test server that echoes request headers
 49	var receivedHeaders http.Header
 50	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 51		receivedHeaders = r.Header.Clone()
 52		w.WriteHeader(http.StatusOK)
 53		w.Write([]byte("ok"))
 54	}))
 55	defer server.Close()
 56
 57	// Create client with our transport
 58	client := NewClient(nil, nil)
 59
 60	// Make a request with conversation ID in context
 61	ctx := WithConversationID(context.Background(), "test-conv-id")
 62	req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
 63
 64	resp, err := client.Do(req)
 65	if err != nil {
 66		t.Fatalf("Request failed: %v", err)
 67	}
 68	resp.Body.Close()
 69
 70	// Verify User-Agent header was added
 71	if !strings.HasPrefix(receivedHeaders.Get("User-Agent"), "Shelley") {
 72		t.Errorf("User-Agent = %q, want prefix 'Shelley'", receivedHeaders.Get("User-Agent"))
 73	}
 74
 75	// Verify Shelley-Conversation-Id header was added
 76	if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
 77		t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
 78	}
 79
 80	// Verify x-session-affinity is NOT added for non-fireworks providers
 81	if got := receivedHeaders.Get("x-session-affinity"); got != "" {
 82		t.Errorf("x-session-affinity = %q, want empty for non-fireworks", got)
 83	}
 84}
 85
 86func TestTransportAddsSessionAffinityForFireworks(t *testing.T) {
 87	// Create a test server that echoes request headers
 88	var receivedHeaders http.Header
 89	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 90		receivedHeaders = r.Header.Clone()
 91		w.WriteHeader(http.StatusOK)
 92		w.Write([]byte("ok"))
 93	}))
 94	defer server.Close()
 95
 96	// Create client with our transport
 97	client := NewClient(nil, nil)
 98
 99	// Make a request with conversation ID and provider=fireworks in context
100	ctx := context.Background()
101	ctx = WithConversationID(ctx, "test-conv-id")
102	ctx = WithProvider(ctx, "fireworks")
103	req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
104
105	resp, err := client.Do(req)
106	if err != nil {
107		t.Fatalf("Request failed: %v", err)
108	}
109	resp.Body.Close()
110
111	// Verify x-session-affinity header was added for fireworks
112	if got := receivedHeaders.Get("x-session-affinity"); got != "test-conv-id" {
113		t.Errorf("x-session-affinity = %q, want %q", got, "test-conv-id")
114	}
115
116	// Verify Shelley-Conversation-Id header was also added
117	if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
118		t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
119	}
120}
121
122func TestTransportRecordsRequest(t *testing.T) {
123	// Create a test server
124	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125		body, _ := io.ReadAll(r.Body)
126		w.WriteHeader(http.StatusOK)
127		w.Write([]byte("response body: " + string(body)))
128	}))
129	defer server.Close()
130
131	// Track recorded values
132	var (
133		recordedURL         string
134		recordedRequestBody []byte
135		recordedRespBody    []byte
136		recordedStatusCode  int
137		recordedDuration    time.Duration
138		recorderCalled      bool
139	)
140
141	recorder := func(ctx context.Context, url string, requestBody, responseBody []byte, statusCode int, err error, duration time.Duration) {
142		recorderCalled = true
143		recordedURL = url
144		recordedRequestBody = requestBody
145		recordedRespBody = responseBody
146		recordedStatusCode = statusCode
147		recordedDuration = duration
148	}
149
150	// Create client with recorder
151	client := NewClient(nil, recorder)
152
153	// Make a request with body
154	req, _ := http.NewRequest("POST", server.URL, strings.NewReader("test body"))
155	resp, err := client.Do(req)
156	if err != nil {
157		t.Fatalf("Request failed: %v", err)
158	}
159
160	// Read response body to ensure it's still accessible
161	respBody, _ := io.ReadAll(resp.Body)
162	resp.Body.Close()
163
164	if string(respBody) != "response body: test body" {
165		t.Errorf("Response body = %q, want %q", string(respBody), "response body: test body")
166	}
167
168	// Verify recorder was called with correct values
169	if !recorderCalled {
170		t.Fatal("Recorder was not called")
171	}
172
173	if recordedURL != server.URL {
174		t.Errorf("Recorded URL = %q, want %q", recordedURL, server.URL)
175	}
176
177	if string(recordedRequestBody) != "test body" {
178		t.Errorf("Recorded request body = %q, want %q", string(recordedRequestBody), "test body")
179	}
180
181	if string(recordedRespBody) != "response body: test body" {
182		t.Errorf("Recorded response body = %q, want %q", string(recordedRespBody), "response body: test body")
183	}
184
185	if recordedStatusCode != http.StatusOK {
186		t.Errorf("Recorded status code = %d, want %d", recordedStatusCode, http.StatusOK)
187	}
188
189	if recordedDuration <= 0 {
190		t.Error("Recorded duration should be positive")
191	}
192}
193
194func TestTransportWithoutRecorder(t *testing.T) {
195	// Create a test server
196	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
197		w.WriteHeader(http.StatusOK)
198		w.Write([]byte("ok"))
199	}))
200	defer server.Close()
201
202	// Create client without recorder
203	client := NewClient(nil, nil)
204
205	// Make a request
206	req, _ := http.NewRequest("GET", server.URL, nil)
207	resp, err := client.Do(req)
208	if err != nil {
209		t.Fatalf("Request failed: %v", err)
210	}
211	resp.Body.Close()
212
213	if resp.StatusCode != http.StatusOK {
214		t.Errorf("Status code = %d, want %d", resp.StatusCode, http.StatusOK)
215	}
216}