fix: some openai providers

Kujtim Hoxha and Peter Steinberger created

Co-authored-by: Peter Steinberger <steipete@gmail.com>

Change summary

internal/llm/provider/openai.go | 98 ++++++++++++++--------------------
1 file changed, 41 insertions(+), 57 deletions(-)

Detailed changes

internal/llm/provider/openai.go 🔗

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"slices"
 	"strings"
 	"time"
 
@@ -14,6 +15,7 @@ import (
 	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/message"
+	"github.com/google/uuid"
 	"github.com/openai/openai-go"
 	"github.com/openai/openai-go/option"
 	"github.com/openai/openai-go/packages/param"
@@ -338,21 +340,16 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 			acc := openai.ChatCompletionAccumulator{}
 			currentContent := ""
 			toolCalls := make([]message.ToolCall, 0)
-
-			var currentToolCallID string
-			var currentToolCall openai.ChatCompletionMessageToolCall
 			var msgToolCalls []openai.ChatCompletionMessageToolCall
-			currentToolIndex := 0
 			for openaiStream.Next() {
 				chunk := openaiStream.Current()
 				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
 				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
-					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
-					currentToolIndex++
+					chunk.Choices[0].Delta.ToolCalls[0].Index = 0
 				}
 				acc.AddChunk(chunk)
 				// This fixes multiple tool calls for some providers
-				for _, choice := range chunk.Choices {
+				for i, choice := range chunk.Choices {
 					if choice.Delta.Content != "" {
 						eventChan <- ProviderEvent{
 							Type:    EventContentDelta,
@@ -361,63 +358,50 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 						currentContent += choice.Delta.Content
 					} else if len(choice.Delta.ToolCalls) > 0 {
 						toolCall := choice.Delta.ToolCalls[0]
-						// Detect tool use start
-						if currentToolCallID == "" {
-							if toolCall.ID != "" {
-								currentToolCallID = toolCall.ID
-								eventChan <- ProviderEvent{
-									Type: EventToolUseStart,
-									ToolCall: &message.ToolCall{
-										ID:       toolCall.ID,
-										Name:     toolCall.Function.Name,
-										Finished: false,
-									},
+						newToolCall := false
+						if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists
+							existingToolCall := msgToolCalls[toolCall.Index]
+							if toolCall.ID != "" && toolCall.ID != existingToolCall.ID {
+								found := false
+								// try to find the tool based on the ID
+								for i, tool := range msgToolCalls {
+									if tool.ID == toolCall.ID {
+										msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments
+										found = true
+									}
 								}
-								currentToolCall = openai.ChatCompletionMessageToolCall{
-									ID:   toolCall.ID,
-									Type: "function",
-									Function: openai.ChatCompletionMessageToolCallFunction{
-										Name:      toolCall.Function.Name,
-										Arguments: toolCall.Function.Arguments,
-									},
+								if !found {
+									newToolCall = true
 								}
-							}
-						} else {
-							// Delta tool use
-							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
-								currentToolCall.Function.Arguments += toolCall.Function.Arguments
 							} else {
-								// Detect new tool use
-								if toolCall.ID != currentToolCallID {
-									msgToolCalls = append(msgToolCalls, currentToolCall)
-									currentToolCallID = toolCall.ID
-									eventChan <- ProviderEvent{
-										Type: EventToolUseStart,
-										ToolCall: &message.ToolCall{
-											ID:       toolCall.ID,
-											Name:     toolCall.Function.Name,
-											Finished: false,
-										},
-									}
-									currentToolCall = openai.ChatCompletionMessageToolCall{
-										ID:   toolCall.ID,
-										Type: "function",
-										Function: openai.ChatCompletionMessageToolCallFunction{
-											Name:      toolCall.Function.Name,
-											Arguments: toolCall.Function.Arguments,
-										},
-									}
-								}
+								msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments
 							}
+						} else {
+							newToolCall = true
 						}
-					}
-					// Kujtim: some models send finish stop even for tool calls
-					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
-						msgToolCalls = append(msgToolCalls, currentToolCall)
-						if len(acc.Choices) > 0 {
-							acc.Choices[0].Message.ToolCalls = msgToolCalls
+						if newToolCall { // new tool call
+							if toolCall.ID == "" {
+								toolCall.ID = uuid.NewString()
+							}
+							eventChan <- ProviderEvent{
+								Type: EventToolUseStart,
+								ToolCall: &message.ToolCall{
+									ID:       toolCall.ID,
+									Name:     toolCall.Function.Name,
+									Finished: false,
+								},
+							}
+							msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{
+								ID:   toolCall.ID,
+								Type: "function",
+								Function: openai.ChatCompletionMessageToolCallFunction{
+									Name:      toolCall.Function.Name,
+									Arguments: toolCall.Function.Arguments,
+								},
+							})
 						}
 					}
+					acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls)
 				}
 			}