chore: small fix

Kujtim Hoxha created

Change summary

internal/llm/agent/agent.go     |  6 ++++
internal/llm/provider/openai.go | 48 ++++++++++++++++++++++------------
2 files changed, 37 insertions(+), 17 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -420,6 +420,12 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 			msgHistory = append(msgHistory, agentMessage, *toolResults)
 			continue
 		}
+		if agentMessage.FinishReason() == "" {
+			// Kujtim: could not track down where this is happening but this means its cancelled
+			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
+			_ = a.messages.Update(context.Background(), agentMessage)
+			return a.err(ErrRequestCancelled)
+		}
 		return AgentEvent{
 			Type:    AgentEventTypeResponse,
 			Message: agentMessage,

internal/llm/provider/openai.go 🔗

@@ -319,6 +319,10 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 	go func() {
 		for {
 			attempts++
+			// Kujtim: fixes an issue with anthropig models on openrouter
+			if len(params.Tools) == 0 {
+				params.Tools = nil
+			}
 			openaiStream := o.client.Chat.Completions.NewStreaming(
 				ctx,
 				params,
@@ -484,31 +488,41 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 }
 
 func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
-	var apiErr *openai.Error
-	if !errors.As(err, &apiErr) {
-		return false, 0, err
-	}
-
 	if attempts > maxRetries {
 		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
 	}
+	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+		return false, 0, err
+	}
+	var apiErr *openai.Error
+	retryMs := 0
+	retryAfterValues := []string{}
+	if errors.As(err, &apiErr) {
+		// Check for token expiration (401 Unauthorized)
+		if apiErr.StatusCode == 401 {
+			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
+			if err != nil {
+				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+			}
+			o.client = createOpenAIClient(o.providerOptions)
+			return true, 0, nil
+		}
 
-	// Check for token expiration (401 Unauthorized)
-	if apiErr.StatusCode == 401 {
-		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
-		if err != nil {
-			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
+			return false, 0, err
 		}
-		o.client = createOpenAIClient(o.providerOptions)
-		return true, 0, nil
-	}
 
-	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
-		return false, 0, err
+		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
 	}
 
-	retryMs := 0
-	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
+	if apiErr != nil {
+		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
+		if len(retryAfterValues) > 0 {
+			slog.Warn("Retry-After header", "values", retryAfterValues)
+		}
+	} else {
+		slog.Warn("OpenAI API error", "error", err.Error())
+	}
 
 	backoffMs := 2000 * (1 << (attempts - 1))
 	jitterMs := int(float64(backoffMs) * 0.2)