@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"log/slog"
+ "slices"
"strings"
"time"
@@ -14,6 +15,7 @@ import (
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/google/uuid"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
@@ -338,21 +340,16 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
acc := openai.ChatCompletionAccumulator{}
currentContent := ""
toolCalls := make([]message.ToolCall, 0)
-
- var currentToolCallID string
- var currentToolCall openai.ChatCompletionMessageToolCall
var msgToolCalls []openai.ChatCompletionMessageToolCall
- currentToolIndex := 0
for openaiStream.Next() {
chunk := openaiStream.Current()
// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
- chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
- currentToolIndex++
+ chunk.Choices[0].Delta.ToolCalls[0].Index = 0
}
acc.AddChunk(chunk)
// This fixes multiple tool calls for some providers
- for _, choice := range chunk.Choices {
+ for i, choice := range chunk.Choices {
if choice.Delta.Content != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -361,63 +358,50 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
currentContent += choice.Delta.Content
} else if len(choice.Delta.ToolCalls) > 0 {
toolCall := choice.Delta.ToolCalls[0]
- // Detect tool use start
- if currentToolCallID == "" {
- if toolCall.ID != "" {
- currentToolCallID = toolCall.ID
- eventChan <- ProviderEvent{
- Type: EventToolUseStart,
- ToolCall: &message.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Function.Name,
- Finished: false,
- },
+ newToolCall := false
+ if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists
+ existingToolCall := msgToolCalls[toolCall.Index]
+ if toolCall.ID != "" && toolCall.ID != existingToolCall.ID {
+ found := false
+ // try to find the tool based on the ID
+ for i, tool := range msgToolCalls {
+ if tool.ID == toolCall.ID {
+ msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments
+ found = true
+ }
}
- currentToolCall = openai.ChatCompletionMessageToolCall{
- ID: toolCall.ID,
- Type: "function",
- Function: openai.ChatCompletionMessageToolCallFunction{
- Name: toolCall.Function.Name,
- Arguments: toolCall.Function.Arguments,
- },
+ if !found {
+ newToolCall = true
}
- }
- } else {
- // Delta tool use
- if toolCall.ID == "" || toolCall.ID == currentToolCallID {
- currentToolCall.Function.Arguments += toolCall.Function.Arguments
} else {
- // Detect new tool use
- if toolCall.ID != currentToolCallID {
- msgToolCalls = append(msgToolCalls, currentToolCall)
- currentToolCallID = toolCall.ID
- eventChan <- ProviderEvent{
- Type: EventToolUseStart,
- ToolCall: &message.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Function.Name,
- Finished: false,
- },
- }
- currentToolCall = openai.ChatCompletionMessageToolCall{
- ID: toolCall.ID,
- Type: "function",
- Function: openai.ChatCompletionMessageToolCallFunction{
- Name: toolCall.Function.Name,
- Arguments: toolCall.Function.Arguments,
- },
- }
- }
+ msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments
}
+ } else {
+ newToolCall = true
}
- }
- // Kujtim: some models send finish stop even for tool calls
- if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
- msgToolCalls = append(msgToolCalls, currentToolCall)
- if len(acc.Choices) > 0 {
- acc.Choices[0].Message.ToolCalls = msgToolCalls
+ if newToolCall { // new tool call
+ if toolCall.ID == "" {
+ toolCall.ID = uuid.NewString()
+ }
+ eventChan <- ProviderEvent{
+ Type: EventToolUseStart,
+ ToolCall: &message.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Function.Name,
+ Finished: false,
+ },
+ }
+ msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{
+ ID: toolCall.ID,
+ Type: "function",
+ Function: openai.ChatCompletionMessageToolCallFunction{
+ Name: toolCall.Function.Name,
+ Arguments: toolCall.Function.Arguments,
+ },
+ })
}
}
+ acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls)
}
}