@@ -378,7 +378,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return a.messages.Update(ctx, *currentAssistant)
},
OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
- // TODO: implement
+ slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
},
OnToolCall: func(tc fantasy.ToolCallContent) error {
toolCall := message.ToolCall{
@@ -1318,3 +1318,20 @@ func buildSummaryPrompt(todos []session.Todo) string {
}
return sb.String()
}
+
+func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
+ fields := []any{
+ "retry_delay", delay.String(),
+ }
+ if err == nil {
+ return fields
+ }
+ fields = append(fields, "status_code", err.StatusCode)
+ if err.Title != "" {
+ fields = append(fields, "title", err.Title)
+ }
+ if err.Message != "" {
+ fields = append(fields, "message", err.Message)
+ }
+ return fields
+}
@@ -8,6 +8,7 @@ import (
"runtime"
"strings"
"testing"
+ "time"
"charm.land/fantasy"
"charm.land/x/vcr"
@@ -793,3 +794,34 @@ func TestPreparePrompt_OrphanedToolUseMixed(t *testing.T) {
}
require.Equal(t, 1, syntheticCount, "expected exactly one synthetic result for the orphaned call")
}
+
+func TestProviderRetryLogFields(t *testing.T) {
+ t.Run("nil provider error", func(t *testing.T) {
+ fields := providerRetryLogFields(nil, 2*time.Second)
+ require.Equal(t, []any{"retry_delay", "2s"}, fields)
+ })
+
+ t.Run("provider error with title and message", func(t *testing.T) {
+ fields := providerRetryLogFields(&fantasy.ProviderError{
+ StatusCode: 429,
+ Title: "rate limit",
+ Message: "too many requests",
+ }, 1500*time.Millisecond)
+ require.Equal(t, []any{
+ "retry_delay", "1.5s",
+ "status_code", 429,
+ "title", "rate limit",
+ "message", "too many requests",
+ }, fields)
+ })
+
+ t.Run("provider error without optional strings", func(t *testing.T) {
+ fields := providerRetryLogFields(&fantasy.ProviderError{
+ StatusCode: 503,
+ }, time.Second)
+ require.Equal(t, []any{
+ "retry_delay", "1s",
+ "status_code", 503,
+ }, fields)
+ })
+}