@@ -88,6 +88,11 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// Add conversation ID header if present
if conversationID := ConversationIDFromContext(req.Context()); conversationID != "" {
req.Header.Set("Shelley-Conversation-Id", conversationID)
+
+ // Add x-session-affinity header for Fireworks to enable prompt caching
+ if ProviderFromContext(req.Context()) == "fireworks" {
+ req.Header.Set("x-session-affinity", conversationID)
+ }
}
// Read and store the request body for recording
@@ -76,6 +76,47 @@ func TestTransportAddsHeaders(t *testing.T) {
if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
}
+
+ // Verify x-session-affinity is NOT added for non-fireworks providers
+ if got := receivedHeaders.Get("x-session-affinity"); got != "" {
+ t.Errorf("x-session-affinity = %q, want empty for non-fireworks", got)
+ }
+}
+
+func TestTransportAddsSessionAffinityForFireworks(t *testing.T) {
+ // Create a test server that echoes request headers
+ var receivedHeaders http.Header
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedHeaders = r.Header.Clone()
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("ok"))
+ }))
+ defer server.Close()
+
+ // Create client with our transport
+ client := NewClient(nil, nil)
+
+ // Make a request with conversation ID and provider=fireworks in context
+ ctx := context.Background()
+ ctx = WithConversationID(ctx, "test-conv-id")
+ ctx = WithProvider(ctx, "fireworks")
+ req, _ := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("Request failed: %v", err)
+ }
+ resp.Body.Close()
+
+ // Verify x-session-affinity header was added for fireworks
+ if got := receivedHeaders.Get("x-session-affinity"); got != "test-conv-id" {
+ t.Errorf("x-session-affinity = %q, want %q", got, "test-conv-id")
+ }
+
+ // Verify Shelley-Conversation-Id header was also added
+ if got := receivedHeaders.Get("Shelley-Conversation-Id"); got != "test-conv-id" {
+ t.Errorf("Shelley-Conversation-Id = %q, want %q", got, "test-conv-id")
+ }
}
func TestTransportRecordsRequest(t *testing.T) {