fix(openai): yield tool calls with invalid JSON instead of silently dropping (#223)

mkaaad created

Previously, tool calls with invalid JSON arguments were silently dropped
in the stream, preventing the agent from reporting the error back to the
model. Now all tool calls are yielded regardless of argument validity,
and the agent handles validation via its existing repair/error flow.

Additionally, track hasYieldedToolCall to map finishReason correctly:
use tool_calls only when a tool call was actually yielded, and fall back
to stop when no tool call was yielded despite the API saying tool_calls.

Change summary

providers/openai/language_model.go |  21 +++---
providers/openai/openai_test.go    | 101 ++++++++++++++++++++++++++++++++
2 files changed, 112 insertions(+), 10 deletions(-)

Detailed changes

providers/openai/language_model.go 🔗

@@ -517,7 +517,10 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 			}
 
 			// Handle tool calls that finish with empty arguments (e.g., Copilot).
-			// Normalize empty args to "{}" and emit the tool call if valid.
+			// Normalize empty args to "{}" and emit the tool call.
+			// If the arguments are invalid JSON, we still yield the tool call
+			// so the consumer (agent) can handle the error rather than
+			// silently dropping it.
 			for idx, tc := range toolCalls {
 				if tc.hasFinished {
 					continue
@@ -526,16 +529,14 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 					tc.arguments = "{}"
 					toolCalls[idx] = tc
 				}
-				if xjson.IsValid(tc.arguments) {
-					if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) {
-						return
-					}
-					if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) {
-						return
-					}
-					tc.hasFinished = true
-					toolCalls[idx] = tc
+				if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) {
+					return
+				}
+				if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) {
+					return
 				}
+				tc.hasFinished = true
+				toolCalls[idx] = tc
 			}
 
 			if len(acc.Choices) > 0 {

providers/openai/openai_test.go 🔗

@@ -2296,6 +2296,23 @@ func (sms *streamingMockServer) prepareToolStreamResponseWithEmptyArgs() {
 	sms.chunks = chunks
 }
 
+func (sms *streamingMockServer) prepareToolStreamResponseWithInvalidJSON() {
+	chunks := []string{
+		// Tool call start
+		`data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_invalid_json","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		// Arguments delta containing \x00 which is not a valid JSON escape
+		`data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"old_string\":\"hello\\x00"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		// Remaining arguments — combined is {"old_string":"hello\x00world"} which is invalid JSON
+		`data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"world\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		// Finish with tool_calls
+		`data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
+		// Usage
+		`data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
+		"data: [DONE]\n\n",
+	}
+	sms.chunks = chunks
+}
+
 func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
 	var parts []fantasy.StreamPart
 	for part := range stream {
@@ -3022,6 +3039,90 @@ func TestDoStream(t *testing.T) {
 		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
 		require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
 	})
+
+	t.Run("should drop tool calls with invalid JSON arguments", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareToolStreamResponseWithInvalidJSON()
+
+		provider, err := New(
+			WithAPIKey("test-api-key"),
+			WithBaseURL(server.server.URL),
+		)
+		require.NoError(t, err)
+		model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), fantasy.Call{
+			Prompt: testPrompt,
+			Tools: []fantasy.Tool{
+				fantasy.FunctionTool{
+					Name: "test-tool",
+					InputSchema: map[string]any{
+						"type": "object",
+						"properties": map[string]any{
+							"old_string": map[string]any{
+								"type": "string",
+							},
+							"new_string": map[string]any{
+								"type": "string",
+							},
+						},
+						"required":             []string{"old_string", "new_string"},
+						"additionalProperties": false,
+						"$schema":              "http://json-schema.org/draft-07/schema#",
+					},
+				},
+			},
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find tool-related parts
+		toolInputStart, toolInputEnd, toolCall := -1, -1, -1
+		var toolDeltas []string
+		var finishPart *fantasy.StreamPart
+
+		for i, part := range parts {
+			switch part.Type {
+			case fantasy.StreamPartTypeToolInputStart:
+				toolInputStart = i
+				require.Equal(t, "call_invalid_json", part.ID)
+				require.Equal(t, "test-tool", part.ToolCallName)
+			case fantasy.StreamPartTypeToolInputDelta:
+				toolDeltas = append(toolDeltas, part.Delta)
+			case fantasy.StreamPartTypeToolInputEnd:
+				toolInputEnd = i
+				require.Equal(t, "call_invalid_json", part.ID)
+			case fantasy.StreamPartTypeToolCall:
+				toolCall = i
+				require.Equal(t, "call_invalid_json", part.ID)
+				require.Equal(t, "test-tool", part.ToolCallName)
+			case fantasy.StreamPartTypeFinish:
+				finishPart = &part
+			}
+		}
+
+		require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part")
+		require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part")
+		require.NotEqual(t, -1, toolCall, "expected ToolCall part")
+
+		// Verify tool deltas combine to the complete input with \x00
+		var fullInput strings.Builder
+		for _, delta := range toolDeltas {
+			fullInput.WriteString(delta)
+		}
+		require.Equal(t, `{"old_string":"hello\x00world"}`, fullInput.String())
+
+		// Finish reason is ToolCalls since the tool call was yielded
+		require.NotNil(t, finishPart)
+		require.Equal(t, fantasy.FinishReasonToolCalls, finishPart.FinishReason)
+	})
 }
 
 func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {