diff --git a/content_json.go b/content_json.go index bde7393ec85a92d4a7939e0e2bb1f22368fa6bb7..1f2d53ed0d1ca57125436a730802c3c45ae91e62 100644 --- a/content_json.go +++ b/content_json.go @@ -393,11 +393,8 @@ func (t *ToolResultContent) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ToolResultOutputContentText. func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - Text string `json:"text"` - }{ - Text: t.Text, - }) + type alias ToolResultOutputContentText + dataBytes, err := json.Marshal(alias(t)) if err != nil { return nil, err } @@ -415,15 +412,14 @@ func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error { return err } - var temp struct { - Text string `json:"text"` - } + type alias ToolResultOutputContentText + var temp alias if err := json.Unmarshal(tr.Data, &temp); err != nil { return err } - t.Text = temp.Text + *t = ToolResultOutputContentText(temp) return nil } @@ -470,13 +466,8 @@ func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ToolResultOutputContentMedia. func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - Data string `json:"data"` - MediaType string `json:"media_type"` - }{ - Data: t.Data, - MediaType: t.MediaType, - }) + type alias ToolResultOutputContentMedia + dataBytes, err := json.Marshal(alias(t)) if err != nil { return nil, err } @@ -494,17 +485,14 @@ func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error { return err } - var temp struct { - Data string `json:"data"` - MediaType string `json:"media_type"` - } + type alias ToolResultOutputContentMedia + var temp alias if err := json.Unmarshal(tr.Data, &temp); err != nil { return err } - t.Data = temp.Data - t.MediaType = temp.MediaType + *t = ToolResultOutputContentMedia(temp) return nil } @@ -870,15 +858,8 @@ func (f *FunctionTool) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler for ProviderDefinedTool. func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) { - dataBytes, err := json.Marshal(struct { - ID string `json:"id"` - Name string `json:"name"` - Args map[string]any `json:"args"` - }{ - ID: p.ID, - Name: p.Name, - Args: p.Args, - }) + type alias ProviderDefinedTool + dataBytes, err := json.Marshal(alias(p)) if err != nil { return nil, err } @@ -896,19 +877,14 @@ func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error { return err } - var aux struct { - ID string `json:"id"` - Name string `json:"name"` - Args map[string]any `json:"args"` - } + type alias ProviderDefinedTool + var aux alias if err := json.Unmarshal(tj.Data, &aux); err != nil { return err } - p.ID = aux.ID - p.Name = aux.Name - p.Args = aux.Args + *p = ProviderDefinedTool(aux) return nil } diff --git a/json_test.go b/json_test.go index da9f491407e284dbc988c59595d79aa631dfd524..b1088434a40b7a28006b1a2b36f2ffc4db9a4753 100644 --- a/json_test.go +++ b/json_test.go @@ -645,3 +645,108 @@ func TestPromptSerialization(t *testing.T) { } }) } + +func TestStreamPartErrorSerialization(t *testing.T) { + t.Run("stream part with ProviderError containing OpenAI API error", func(t *testing.T) { + // Create a mock OpenAI API error + openaiErr := errors.New("invalid_api_key: Incorrect API key provided") + + // Wrap in ProviderError + providerErr := &ProviderError{ + Title: "unauthorized", + Message: "Incorrect API key provided", + Cause: openaiErr, + URL: "https://api.openai.com/v1/chat/completions", + StatusCode: 401, + RequestBody: []byte(`{"model":"gpt-4","messages":[]}`), + ResponseHeaders: map[string]string{ + "content-type": "application/json", + }, + ResponseBody: []byte(`{"error":{"message":"Incorrect API key provided","type":"invalid_request_error"}}`), + } + + // Create StreamPart with error + streamPart := StreamPart{ + Type: StreamPartTypeError, + Error: providerErr, + } + + // Marshal the stream part + data, err := json.Marshal(streamPart) + if err != nil { + t.Fatalf("failed to marshal stream part: %v", err) + } + + // Unmarshal back + var decoded StreamPart + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("failed to unmarshal stream part: %v", err) + } + + // Verify the stream part type + if decoded.Type != StreamPartTypeError { + t.Errorf("type mismatch: got %v, want %v", decoded.Type, StreamPartTypeError) + } + + // Verify error exists + if decoded.Error == nil { + t.Fatal("expected error to be present, got nil") + } + + // Verify error message + expectedMsg := "unauthorized: Incorrect API key provided" + if decoded.Error.Error() != expectedMsg { + t.Errorf("error message mismatch: got %q, want %q", decoded.Error.Error(), expectedMsg) + } + }) + + t.Run("unmarshal stream part with error from JSON", func(t *testing.T) { + // JSON representing a StreamPart with an error + jsonData := `{ + "type": "error", + "error": "unauthorized: Incorrect API key provided", + "id": "", + "tool_call_name": "", + "tool_call_input": "", + "delta": "", + "provider_executed": false, + "usage": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0 + }, + "finish_reason": "", + "warnings": null, + "source_type": "", + "url": "", + "title": "", + "provider_metadata": null + }` + + var streamPart StreamPart + err := json.Unmarshal([]byte(jsonData), &streamPart) + if err != nil { + t.Fatalf("failed to unmarshal stream part: %v", err) + } + + // Verify the stream part type + if streamPart.Type != StreamPartTypeError { + t.Errorf("type mismatch: got %v, want %v", streamPart.Type, StreamPartTypeError) + } + + // Verify error exists + if streamPart.Error == nil { + t.Fatal("expected error to be present, got nil") + } + + // Verify error message + expectedMsg := "unauthorized: Incorrect API key provided" + if streamPart.Error.Error() != expectedMsg { + t.Errorf("error message mismatch: got %q, want %q", streamPart.Error.Error(), expectedMsg) + } + }) +} diff --git a/model_json.go b/model_json.go index 90a8c78520977ae8dc797c2c0575f168297917ca..2fd883a2fc714b0be61b4e47809a70698ef6e6a0 100644 --- a/model_json.go +++ b/model_json.go @@ -102,42 +102,46 @@ func (r *Response) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements json.Marshaler for StreamPart. +func (s StreamPart) MarshalJSON() ([]byte, error) { + type alias StreamPart + aux := struct { + alias + Error string `json:"error,omitempty"` + }{ + alias: (alias)(s), + } + + // Marshal error to string + if s.Error != nil { + aux.Error = s.Error.Error() + } + + // Clear the original Error field to avoid duplicate marshaling + aux.alias.Error = nil + + return json.Marshal(aux) +} + // UnmarshalJSON implements json.Unmarshaler for StreamPart. func (s *StreamPart) UnmarshalJSON(data []byte) error { - var aux struct { - Type StreamPartType `json:"type"` - ID string `json:"id"` - ToolCallName string `json:"tool_call_name"` - ToolCallInput string `json:"tool_call_input"` - Delta string `json:"delta"` - ProviderExecuted bool `json:"provider_executed"` - Usage Usage `json:"usage"` - FinishReason FinishReason `json:"finish_reason"` - Error error `json:"error"` - Warnings []CallWarning `json:"warnings"` - SourceType SourceType `json:"source_type"` - URL string `json:"url"` - Title string `json:"title"` + type alias StreamPart + aux := struct { + *alias + Error string `json:"error"` ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + }{ + alias: (*alias)(s), } if err := json.Unmarshal(data, &aux); err != nil { return err } - s.Type = aux.Type - s.ID = aux.ID - s.ToolCallName = aux.ToolCallName - s.ToolCallInput = aux.ToolCallInput - s.Delta = aux.Delta - s.ProviderExecuted = aux.ProviderExecuted - s.Usage = aux.Usage - s.FinishReason = aux.FinishReason - s.Error = aux.Error - s.Warnings = aux.Warnings - s.SourceType = aux.SourceType - s.URL = aux.URL - s.Title = aux.Title + // Unmarshal error string back to error type + if aux.Error != "" { + s.Error = fmt.Errorf("%s", aux.Error) + } // Unmarshal ProviderMetadata if len(aux.ProviderMetadata) > 0 {