diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 2907008de13b0ff05150a18aadf01b54d99213ed..06610dedb9ff303f676925785e459c274e3f05e9 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -517,7 +517,10 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } // Handle tool calls that finish with empty arguments (e.g., Copilot). - // Normalize empty args to "{}" and emit the tool call if valid. + // Normalize empty args to "{}" and emit the tool call. + // If the arguments are invalid JSON, we still yield the tool call + // so the consumer (agent) can handle the error rather than + // silently dropping it. for idx, tc := range toolCalls { if tc.hasFinished { continue @@ -526,16 +529,14 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S tc.arguments = "{}" toolCalls[idx] = tc } - if xjson.IsValid(tc.arguments) { - if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) { - return - } - if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) { - return - } - tc.hasFinished = true - toolCalls[idx] = tc + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) { + return } + tc.hasFinished = true + toolCalls[idx] = tc } if len(acc.Choices) > 0 { diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 2230deae43e26c63435aa0f1235933a661e4b5db..058c8fc3d04828b24a0a2517ab5d62bedb4beb1c 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -2296,6 +2296,23 @@ func (sms *streamingMockServer) prepareToolStreamResponseWithEmptyArgs() { sms.chunks = chunks } +func (sms *streamingMockServer) prepareToolStreamResponseWithInvalidJSON() { + chunks := []string{ + // Tool call start + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_invalid_json","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Arguments delta containing \x00 which is not a valid JSON escape + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"old_string\":\"hello\\x00"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Remaining arguments — combined is {"old_string":"hello\x00world"} which is invalid JSON + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"world\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + // Finish with tool_calls + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n", + // Usage + `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n", + "data: [DONE]\n\n", + } + sms.chunks = chunks +} + func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) { var parts []fantasy.StreamPart for part := range stream { @@ -3022,6 +3039,90 @@ func TestDoStream(t *testing.T) { require.Equal(t, int64(35), finishPart.Usage.TotalTokens) require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens) }) + + t.Run("should drop tool calls with invalid JSON arguments", func(t *testing.T) { + t.Parallel() + + server := newStreamingMockServer() + defer server.close() + + server.prepareToolStreamResponseWithInvalidJSON() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo") + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ + Name: "test-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "old_string": map[string]any{ + "type": "string", + }, + "new_string": map[string]any{ + "type": "string", + }, + }, + "required": []string{"old_string", "new_string"}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + }, + }) + + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + // Find tool-related parts + toolInputStart, toolInputEnd, toolCall := -1, -1, -1 + var toolDeltas []string + var finishPart *fantasy.StreamPart + + for i, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart: + toolInputStart = i + require.Equal(t, "call_invalid_json", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + case fantasy.StreamPartTypeToolInputDelta: + toolDeltas = append(toolDeltas, part.Delta) + case fantasy.StreamPartTypeToolInputEnd: + toolInputEnd = i + require.Equal(t, "call_invalid_json", part.ID) + case fantasy.StreamPartTypeToolCall: + toolCall = i + require.Equal(t, "call_invalid_json", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + case fantasy.StreamPartTypeFinish: + finishPart = &part + } + } + + require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part") + require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part") + require.NotEqual(t, -1, toolCall, "expected ToolCall part") + + // Verify tool deltas combine to the complete input with \x00 + var fullInput strings.Builder + for _, delta := range toolDeltas { + fullInput.WriteString(delta) + } + require.Equal(t, `{"old_string":"hello\x00world"}`, fullInput.String()) + + // Finish reason is ToolCalls since the tool call was yielded + require.NotNil(t, finishPart) + require.Equal(t, fantasy.FinishReasonToolCalls, finishPart.FinishReason) + }) } func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {