From 0c8e111af5d8bccf77d544da967c28f7df56d36f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 27 Aug 2025 14:30:08 -0400 Subject: [PATCH] fix: openai provider tool calls --- internal/llm/provider/openai.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index e28b0444df023245e235f4a9cffa47adb9a46286..ffed9325e0a70fb86ffe2fecd5b7f00e63e3e215 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "log/slog" - "slices" "strings" "time" @@ -342,18 +341,15 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t acc := openai.ChatCompletionAccumulator{} currentContent := "" toolCalls := make([]message.ToolCall, 0) - var msgToolCalls []openai.ChatCompletionMessageToolCall + msgToolCalls := make(map[int64]openai.ChatCompletionMessageToolCall) for openaiStream.Next() { chunk := openaiStream.Current() - if len(chunk.Choices) == 0 { - continue - } // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index - if len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 { + if len(chunk.Choices) != 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 { chunk.Choices[0].Delta.ToolCalls[0].Index = 0 } acc.AddChunk(chunk) - for i, choice := range chunk.Choices { + for _, choice := range chunk.Choices { reasoning, ok := choice.Delta.JSON.ExtraFields["reasoning"] if ok && reasoning.Raw() != "" { reasoningStr := "" @@ -374,14 +370,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } else if len(choice.Delta.ToolCalls) > 0 { toolCall := choice.Delta.ToolCalls[0] newToolCall := false - if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists - existingToolCall := msgToolCalls[toolCall.Index] + if existingToolCall, ok := msgToolCalls[toolCall.Index]; ok { // tool call exists if toolCall.ID != "" && toolCall.ID != existingToolCall.ID { found := false // try to find the tool based on the ID - for i, tool := range msgToolCalls { + for _, tool := range msgToolCalls { if tool.ID == toolCall.ID { - msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments + existingToolCall.Function.Arguments += toolCall.Function.Arguments + msgToolCalls[toolCall.Index] = existingToolCall found = true } } @@ -389,7 +385,8 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t newToolCall = true } } else { - msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments + existingToolCall.Function.Arguments += toolCall.Function.Arguments + msgToolCalls[toolCall.Index] = existingToolCall } } else { newToolCall = true @@ -406,17 +403,16 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t Finished: false, }, } - msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{ + msgToolCalls[toolCall.Index] = openai.ChatCompletionMessageToolCall{ ID: toolCall.ID, Type: "function", Function: openai.ChatCompletionMessageToolCallFunction{ Name: toolCall.Function.Name, Arguments: toolCall.Function.Arguments, }, - }) + } } } - acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls) } } @@ -541,6 +537,10 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { for _, call := range completion.Choices[0].Message.ToolCalls { + // accumulator for some reason does this. + if call.Function.Name == "" { + continue + } toolCall := message.ToolCall{ ID: call.ID, Name: call.Function.Name,