diff --git a/content_json.go b/content_json.go index bde7393ec85a92d4a7939e0e2bb1f22368fa6bb7..1f2d53ed0d1ca57125436a730802c3c45ae91e62 100644 --- a/content_json.go +++ b/content_json.go @@ -393,11 +393,8 @@ func (t *ToolResultContent) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ToolResultOutputContentText. func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - Text string `json:"text"` - }{ - Text: t.Text, - }) + type alias ToolResultOutputContentText + dataBytes, err := json.Marshal(alias(t)) if err != nil { return nil, err } @@ -415,15 +412,14 @@ func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error { return err } - var temp struct { - Text string `json:"text"` - } + type alias ToolResultOutputContentText + var temp alias if err := json.Unmarshal(tr.Data, &temp); err != nil { return err } - t.Text = temp.Text + *t = ToolResultOutputContentText(temp) return nil } @@ -470,13 +466,8 @@ func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ToolResultOutputContentMedia. func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - Data string `json:"data"` - MediaType string `json:"media_type"` - }{ - Data: t.Data, - MediaType: t.MediaType, - }) + type alias ToolResultOutputContentMedia + dataBytes, err := json.Marshal(alias(t)) if err != nil { return nil, err } @@ -494,17 +485,14 @@ func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error { return err } - var temp struct { - Data string `json:"data"` - MediaType string `json:"media_type"` - } + type alias ToolResultOutputContentMedia + var temp alias if err := json.Unmarshal(tr.Data, &temp); err != nil { return err } - t.Data = temp.Data - t.MediaType = temp.MediaType + *t = ToolResultOutputContentMedia(temp) return nil } @@ -870,15 +858,8 @@ func (f *FunctionTool) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ProviderDefinedTool. func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - ID string `json:"id"` - Name string `json:"name"` - Args map[string]any `json:"args"` - }{ - ID: p.ID, - Name: p.Name, - Args: p.Args, - }) + type alias ProviderDefinedTool + dataBytes, err := json.Marshal(alias(p)) if err != nil { return nil, err } @@ -896,19 +877,14 @@ func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error { return err } - var aux struct { - ID string `json:"id"` - Name string `json:"name"` - Args map[string]any `json:"args"` - } + type alias ProviderDefinedTool + var aux alias if err := json.Unmarshal(tj.Data, &aux); err != nil { return err } - p.ID = aux.ID - p.Name = aux.Name - p.Args = aux.Args + *p = ProviderDefinedTool(aux) return nil } diff --git a/go.mod b/go.mod index 95af7849ecab4aa12f7ee093da282deb0b3782de..45a86e04266e7ee46e43f2ee4b19ced741b21058 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module charm.land/fantasy go 1.25 require ( - charm.land/x/vcr v0.1.0 + charm.land/x/vcr v0.1.1 cloud.google.com/go/auth v0.17.0 github.com/RealAlexandreAI/json-repair v0.0.14 github.com/aws/aws-sdk-go-v2 v1.39.6 diff --git a/go.sum b/go.sum index 291a585bae7b36a4911dce0421225565d48e0630..e6fecef060ca2160b1ed534c393b4df504c9f17b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -charm.land/x/vcr v0.1.0 h1:XhCUVij6Ss6+xJuAb2n4mNRGSS/SrnNoUmEwJziy+Dg= -charm.land/x/vcr v0.1.0/go.mod h1:eByq2gqzWvcct/8XE2XO5KznoWEBiXH56+y2gphbltM= +charm.land/x/vcr v0.1.1 h1:PXCFMUG0rPtyk35rhfzYCJEduOzWXCIbrXTFq4OF/9Q= +charm.land/x/vcr v0.1.1/go.mod h1:eByq2gqzWvcct/8XE2XO5KznoWEBiXH56+y2gphbltM= cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= diff --git a/json_test.go b/json_test.go index da9f491407e284dbc988c59595d79aa631dfd524..b1088434a40b7a28006b1a2b36f2ffc4db9a4753 100644 --- a/json_test.go +++ b/json_test.go @@ -645,3 +645,108 @@ func TestPromptSerialization(t *testing.T) { } }) } + +func TestStreamPartErrorSerialization(t *testing.T) { + t.Run("stream part with ProviderError containing OpenAI API error", func(t *testing.T) { + // Create a mock OpenAI API error + openaiErr := errors.New("invalid_api_key: Incorrect API key provided") + + // Wrap in ProviderError + providerErr := &ProviderError{ + Title: "unauthorized", + Message: "Incorrect API key provided", + Cause: openaiErr, + URL: "https://api.openai.com/v1/chat/completions", + StatusCode: 401, + RequestBody: []byte(`{"model":"gpt-4","messages":[]}`), + ResponseHeaders: map[string]string{ + "content-type": "application/json", + }, + ResponseBody: []byte(`{"error":{"message":"Incorrect API key provided","type":"invalid_request_error"}}`), + } + + // Create StreamPart with error + streamPart := StreamPart{ + Type: StreamPartTypeError, + Error: providerErr, + } + + // Marshal the stream part + data, err := json.Marshal(streamPart) + if err != nil { + t.Fatalf("failed to marshal stream part: %v", err) + } + + // Unmarshal back + var decoded StreamPart + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("failed to unmarshal stream part: %v", err) + } + + // Verify the stream part type + if decoded.Type != StreamPartTypeError { + t.Errorf("type mismatch: got %v, want %v", decoded.Type, StreamPartTypeError) + } + + // Verify error exists + if decoded.Error == nil { + t.Fatal("expected error to be present, got nil") + } + + // Verify error message + expectedMsg := "unauthorized: Incorrect API key provided" + if decoded.Error.Error() != expectedMsg { + t.Errorf("error message mismatch: got %q, want %q", decoded.Error.Error(), expectedMsg) + } + }) + + t.Run("unmarshal stream part with error from JSON", func(t *testing.T) { + // JSON representing a StreamPart with an error + jsonData := `{ + "type": "error", + "error": "unauthorized: Incorrect API key provided", + "id": "", + "tool_call_name": "", + "tool_call_input": "", + "delta": "", + "provider_executed": false, + "usage": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0 + }, + "finish_reason": "", + "warnings": null, + "source_type": "", + "url": "", + "title": "", + "provider_metadata": null + }` + + var streamPart StreamPart + err := json.Unmarshal([]byte(jsonData), &streamPart) + if err != nil { + t.Fatalf("failed to unmarshal stream part: %v", err) + } + + // Verify the stream part type + if streamPart.Type != StreamPartTypeError { + t.Errorf("type mismatch: got %v, want %v", streamPart.Type, StreamPartTypeError) + } + + // Verify error exists + if streamPart.Error == nil { + t.Fatal("expected error to be present, got nil") + } + + // Verify error message + expectedMsg := "unauthorized: Incorrect API key provided" + if streamPart.Error.Error() != expectedMsg { + t.Errorf("error message mismatch: got %q, want %q", streamPart.Error.Error(), expectedMsg) + } + }) +} diff --git a/model_json.go b/model_json.go index 90a8c78520977ae8dc797c2c0575f168297917ca..2fd883a2fc714b0be61b4e47809a70698ef6e6a0 100644 --- a/model_json.go +++ b/model_json.go @@ -102,42 +102,46 @@ func (r *Response) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements json.Marshaler for StreamPart. +func (s StreamPart) MarshalJSON() ([]byte, error) { + type alias StreamPart + aux := struct { + alias + Error string `json:"error,omitempty"` + }{ + alias: (alias)(s), + } + + // Marshal error to string + if s.Error != nil { + aux.Error = s.Error.Error() + } + + // Clear the original Error field to avoid duplicate marshaling + aux.alias.Error = nil + + return json.Marshal(aux) +} + // UnmarshalJSON implements json.Unmarshaler for StreamPart. func (s *StreamPart) UnmarshalJSON(data []byte) error { - var aux struct { - Type StreamPartType `json:"type"` - ID string `json:"id"` - ToolCallName string `json:"tool_call_name"` - ToolCallInput string `json:"tool_call_input"` - Delta string `json:"delta"` - ProviderExecuted bool `json:"provider_executed"` - Usage Usage `json:"usage"` - FinishReason FinishReason `json:"finish_reason"` - Error error `json:"error"` - Warnings []CallWarning `json:"warnings"` - SourceType SourceType `json:"source_type"` - URL string `json:"url"` - Title string `json:"title"` + type alias StreamPart + aux := struct { + *alias + Error string `json:"error"` ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + }{ + alias: (*alias)(s), } if err := json.Unmarshal(data, &aux); err != nil { return err } - s.Type = aux.Type - s.ID = aux.ID - s.ToolCallName = aux.ToolCallName - s.ToolCallInput = aux.ToolCallInput - s.Delta = aux.Delta - s.ProviderExecuted = aux.ProviderExecuted - s.Usage = aux.Usage - s.FinishReason = aux.FinishReason - s.Error = aux.Error - s.Warnings = aux.Warnings - s.SourceType = aux.SourceType - s.URL = aux.URL - s.Title = aux.Title + // Unmarshal error string back to error type + if aux.Error != "" { + s.Error = fmt.Errorf("%s", aux.Error) + } // Unmarshal ProviderMetadata if len(aux.ProviderMetadata) > 0 { diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index e1acf0036fb0664883b16bada360d14d2d4c1de7..55227bd92a263841eef3d15c1be8a8ab198c0f11 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -557,7 +557,11 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } func isReasoningModel(modelID string) bool { - return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat") + return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") || + strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") || + strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") || + strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") || + strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat") } func isSearchPreviewModel(modelID string) bool { @@ -565,13 +569,14 @@ func isSearchPreviewModel(modelID string) bool { } func supportsFlexProcessing(modelID string) bool { - return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5") + return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") || + strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5") } func supportsPriorityProcessing(modelID string) bool { - return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") || - strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") || - strings.HasPrefix(modelID, "o4-mini") + return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") || + strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") || + strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini") } func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) { diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 13a9da50a26069aafa4e62788b02f99305147007..7a30139ac31080566fa43393512bcc3b8ba085b6 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -56,16 +56,17 @@ type responsesModelConfig struct { func getResponsesModelConfig(modelID string) responsesModelConfig { supportsFlexProcessing := strings.HasPrefix(modelID, "o3") || - strings.HasPrefix(modelID, "o4-mini") || - (strings.HasPrefix(modelID, "gpt-5") && !strings.HasPrefix(modelID, "gpt-5-chat")) - - supportsPriorityProcessing := strings.HasPrefix(modelID, "gpt-4") || - strings.HasPrefix(modelID, "gpt-5-mini") || - (strings.HasPrefix(modelID, "gpt-5") && - !strings.HasPrefix(modelID, "gpt-5-nano") && - !strings.HasPrefix(modelID, "gpt-5-chat")) || + strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini") || + (strings.Contains(modelID, "gpt-5") && !strings.Contains(modelID, "gpt-5-chat")) + + supportsPriorityProcessing := strings.Contains(modelID, "gpt-4") || + strings.Contains(modelID, "gpt-5-mini") || + (strings.Contains(modelID, "gpt-5") && + !strings.Contains(modelID, "gpt-5-nano") && + !strings.Contains(modelID, "gpt-5-chat")) || strings.HasPrefix(modelID, "o3") || - strings.HasPrefix(modelID, "o4-mini") + strings.Contains(modelID, "-o3") || + strings.Contains(modelID, "o4-mini") defaults := responsesModelConfig{ requiredAutoTruncation: false, @@ -74,7 +75,7 @@ func getResponsesModelConfig(modelID string) responsesModelConfig { supportsPriorityProcessing: supportsPriorityProcessing, } - if strings.HasPrefix(modelID, "gpt-5-chat") { + if strings.Contains(modelID, "gpt-5-chat") { return responsesModelConfig{ isReasoningModel: false, systemMessageMode: defaults.systemMessageMode, @@ -84,11 +85,13 @@ func getResponsesModelConfig(modelID string) responsesModelConfig { } } - if strings.HasPrefix(modelID, "o") || - strings.HasPrefix(modelID, "gpt-5") || - strings.HasPrefix(modelID, "codex-") || - strings.HasPrefix(modelID, "computer-use") { - if strings.HasPrefix(modelID, "o1-mini") || strings.HasPrefix(modelID, "o1-preview") { + if strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") || + strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") || + strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") || + strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") || + strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "codex-") || + strings.Contains(modelID, "computer-use") { + if strings.Contains(modelID, "o1-mini") || strings.Contains(modelID, "o1-preview") { return responsesModelConfig{ isReasoningModel: true, systemMessageMode: "remove", diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index 88dba7f42ae5af4aa4304aff64ccb2bb15c86525..935d97a1d78dcc56c80064ee0334de93cb9a17aa 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -151,6 +151,9 @@ var responsesReasoningModelIDs = []string{ "gpt-5-nano", "gpt-5-nano-2025-08-07", "gpt-5-codex", + "gpt-5.1", + "gpt-5.1-codex", + "gpt-5.1-codex-mini", } // responsesModelIds lists all model IDs for OpenAI Responses API.