wip: poc retries

Kujtim Hoxha created

Change summary

internal/llm/agent/agent.go                       | 16 +++
internal/llm/provider/openai.go                   | 13 +-
internal/llm/provider/provider.go                 |  3 
internal/message/content.go                       | 82 +++++++++++++++++
internal/message/message.go                       |  9 +
internal/tui/components/chat/messages/messages.go | 75 +++++++++++++++
6 files changed, 190 insertions(+), 8 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -634,7 +634,23 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
 		assistantMsg.FinishToolCall(event.ToolCall.ID)
 		return a.messages.Update(ctx, *assistantMsg)
 	case provider.EventError:
+		assistantMsg.SetRetrying(false)
+		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
+			return fmt.Errorf("failed to update message: %w", err)
+		}
 		return event.Error
+	case provider.EventRetry:
+		errMsg := ""
+		if event.Error != nil {
+			errMsg = event.Error.Error()
+		}
+		assistantMsg.SetRetrying(false)
+		assistantMsg.AddRetry(errMsg, event.Retry)
+		return a.messages.Update(ctx, *assistantMsg)
+	case provider.EventRetrying:
+		assistantMsg.SetRetrying(true)
+		return a.messages.Update(ctx, *assistantMsg)
+
 	case provider.EventComplete:
 		assistantMsg.FinishThinking()
 		assistantMsg.SetToolCalls(event.Response.ToolCalls)

internal/llm/provider/openai.go 🔗

@@ -461,13 +461,11 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 
 			// If there is an error we are going to see if we can retry the call
 			retry, after, retryErr := o.shouldRetry(attempts, err)
-			if retryErr != nil {
-				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
-				close(eventChan)
-				return
-			}
 			if retry {
-				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
+				if retryErr == nil {
+					slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
+				}
+				eventChan <- ProviderEvent{Type: EventRetry, Error: retryErr, Retry: after}
 				select {
 				case <-ctx.Done():
 					// context cancelled
@@ -477,6 +475,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 					close(eventChan)
 					return
 				case <-time.After(time.Duration(after) * time.Millisecond):
+					eventChan <- ProviderEvent{Type: EventRetry, Error: retryErr}
 					continue
 				}
 			}
@@ -534,7 +533,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 			retryMs = retryMs * 1000
 		}
 	}
-	return true, int64(retryMs), nil
+	return true, int64(retryMs), err
 }
 
 func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {

internal/llm/provider/provider.go 🔗

@@ -26,6 +26,8 @@ const (
 	EventComplete       EventType = "complete"
 	EventError          EventType = "error"
 	EventWarning        EventType = "warning"
+	EventRetry          EventType = "retry"
+	EventRetrying       EventType = "retrying"
 )
 
 type TokenUsage struct {
@@ -50,6 +52,7 @@ type ProviderEvent struct {
 	Signature string
 	Response  *ProviderResponse
 	ToolCall  *message.ToolCall
+	Retry     int64
 	Error     error
 }
 type Provider interface {

internal/message/content.go 🔗

@@ -113,6 +113,19 @@ type Finish struct {
 
 func (Finish) isPart() {}
 
+type Retry struct {
+	Error      string `json:"error"`
+	RetryAfter int64  `json:"retry_after"`
+	Timestamp  int64  `json:"timestamp"`
+}
+
+type RetryContent struct {
+	Retries  []Retry `json:"retries"`
+	Retrying bool    `json:"retrying"`
+}
+
+func (RetryContent) isPart() {}
+
 type Message struct {
 	ID        string
 	Role      MessageRole
@@ -384,3 +397,72 @@ func (m *Message) AddImageURL(url, detail string) {
 func (m *Message) AddBinary(mimeType string, data []byte) {
 	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
 }
+
+func (m *Message) RetryContent() *RetryContent {
+	for _, part := range m.Parts {
+		if c, ok := part.(RetryContent); ok {
+			return &c
+		}
+	}
+	return nil
+}
+
+func (m *Message) AddRetry(error string, retryAfter int64) {
+	retry := Retry{
+		Error:      error,
+		RetryAfter: retryAfter,
+		Timestamp:  time.Now().Unix(),
+	}
+
+	found := false
+	for i, part := range m.Parts {
+		if c, ok := part.(RetryContent); ok {
+			m.Parts[i] = RetryContent{
+				Retries:  append(c.Retries, retry),
+				Retrying: c.Retrying,
+			}
+			found = true
+			break
+		}
+	}
+	if !found {
+		m.Parts = append(m.Parts, RetryContent{
+			Retries:  []Retry{retry},
+			Retrying: false,
+		})
+	}
+}
+
+func (m *Message) SetRetrying(retrying bool) {
+	found := false
+	for i, part := range m.Parts {
+		if c, ok := part.(RetryContent); ok {
+			m.Parts[i] = RetryContent{
+				Retries:  c.Retries,
+				Retrying: retrying,
+			}
+			found = true
+			break
+		}
+	}
+	if !found && retrying {
+		m.Parts = append(m.Parts, RetryContent{
+			Retries:  []Retry{},
+			Retrying: retrying,
+		})
+	}
+}
+
+func (m *Message) IsRetrying() bool {
+	if retryContent := m.RetryContent(); retryContent != nil {
+		return retryContent.Retrying
+	}
+	return false
+}
+
+func (m *Message) GetRetries() []Retry {
+	if retryContent := m.RetryContent(); retryContent != nil {
+		return retryContent.Retries
+	}
+	return []Retry{}
+}

internal/message/message.go 🔗

@@ -172,6 +172,7 @@ const (
 	toolCallType   partType = "tool_call"
 	toolResultType partType = "tool_result"
 	finishType     partType = "finish"
+	retryType      partType = "retry"
 )
 
 type partWrapper struct {
@@ -200,6 +201,8 @@ func marshallParts(parts []ContentPart) ([]byte, error) {
 			typ = toolResultType
 		case Finish:
 			typ = finishType
+		case RetryContent:
+			typ = retryType
 		default:
 			return nil, fmt.Errorf("unknown part type: %T", part)
 		}
@@ -273,6 +276,12 @@ func unmarshallParts(data []byte) ([]ContentPart, error) {
 				return nil, err
 			}
 			parts = append(parts, part)
+		case retryType:
+			part := RetryContent{}
+			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+				return nil, err
+			}
+			parts = append(parts, part)
 		default:
 			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
 		}

internal/tui/components/chat/messages/messages.go 🔗

@@ -175,6 +175,7 @@ func (m *messageCmp) renderAssistantMessage() string {
 	finished := m.message.IsFinished()
 	finishedData := m.message.FinishPart()
 	thinkingContent := ""
+	retryContent := m.renderRetryContent()
 
 	if thinking || m.message.ReasoningContent().Thinking != "" {
 		m.anim.SetLabel("Thinking")
@@ -192,12 +193,19 @@ func (m *messageCmp) renderAssistantMessage() string {
 		return m.style().Render(errorContent)
 	}
 
+	if retryContent != "" {
+		parts = append(parts, retryContent)
+	}
+
 	if thinkingContent != "" {
+		if retryContent != "" {
+			parts = append(parts, "")
+		}
 		parts = append(parts, thinkingContent)
 	}
 
 	if content != "" {
-		if thinkingContent != "" {
+		if thinkingContent != "" || retryContent != "" {
 			parts = append(parts, "")
 		}
 		parts = append(parts, m.toMarkdown(content))
@@ -290,8 +298,68 @@ func (m *messageCmp) renderThinkingContent() string {
 	return lineStyle.Width(m.textWidth()).Padding(0, 1).Render(m.thinkingViewport.View()) + "\n\n" + footer
 }
 
+func (m *messageCmp) renderRetryContent() string {
+	t := styles.CurrentTheme()
+	retryContent := m.message.RetryContent()
+	if retryContent == nil || len(retryContent.Retries) == 0 {
+		return ""
+	}
+
+	// Get the latest retry for the main display
+	latestRetry := retryContent.Retries[len(retryContent.Retries)-1]
+
+	var title string
+	var details string
+	retryDuration := time.Duration(latestRetry.RetryAfter) * time.Millisecond
+
+	if strings.Contains(latestRetry.Error, "426") || strings.Contains(strings.ToLower(latestRetry.Error), "rate limited") {
+		// Rate limit retry
+		warningTag := t.S().Base.Padding(0, 1).Background(t.Warning).Foreground(t.BgBase).Render("RATE LIMITED")
+		retryMsg := fmt.Sprintf("Retrying after %s", retryDuration.String())
+		title = fmt.Sprintf("%s %s", warningTag, t.S().Base.Foreground(t.FgHalfMuted).Render(retryMsg))
+	} else {
+		// Error retry
+		warningTag := t.S().Base.Padding(0, 1).Background(t.Warning).Foreground(t.BgBase).Render("RETRYING")
+		truncated := ansi.Truncate(latestRetry.Error, m.textWidth()-2-lipgloss.Width(warningTag), "...")
+		title = fmt.Sprintf("%s %s", warningTag, t.S().Base.Foreground(t.FgHalfMuted).Render(truncated))
+	}
+
+	// Show retry history as details
+	if len(retryContent.Retries) > 1 {
+		var retryHistory []string
+
+		for i, retry := range retryContent.Retries {
+			timestamp := time.Unix(retry.Timestamp, 0).Format("15:04:05")
+			retryDuration := time.Duration(retry.RetryAfter) * time.Millisecond
+			var retryMsg string
+			if retry.Error == "" {
+				retryMsg = fmt.Sprintf("Rate limited, retrying after %s", retryDuration.String())
+			} else {
+				retryMsg = fmt.Sprintf("Error: %s (retry after %s)", retry.Error, retryDuration.String())
+			}
+			retryHistory = append(retryHistory, fmt.Sprintf("Attempt %d (%s): %s", i+1, timestamp, retryMsg))
+		}
+		details = strings.Join(retryHistory, "\n")
+	} else {
+		// Single retry, show timestamp
+		timestamp := time.Unix(latestRetry.Timestamp, 0).Format("15:04:05")
+		details = fmt.Sprintf("First attempt at %s", timestamp)
+	}
+
+	// Add current status if actively retrying
+	if retryContent.Retrying {
+		details += "\n\nCurrently retrying..."
+	}
+
+	detailsFormatted := t.S().Base.Foreground(t.FgSubtle).Width(m.textWidth() - 2).Render(details)
+	retryDisplay := fmt.Sprintf("%s\n\n%s", title, detailsFormatted)
+
+	return retryDisplay
+}
+
 // shouldSpin determines whether the message should show a loading animation.
 // Only assistant messages without content that aren't finished should spin.
+// Also considers retry state - only spins when actively retrying.
 func (m *messageCmp) shouldSpin() bool {
 	if m.message.Role != message.Assistant {
 		return false
@@ -301,6 +369,11 @@ func (m *messageCmp) shouldSpin() bool {
 		return false
 	}
 
+	// Check retry state - only spin if actively retrying
+	if retryContent := m.message.RetryContent(); retryContent != nil {
+		return retryContent.Retrying
+	}
+
 	if m.message.Content().Text != "" {
 		return false
 	}