diff --git a/llm/llmhttp/llmhttp.go b/llm/llmhttp/llmhttp.go index 26d8fcf6c2f2cd2800403d80af3f2bd8b144d82b..50fa7f320a90669093916a5a8a7434397546998e 100644 --- a/llm/llmhttp/llmhttp.go +++ b/llm/llmhttp/llmhttp.go @@ -88,6 +88,11 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // Add conversation ID header if present if conversationID := ConversationIDFromContext(req.Context()); conversationID != "" { req.Header.Set("Shelley-Conversation-Id", conversationID) + + // Add x-session-affinity header for Fireworks to enable prompt caching + if ProviderFromContext(req.Context()) == "fireworks" { + req.Header.Set("x-session-affinity", conversationID) + } } // Read and store the request body for recording diff --git a/llm/llmhttp/llmhttp_test.go b/llm/llmhttp/llmhttp_test.go index ccad6d3cba92f8d7e29af8b60aa3726a13f0b59d..0d031edd35a36211407bc62ca3426ad73770d9bf 100644 --- a/llm/llmhttp/llmhttp_test.go +++ b/llm/llmhttp/llmhttp_test.go @@ -76,6 +76,47 @@ func TestTransportAddsHeaders(t *testing.T) { if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" { t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id") } + + // Verify x-session-affinity is NOT added for non-fireworks providers + if got := receivedHeaders.Get("x-session-affinity"); got != "" { + t.Errorf("x-session-affinity = %q, want empty for non-fireworks", got) + } +} + +func TestTransportAddsSessionAffinityForFireworks(t *testing.T) { + // Create a test server that echoes request headers + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer server.Close() + + // Create client with our transport + client := NewClient(nil, nil) + + // Make a request with conversation ID and provider=fireworks in context + ctx := context.Background() + ctx = WithConversationID(ctx, "test-conv-id") + ctx = WithProvider(ctx, "fireworks") + req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + resp.Body.Close() + + // Verify x-session-affinity header was added for fireworks + if got := receivedHeaders.Get("x-session-affinity"); got != "test-conv-id" { + t.Errorf("x-session-affinity = %q, want %q", got, "test-conv-id") + } + + // Verify Shelley-Conversation-Id header was also added + if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" { + t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id") + } } func TestTransportRecordsRequest(t *testing.T) {