Merge pull request #357 from charmbracelet/fix-openrouter

Kujtim Hoxha created

Fix some openrouter models

Change summary

internal/llm/agent/agent.go     |  6 ++
internal/llm/provider/openai.go | 75 ++++++++++++++++++++++++++--------
2 files changed, 62 insertions(+), 19 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -420,6 +420,12 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 			msgHistory = append(msgHistory, agentMessage, *toolResults)
 			continue
 		}
+		if agentMessage.FinishReason() == "" {
+			// Kujtim: could not track down where this is happening but this means its cancelled
+			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
+			_ = a.messages.Update(context.Background(), agentMessage)
+			return a.err(ErrRequestCancelled)
+		}
 		return AgentEvent{
 			Type:    AgentEventTypeResponse,
 			Message: agentMessage,

internal/llm/provider/openai.go 🔗

@@ -319,6 +319,10 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 	go func() {
 		for {
 			attempts++
+			// Kujtim: fixes an issue with anthropig models on openrouter
+			if len(params.Tools) == 0 {
+				params.Tools = nil
+			}
 			openaiStream := o.client.Chat.Completions.NewStreaming(
 				ctx,
 				params,
@@ -331,8 +335,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 			var currentToolCallID string
 			var currentToolCall openai.ChatCompletionMessageToolCall
 			var msgToolCalls []openai.ChatCompletionMessageToolCall
+			currentToolIndex := 0
 			for openaiStream.Next() {
 				chunk := openaiStream.Current()
+				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
+				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
+					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
+					currentToolIndex++
+				}
 				acc.AddChunk(chunk)
 				// This fixes multiple tool calls for some providers
 				for _, choice := range chunk.Choices {
@@ -348,6 +358,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 						if currentToolCallID == "" {
 							if toolCall.ID != "" {
 								currentToolCallID = toolCall.ID
+								eventChan <- ProviderEvent{
+									Type: EventToolUseStart,
+									ToolCall: &message.ToolCall{
+										ID:       toolCall.ID,
+										Name:     toolCall.Function.Name,
+										Finished: false,
+									},
+								}
 								currentToolCall = openai.ChatCompletionMessageToolCall{
 									ID:   toolCall.ID,
 									Type: "function",
@@ -359,13 +377,21 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 							}
 						} else {
 							// Delta tool use
-							if toolCall.ID == "" {
+							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
 								currentToolCall.Function.Arguments += toolCall.Function.Arguments
 							} else {
 								// Detect new tool use
 								if toolCall.ID != currentToolCallID {
 									msgToolCalls = append(msgToolCalls, currentToolCall)
 									currentToolCallID = toolCall.ID
+									eventChan <- ProviderEvent{
+										Type: EventToolUseStart,
+										ToolCall: &message.ToolCall{
+											ID:       toolCall.ID,
+											Name:     toolCall.Function.Name,
+											Finished: false,
+										},
+									}
 									currentToolCall = openai.ChatCompletionMessageToolCall{
 										ID:   toolCall.ID,
 										Type: "function",
@@ -378,7 +404,8 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 							}
 						}
 					}
-					if choice.FinishReason == "tool_calls" {
+					// Kujtim: some models send finish stop even for tool calls
+					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
 						msgToolCalls = append(msgToolCalls, currentToolCall)
 						if len(acc.Choices) > 0 {
 							acc.Choices[0].Message.ToolCalls = msgToolCalls
@@ -461,31 +488,41 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 }
 
 func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
-	var apiErr *openai.Error
-	if !errors.As(err, &apiErr) {
-		return false, 0, err
-	}
-
 	if attempts > maxRetries {
 		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
 	}
+	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+		return false, 0, err
+	}
+	var apiErr *openai.Error
+	retryMs := 0
+	retryAfterValues := []string{}
+	if errors.As(err, &apiErr) {
+		// Check for token expiration (401 Unauthorized)
+		if apiErr.StatusCode == 401 {
+			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
+			if err != nil {
+				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+			}
+			o.client = createOpenAIClient(o.providerOptions)
+			return true, 0, nil
+		}
 
-	// Check for token expiration (401 Unauthorized)
-	if apiErr.StatusCode == 401 {
-		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
-		if err != nil {
-			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
+			return false, 0, err
 		}
-		o.client = createOpenAIClient(o.providerOptions)
-		return true, 0, nil
-	}
 
-	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
-		return false, 0, err
+		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
 	}
 
-	retryMs := 0
-	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
+	if apiErr != nil {
+		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
+		if len(retryAfterValues) > 0 {
+			slog.Warn("Retry-After header", "values", retryAfterValues)
+		}
+	} else {
+		slog.Warn("OpenAI API error", "error", err.Error())
+	}
 
 	backoffMs := 2000 * (1 << (attempts - 1))
 	jitterMs := int(float64(backoffMs) * 0.2)