fix: openai provider tool calls

Kujtim Hoxha created

Change summary

internal/llm/provider/openai.go | 30 +++++++++++++++---------------
1 file changed, 15 insertions(+), 15 deletions(-)

Detailed changes

internal/llm/provider/openai.go 🔗

@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
-	"slices"
 	"strings"
 	"time"
 
@@ -342,18 +341,15 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 			acc := openai.ChatCompletionAccumulator{}
 			currentContent := ""
 			toolCalls := make([]message.ToolCall, 0)
-			var msgToolCalls []openai.ChatCompletionMessageToolCall
+			msgToolCalls := make(map[int64]openai.ChatCompletionMessageToolCall)
 			for openaiStream.Next() {
 				chunk := openaiStream.Current()
-				if len(chunk.Choices) == 0 {
-					continue
-				}
 				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
-				if len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
+				if len(chunk.Choices) != 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
 					chunk.Choices[0].Delta.ToolCalls[0].Index = 0
 				}
 				acc.AddChunk(chunk)
-				for i, choice := range chunk.Choices {
+				for _, choice := range chunk.Choices {
 					reasoning, ok := choice.Delta.JSON.ExtraFields["reasoning"]
 					if ok && reasoning.Raw() != "" {
 						reasoningStr := ""
@@ -374,14 +370,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 					} else if len(choice.Delta.ToolCalls) > 0 {
 						toolCall := choice.Delta.ToolCalls[0]
 						newToolCall := false
-						if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists
-							existingToolCall := msgToolCalls[toolCall.Index]
+						if existingToolCall, ok := msgToolCalls[toolCall.Index]; ok { // tool call exists
 							if toolCall.ID != "" && toolCall.ID != existingToolCall.ID {
 								found := false
 								// try to find the tool based on the ID
-								for i, tool := range msgToolCalls {
+								for _, tool := range msgToolCalls {
 									if tool.ID == toolCall.ID {
-										msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments
+										existingToolCall.Function.Arguments += toolCall.Function.Arguments
+										msgToolCalls[toolCall.Index] = existingToolCall
 										found = true
 									}
 								}
@@ -389,7 +385,8 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 									newToolCall = true
 								}
 							} else {
-								msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments
+								existingToolCall.Function.Arguments += toolCall.Function.Arguments
+								msgToolCalls[toolCall.Index] = existingToolCall
 							}
 						} else {
 							newToolCall = true
@@ -406,17 +403,16 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 									Finished: false,
 								},
 							}
-							msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{
+							msgToolCalls[toolCall.Index] = openai.ChatCompletionMessageToolCall{
 								ID:   toolCall.ID,
 								Type: "function",
 								Function: openai.ChatCompletionMessageToolCallFunction{
 									Name:      toolCall.Function.Name,
 									Arguments: toolCall.Function.Arguments,
 								},
-							})
+							}
 						}
 					}
-					acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls)
 				}
 			}
 
@@ -541,6 +537,10 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too
 
 	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
 		for _, call := range completion.Choices[0].Message.ToolCalls {
+			// accumulator for some reason does this.
+			if call.Function.Name == "" {
+				continue
+			}
 			toolCall := message.ToolCall{
 				ID:       call.ID,
 				Name:     call.Function.Name,