fix: marshal/unmarshal ProviderError when Cause is openai.APIError (#78)

Carlos Alexandro Becker created

Change summary

content_json.go |  54 +++++++------------------
json_test.go    | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++
model_json.go   |  58 +++++++++++++++-------------
3 files changed, 151 insertions(+), 66 deletions(-)

Detailed changes

content_json.go 🔗

@@ -393,11 +393,8 @@ func (t *ToolResultContent) UnmarshalJSON(data []byte) error {
 
 // MarshalJSON implements json.Marshaler for ToolResultOutputContentText.
 func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) {
-	dataBytes, err := json.Marshal(struct {
-		Text string `json:"text"`
-	}{
-		Text: t.Text,
-	})
+	type alias ToolResultOutputContentText
+	dataBytes, err := json.Marshal(alias(t))
 	if err != nil {
 		return nil, err
 	}
@@ -415,15 +412,14 @@ func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error {
 		return err
 	}
 
-	var temp struct {
-		Text string `json:"text"`
-	}
+	type alias ToolResultOutputContentText
+	var temp alias
 
 	if err := json.Unmarshal(tr.Data, &temp); err != nil {
 		return err
 	}
 
-	t.Text = temp.Text
+	*t = ToolResultOutputContentText(temp)
 	return nil
 }
 
@@ -470,13 +466,8 @@ func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error {
 
 // MarshalJSON implements json.Marshaler for ToolResultOutputContentMedia.
 func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) {
-	dataBytes, err := json.Marshal(struct {
-		Data      string `json:"data"`
-		MediaType string `json:"media_type"`
-	}{
-		Data:      t.Data,
-		MediaType: t.MediaType,
-	})
+	type alias ToolResultOutputContentMedia
+	dataBytes, err := json.Marshal(alias(t))
 	if err != nil {
 		return nil, err
 	}
@@ -494,17 +485,14 @@ func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error {
 		return err
 	}
 
-	var temp struct {
-		Data      string `json:"data"`
-		MediaType string `json:"media_type"`
-	}
+	type alias ToolResultOutputContentMedia
+	var temp alias
 
 	if err := json.Unmarshal(tr.Data, &temp); err != nil {
 		return err
 	}
 
-	t.Data = temp.Data
-	t.MediaType = temp.MediaType
+	*t = ToolResultOutputContentMedia(temp)
 	return nil
 }
 
@@ -870,15 +858,8 @@ func (f *FunctionTool) UnmarshalJSON(data []byte) error {
 
 // MarshalJSON implements json.Marshaler for ProviderDefinedTool.
 func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) {
-	dataBytes, err := json.Marshal(struct {
-		ID   string         `json:"id"`
-		Name string         `json:"name"`
-		Args map[string]any `json:"args"`
-	}{
-		ID:   p.ID,
-		Name: p.Name,
-		Args: p.Args,
-	})
+	type alias ProviderDefinedTool
+	dataBytes, err := json.Marshal(alias(p))
 	if err != nil {
 		return nil, err
 	}
@@ -896,19 +877,14 @@ func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error {
 		return err
 	}
 
-	var aux struct {
-		ID   string         `json:"id"`
-		Name string         `json:"name"`
-		Args map[string]any `json:"args"`
-	}
+	type alias ProviderDefinedTool
+	var aux alias
 
 	if err := json.Unmarshal(tj.Data, &aux); err != nil {
 		return err
 	}
 
-	p.ID = aux.ID
-	p.Name = aux.Name
-	p.Args = aux.Args
+	*p = ProviderDefinedTool(aux)
 
 	return nil
 }

json_test.go 🔗

@@ -645,3 +645,108 @@ func TestPromptSerialization(t *testing.T) {
 		}
 	})
 }
+
+func TestStreamPartErrorSerialization(t *testing.T) {
+	t.Run("stream part with ProviderError containing OpenAI API error", func(t *testing.T) {
+		// Create a mock OpenAI API error
+		openaiErr := errors.New("invalid_api_key: Incorrect API key provided")
+
+		// Wrap in ProviderError
+		providerErr := &ProviderError{
+			Title:       "unauthorized",
+			Message:     "Incorrect API key provided",
+			Cause:       openaiErr,
+			URL:         "https://api.openai.com/v1/chat/completions",
+			StatusCode:  401,
+			RequestBody: []byte(`{"model":"gpt-4","messages":[]}`),
+			ResponseHeaders: map[string]string{
+				"content-type": "application/json",
+			},
+			ResponseBody: []byte(`{"error":{"message":"Incorrect API key provided","type":"invalid_request_error"}}`),
+		}
+
+		// Create StreamPart with error
+		streamPart := StreamPart{
+			Type:  StreamPartTypeError,
+			Error: providerErr,
+		}
+
+		// Marshal the stream part
+		data, err := json.Marshal(streamPart)
+		if err != nil {
+			t.Fatalf("failed to marshal stream part: %v", err)
+		}
+
+		// Unmarshal back
+		var decoded StreamPart
+		err = json.Unmarshal(data, &decoded)
+		if err != nil {
+			t.Fatalf("failed to unmarshal stream part: %v", err)
+		}
+
+		// Verify the stream part type
+		if decoded.Type != StreamPartTypeError {
+			t.Errorf("type mismatch: got %v, want %v", decoded.Type, StreamPartTypeError)
+		}
+
+		// Verify error exists
+		if decoded.Error == nil {
+			t.Fatal("expected error to be present, got nil")
+		}
+
+		// Verify error message
+		expectedMsg := "unauthorized: Incorrect API key provided"
+		if decoded.Error.Error() != expectedMsg {
+			t.Errorf("error message mismatch: got %q, want %q", decoded.Error.Error(), expectedMsg)
+		}
+	})
+
+	t.Run("unmarshal stream part with error from JSON", func(t *testing.T) {
+		// JSON representing a StreamPart with an error
+		jsonData := `{
+			"type": "error",
+			"error": "unauthorized: Incorrect API key provided",
+			"id": "",
+			"tool_call_name": "",
+			"tool_call_input": "",
+			"delta": "",
+			"provider_executed": false,
+			"usage": {
+				"input_tokens": 0,
+				"output_tokens": 0,
+				"total_tokens": 0,
+				"reasoning_tokens": 0,
+				"cache_creation_tokens": 0,
+				"cache_read_tokens": 0
+			},
+			"finish_reason": "",
+			"warnings": null,
+			"source_type": "",
+			"url": "",
+			"title": "",
+			"provider_metadata": null
+		}`
+
+		var streamPart StreamPart
+		err := json.Unmarshal([]byte(jsonData), &streamPart)
+		if err != nil {
+			t.Fatalf("failed to unmarshal stream part: %v", err)
+		}
+
+		// Verify the stream part type
+		if streamPart.Type != StreamPartTypeError {
+			t.Errorf("type mismatch: got %v, want %v", streamPart.Type, StreamPartTypeError)
+		}
+
+		// Verify error exists
+		if streamPart.Error == nil {
+			t.Fatal("expected error to be present, got nil")
+		}
+
+		// Verify error message
+		expectedMsg := "unauthorized: Incorrect API key provided"
+		if streamPart.Error.Error() != expectedMsg {
+			t.Errorf("error message mismatch: got %q, want %q", streamPart.Error.Error(), expectedMsg)
+		}
+	})
+}

model_json.go 🔗

@@ -102,42 +102,46 @@ func (r *Response) UnmarshalJSON(data []byte) error {
 	return nil
 }
 
+// MarshalJSON implements json.Marshaler for StreamPart.
+func (s StreamPart) MarshalJSON() ([]byte, error) {
+	type alias StreamPart
+	aux := struct {
+		alias
+		Error string `json:"error,omitempty"`
+	}{
+		alias: (alias)(s),
+	}
+
+	// Marshal error to string
+	if s.Error != nil {
+		aux.Error = s.Error.Error()
+	}
+
+	// Clear the original Error field to avoid duplicate marshaling
+	aux.alias.Error = nil
+
+	return json.Marshal(aux)
+}
+
 // UnmarshalJSON implements json.Unmarshaler for StreamPart.
 func (s *StreamPart) UnmarshalJSON(data []byte) error {
-	var aux struct {
-		Type             StreamPartType             `json:"type"`
-		ID               string                     `json:"id"`
-		ToolCallName     string                     `json:"tool_call_name"`
-		ToolCallInput    string                     `json:"tool_call_input"`
-		Delta            string                     `json:"delta"`
-		ProviderExecuted bool                       `json:"provider_executed"`
-		Usage            Usage                      `json:"usage"`
-		FinishReason     FinishReason               `json:"finish_reason"`
-		Error            error                      `json:"error"`
-		Warnings         []CallWarning              `json:"warnings"`
-		SourceType       SourceType                 `json:"source_type"`
-		URL              string                     `json:"url"`
-		Title            string                     `json:"title"`
+	type alias StreamPart
+	aux := struct {
+		*alias
+		Error            string                     `json:"error"`
 		ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
+	}{
+		alias: (*alias)(s),
 	}
 
 	if err := json.Unmarshal(data, &aux); err != nil {
 		return err
 	}
 
-	s.Type = aux.Type
-	s.ID = aux.ID
-	s.ToolCallName = aux.ToolCallName
-	s.ToolCallInput = aux.ToolCallInput
-	s.Delta = aux.Delta
-	s.ProviderExecuted = aux.ProviderExecuted
-	s.Usage = aux.Usage
-	s.FinishReason = aux.FinishReason
-	s.Error = aux.Error
-	s.Warnings = aux.Warnings
-	s.SourceType = aux.SourceType
-	s.URL = aux.URL
-	s.Title = aux.Title
+	// Unmarshal error string back to error type
+	if aux.Error != "" {
+		s.Error = fmt.Errorf("%s", aux.Error)
+	}
 
 	// Unmarshal ProviderMetadata
 	if len(aux.ProviderMetadata) > 0 {