From e0ddf8769a53fd9d1154c100bf3d72c87b1bb3fe Mon Sep 17 00:00:00 2001 From: mkaaad <119158371+mkaaad@users.noreply.github.com> Date: Sat, 9 May 2026 01:54:25 +0800 Subject: [PATCH] fix(openai): yield tool calls with invalid JSON instead of silently dropping (#223) Previously, tool calls with invalid JSON arguments were silently dropped in the stream, preventing the agent from reporting the error back to the model. Now all tool calls are yielded regardless of argument validity, and the agent handles validation via its existing repair/error flow. Additionally, track hasYieldedToolCall to map finishReason correctly: use tool_calls only when a tool call was actually yielded, and fall back to stop when no tool call was yielded despite the API saying tool_calls. --- providers/openai/language_model.go | 21 +++--- providers/openai/openai_test.go | 101 +++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 2907008de13b0ff05150a18aadf01b54d99213ed..06610dedb9ff303f676925785e459c274e3f05e9 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -517,7 +517,10 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } // Handle tool calls that finish with empty arguments (e.g., Copilot). - // Normalize empty args to "{}" and emit the tool call if valid. + // Normalize empty args to "{}" and emit the tool call. + // If the arguments are invalid JSON, we still yield the tool call + // so the consumer (agent) can handle the error rather than + // silently dropping it. for idx, tc := range toolCalls { if tc.hasFinished { continue @@ -526,16 +529,14 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S tc.arguments = "{}" toolCalls[idx] = tc } - if xjson.IsValid(tc.arguments) { - if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) { - return - } - if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) { - return - } - tc.hasFinished = true - toolCalls[idx] = tc + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) { + return } + tc.hasFinished = true + toolCalls[idx] = tc } if len(acc.Choices) > 0 { diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 2230deae43e26c63435aa0f1235933a661e4b5db..058c8fc3d04828b24a0a2517ab5d62bedb4beb1c 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -2296,6 +2296,23 @@ func (sms *streamingMockServer) prepareToolStreamResponseWithEmptyArgs() { sms.chunks = chunks } +func (sms *streamingMockServer) prepareToolStreamResponseWithInvalidJSON() { + chunks := []string{ + // Tool call start + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_invalid_json","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Arguments delta containing \x00 which is not a valid JSON escape + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"old_string\":\"hello\\x00"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Remaining arguments — combined is {"old_string":"hello\x00world"} which is invalid JSON + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"world\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Finish with tool_calls + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n", + // Usage + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n", + "data: [DONE]\n\n", + } + sms.chunks = chunks +} + func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) { var parts []fantasy.StreamPart for part := range stream { @@ -3022,6 +3039,90 @@ func TestDoStream(t *testing.T) { require.Equal(t, int64(35), finishPart.Usage.TotalTokens) require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens) }) + + t.Run("should drop tool calls with invalid JSON arguments", func(t *testing.T) { + t.Parallel() + + server := newStreamingMockServer() + defer server.close() + + server.prepareToolStreamResponseWithInvalidJSON() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo") + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ + Name: "test-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "old_string": map[string]any{ + "type": "string", + }, + "new_string": map[string]any{ + "type": "string", + }, + }, + "required": []string{"old_string", "new_string"}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + }, + }) + + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + // Find tool-related parts + toolInputStart, toolInputEnd, toolCall := -1, -1, -1 + var toolDeltas []string + var finishPart *fantasy.StreamPart + + for i, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart: + toolInputStart = i + require.Equal(t, "call_invalid_json", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + case fantasy.StreamPartTypeToolInputDelta: + toolDeltas = append(toolDeltas, part.Delta) + case fantasy.StreamPartTypeToolInputEnd: + toolInputEnd = i + require.Equal(t, "call_invalid_json", part.ID) + case fantasy.StreamPartTypeToolCall: + toolCall = i + require.Equal(t, "call_invalid_json", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + case fantasy.StreamPartTypeFinish: + finishPart = &part + } + } + + require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part") + require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part") + require.NotEqual(t, -1, toolCall, "expected ToolCall part") + + // Verify tool deltas combine to the complete input with \x00 + var fullInput strings.Builder + for _, delta := range toolDeltas { + fullInput.WriteString(delta) + } + require.Equal(t, `{"old_string":"hello\x00world"}`, fullInput.String()) + + // Finish reason is ToolCalls since the tool call was yielded + require.NotNil(t, finishPart) + require.Equal(t, fantasy.FinishReasonToolCalls, finishPart.FinishReason) + }) } func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {