diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 49b344c90bc0a2860cffdb530f688bfe1faad665..0093dd24ad77d6962780abdce84f64ba733a2df1 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -272,10 +272,13 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t currentContent := "" toolCalls := make([]message.ToolCall, 0) + var currentToolCallID string + var currentToolCall openai.ChatCompletionMessageToolCall + var msgToolCalls []openai.ChatCompletionMessageToolCall for openaiStream.Next() { chunk := openaiStream.Current() acc.AddChunk(chunk) - + // This fixes multiple tool calls for some providers for _, choice := range chunk.Choices { if choice.Delta.Content != "" { eventChan <- ProviderEvent{ @@ -283,6 +286,45 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t Content: choice.Delta.Content, } currentContent += choice.Delta.Content + } else if len(choice.Delta.ToolCalls) > 0 { + toolCall := choice.Delta.ToolCalls[0] + // Detect tool use start + if currentToolCallID == "" { + if toolCall.ID != "" { + currentToolCallID = toolCall.ID + currentToolCall = openai.ChatCompletionMessageToolCall{ + ID: toolCall.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunction{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + } + } else { + // Delta tool use + if toolCall.ID == "" { + currentToolCall.Function.Arguments += toolCall.Function.Arguments + } else { + // Detect new tool use + if toolCall.ID != currentToolCallID { + msgToolCalls = append(msgToolCalls, currentToolCall) + currentToolCallID = toolCall.ID + currentToolCall = openai.ChatCompletionMessageToolCall{ + ID: toolCall.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunction{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + } + } + } + } + if choice.FinishReason == "tool_calls" { + msgToolCalls = append(msgToolCalls, currentToolCall) + acc.Choices[0].Message.ToolCalls = msgToolCalls } } } @@ -293,6 +335,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t jsonData, _ := json.Marshal(acc.ChatCompletion) slog.Debug("Response", "messages", string(jsonData)) } + resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason if resultFinishReason == "" { // If the finish reason is empty, we assume it was a successful completion