diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 7075ddcc4dd8bceb14e8fa6837d2df391e9a1298..70bbe128663ce6163a93a2eb172e6d23f5873af3 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + "slices" "strings" "time" @@ -14,6 +15,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/message" + "github.com/google/uuid" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/param" @@ -338,21 +340,16 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t acc := openai.ChatCompletionAccumulator{} currentContent := "" toolCalls := make([]message.ToolCall, 0) - - var currentToolCallID string - var currentToolCall openai.ChatCompletionMessageToolCall var msgToolCalls []openai.ChatCompletionMessageToolCall - currentToolIndex := 0 for openaiStream.Next() { chunk := openaiStream.Current() // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 { - chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex) - currentToolIndex++ + chunk.Choices[0].Delta.ToolCalls[0].Index = 0 } acc.AddChunk(chunk) // This fixes multiple tool calls for some providers - for _, choice := range chunk.Choices { + for i, choice := range chunk.Choices { if choice.Delta.Content != "" { eventChan <- ProviderEvent{ Type: EventContentDelta, @@ -361,63 +358,50 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t currentContent += choice.Delta.Content } else if len(choice.Delta.ToolCalls) > 0 { toolCall := choice.Delta.ToolCalls[0] - // Detect tool use start - if currentToolCallID == "" { - if toolCall.ID != "" { - currentToolCallID = toolCall.ID - eventChan <- ProviderEvent{ - Type: EventToolUseStart, - ToolCall: &message.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Function.Name, - Finished: false, - }, + newToolCall := false + if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists + existingToolCall := msgToolCalls[toolCall.Index] + if toolCall.ID != "" && toolCall.ID != existingToolCall.ID { + found := false + // try to find the tool based on the ID + for i, tool := range msgToolCalls { + if tool.ID == toolCall.ID { + msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments + found = true + } } - currentToolCall = openai.ChatCompletionMessageToolCall{ - ID: toolCall.ID, - Type: "function", - Function: openai.ChatCompletionMessageToolCallFunction{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, - }, + if !found { + newToolCall = true } - } - } else { - // Delta tool use - if toolCall.ID == "" || toolCall.ID == currentToolCallID { - currentToolCall.Function.Arguments += toolCall.Function.Arguments } else { - // Detect new tool use - if toolCall.ID != currentToolCallID { - msgToolCalls = append(msgToolCalls, currentToolCall) - currentToolCallID = toolCall.ID - eventChan <- ProviderEvent{ - Type: EventToolUseStart, - ToolCall: &message.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Function.Name, - Finished: false, - }, - } - currentToolCall = openai.ChatCompletionMessageToolCall{ - ID: toolCall.ID, - Type: "function", - Function: openai.ChatCompletionMessageToolCallFunction{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, - }, - } - } + msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments } + } else { + newToolCall = true } - } - // Kujtim: some models send finish stop even for tool calls - if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") { - msgToolCalls = append(msgToolCalls, currentToolCall) - if len(acc.Choices) > 0 { - acc.Choices[0].Message.ToolCalls = msgToolCalls + if newToolCall { // new tool call + if toolCall.ID == "" { + toolCall.ID = uuid.NewString() + } + eventChan <- ProviderEvent{ + Type: EventToolUseStart, + ToolCall: &message.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Function.Name, + Finished: false, + }, + } + msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{ + ID: toolCall.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunction{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + }) } } + acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls) } }