Detailed changes
@@ -1,8 +1,56 @@
package fantasy
+import "encoding/json"
+
// ProviderOptionsData is an interface for provider-specific options data.
+// All implementations MUST also implement encoding/json.Marshaler and
+// encoding/json.Unmarshaler interfaces to ensure proper JSON serialization
+// with the provider registry system.
+//
+// Required implementation pattern:
+//
+// type MyProviderOptions struct {
+// Field string `json:"field"`
+// }
+//
+// // Implement ProviderOptionsData
+// func (*MyProviderOptions) Options() {}
+//
+// // Implement json.Marshaler - use fantasy.MarshalProviderData
+// func (m MyProviderOptions) MarshalJSON() ([]byte, error) {
+// return fantasy.MarshalProviderData(&m, "provider.type")
+// }
+//
+// // Implement json.Unmarshaler - use fantasy.UnmarshalProviderData
+// func (m *MyProviderOptions) UnmarshalJSON(data []byte) error {
+// providerData, err := fantasy.UnmarshalProviderData(data)
+// if err != nil {
+// return err
+// }
+// opts, ok := providerData.(*MyProviderOptions)
+// if !ok {
+// return fmt.Errorf("invalid type")
+// }
+// *m = *opts
+// return nil
+// }
+//
+// Additionally, register the type in init():
+//
+// func init() {
+// fantasy.RegisterProviderType("provider.type", func(data []byte) (fantasy.ProviderOptionsData, error) {
+// var opts MyProviderOptions
+// if err := json.Unmarshal(data, &opts); err != nil {
+// return nil, err
+// }
+// return &opts, nil
+// })
+// }
type ProviderOptionsData interface {
+ // Options is a marker method that identifies types implementing this interface.
Options()
+ json.Marshaler
+ json.Unmarshaler
}
// ProviderMetadata represents additional provider-specific metadata.
@@ -0,0 +1,1022 @@
+package fantasy
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+)
+
+// contentJSON is a helper type for JSON serialization of Content in Response.
+type contentJSON struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+}
+
+// messagePartJSON is a helper type for JSON serialization of MessagePart.
+type messagePartJSON struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+}
+
+// toolResultOutputJSON is a helper type for JSON serialization of ToolResultOutputContent.
+type toolResultOutputJSON struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+}
+
+// toolJSON is a helper type for JSON serialization of Tool.
+type toolJSON struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+}
+
+func (t TextContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Text string `json:"text"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ }{
+ Text: t.Text,
+ ProviderMetadata: t.ProviderMetadata,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeText),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *TextContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Text string `json:"text"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.Text = aux.Text
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ t.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (r ReasoningContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Text string `json:"text"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ }{
+ Text: r.Text,
+ ProviderMetadata: r.ProviderMetadata,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeReasoning),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (r *ReasoningContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Text string `json:"text"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ r.Text = aux.Text
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ r.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (f FileContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ MediaType string `json:"media_type"`
+ Data []byte `json:"data"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ }{
+ MediaType: f.MediaType,
+ Data: f.Data,
+ ProviderMetadata: f.ProviderMetadata,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeFile),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (f *FileContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ MediaType string `json:"media_type"`
+ Data []byte `json:"data"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ f.MediaType = aux.MediaType
+ f.Data = aux.Data
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ f.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (s SourceContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ SourceType SourceType `json:"source_type"`
+ ID string `json:"id"`
+ URL string `json:"url,omitempty"`
+ Title string `json:"title,omitempty"`
+ MediaType string `json:"media_type,omitempty"`
+ Filename string `json:"filename,omitempty"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ }{
+ SourceType: s.SourceType,
+ ID: s.ID,
+ URL: s.URL,
+ Title: s.Title,
+ MediaType: s.MediaType,
+ Filename: s.Filename,
+ ProviderMetadata: s.ProviderMetadata,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeSource),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (s *SourceContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ SourceType SourceType `json:"source_type"`
+ ID string `json:"id"`
+ URL string `json:"url,omitempty"`
+ Title string `json:"title,omitempty"`
+ MediaType string `json:"media_type,omitempty"`
+ Filename string `json:"filename,omitempty"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ s.SourceType = aux.SourceType
+ s.ID = aux.ID
+ s.URL = aux.URL
+ s.Title = aux.Title
+ s.MediaType = aux.MediaType
+ s.Filename = aux.Filename
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ s.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (t ToolCallContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Input string `json:"input"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ Invalid bool `json:"invalid,omitempty"`
+ ValidationError error `json:"validation_error,omitempty"`
+ }{
+ ToolCallID: t.ToolCallID,
+ ToolName: t.ToolName,
+ Input: t.Input,
+ ProviderExecuted: t.ProviderExecuted,
+ ProviderMetadata: t.ProviderMetadata,
+ Invalid: t.Invalid,
+ ValidationError: t.ValidationError,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeToolCall),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolCallContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Input string `json:"input"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ Invalid bool `json:"invalid,omitempty"`
+ ValidationError error `json:"validation_error,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.ToolCallID = aux.ToolCallID
+ t.ToolName = aux.ToolName
+ t.Input = aux.Input
+ t.ProviderExecuted = aux.ProviderExecuted
+ t.Invalid = aux.Invalid
+ t.ValidationError = aux.ValidationError
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ t.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (t ToolResultContent) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Result ToolResultOutputContent `json:"result"`
+ ClientMetadata string `json:"client_metadata,omitempty"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"`
+ }{
+ ToolCallID: t.ToolCallID,
+ ToolName: t.ToolName,
+ Result: t.Result,
+ ClientMetadata: t.ClientMetadata,
+ ProviderExecuted: t.ProviderExecuted,
+ ProviderMetadata: t.ProviderMetadata,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(contentJSON{
+ Type: string(ContentTypeToolResult),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolResultContent) UnmarshalJSON(data []byte) error {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Result json.RawMessage `json:"result"`
+ ClientMetadata string `json:"client_metadata,omitempty"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"`
+ }
+
+ if err := json.Unmarshal(cj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.ToolCallID = aux.ToolCallID
+ t.ToolName = aux.ToolName
+ t.ClientMetadata = aux.ClientMetadata
+ t.ProviderExecuted = aux.ProviderExecuted
+
+ // Unmarshal the Result field
+ result, err := UnmarshalToolResultOutputContent(aux.Result)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal tool result output: %w", err)
+ }
+ t.Result = result
+
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ t.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Text string `json:"text"`
+ }{
+ Text: t.Text,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(toolResultOutputJSON{
+ Type: string(ToolResultContentTypeText),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error {
+ var tr toolResultOutputJSON
+ if err := json.Unmarshal(data, &tr); err != nil {
+ return err
+ }
+
+ var temp struct {
+ Text string `json:"text"`
+ }
+
+ if err := json.Unmarshal(tr.Data, &temp); err != nil {
+ return err
+ }
+
+ t.Text = temp.Text
+ return nil
+}
+
+func (t ToolResultOutputContentError) MarshalJSON() ([]byte, error) {
+ errMsg := ""
+ if t.Error != nil {
+ errMsg = t.Error.Error()
+ }
+ dataBytes, err := json.Marshal(struct {
+ Error string `json:"error"`
+ }{
+ Error: errMsg,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(toolResultOutputJSON{
+ Type: string(ToolResultContentTypeError),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error {
+ var tr toolResultOutputJSON
+ if err := json.Unmarshal(data, &tr); err != nil {
+ return err
+ }
+
+ var temp struct {
+ Error string `json:"error"`
+ }
+
+ if err := json.Unmarshal(tr.Data, &temp); err != nil {
+ return err
+ }
+ if temp.Error != "" {
+ t.Error = errors.New(temp.Error)
+ }
+ return nil
+}
+
+func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Data string `json:"data"`
+ MediaType string `json:"media_type"`
+ }{
+ Data: t.Data,
+ MediaType: t.MediaType,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(toolResultOutputJSON{
+ Type: string(ToolResultContentTypeMedia),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error {
+ var tr toolResultOutputJSON
+ if err := json.Unmarshal(data, &tr); err != nil {
+ return err
+ }
+
+ var temp struct {
+ Data string `json:"data"`
+ MediaType string `json:"media_type"`
+ }
+
+ if err := json.Unmarshal(tr.Data, &temp); err != nil {
+ return err
+ }
+
+ t.Data = temp.Data
+ t.MediaType = temp.MediaType
+ return nil
+}
+
+func (t TextPart) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Text string `json:"text"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ Text: t.Text,
+ ProviderOptions: t.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(messagePartJSON{
+ Type: string(ContentTypeText),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *TextPart) UnmarshalJSON(data []byte) error {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Text string `json:"text"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(mpj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.Text = aux.Text
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ t.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (r ReasoningPart) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Text string `json:"text"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ Text: r.Text,
+ ProviderOptions: r.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(messagePartJSON{
+ Type: string(ContentTypeReasoning),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (r *ReasoningPart) UnmarshalJSON(data []byte) error {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Text string `json:"text"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(mpj.Data, &aux); err != nil {
+ return err
+ }
+
+ r.Text = aux.Text
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ r.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (f FilePart) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Filename string `json:"filename"`
+ Data []byte `json:"data"`
+ MediaType string `json:"media_type"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ Filename: f.Filename,
+ Data: f.Data,
+ MediaType: f.MediaType,
+ ProviderOptions: f.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(messagePartJSON{
+ Type: string(ContentTypeFile),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (f *FilePart) UnmarshalJSON(data []byte) error {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Filename string `json:"filename"`
+ Data []byte `json:"data"`
+ MediaType string `json:"media_type"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(mpj.Data, &aux); err != nil {
+ return err
+ }
+
+ f.Filename = aux.Filename
+ f.Data = aux.Data
+ f.MediaType = aux.MediaType
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ f.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (t ToolCallPart) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Input string `json:"input"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ ToolCallID: t.ToolCallID,
+ ToolName: t.ToolName,
+ Input: t.Input,
+ ProviderExecuted: t.ProviderExecuted,
+ ProviderOptions: t.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(messagePartJSON{
+ Type: string(ContentTypeToolCall),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolCallPart) UnmarshalJSON(data []byte) error {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Input string `json:"input"`
+ ProviderExecuted bool `json:"provider_executed"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(mpj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.ToolCallID = aux.ToolCallID
+ t.ToolName = aux.ToolName
+ t.Input = aux.Input
+ t.ProviderExecuted = aux.ProviderExecuted
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ t.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (t ToolResultPart) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ ToolCallID string `json:"tool_call_id"`
+ Output ToolResultOutputContent `json:"output"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ ToolCallID: t.ToolCallID,
+ Output: t.Output,
+ ProviderOptions: t.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(messagePartJSON{
+ Type: string(ContentTypeToolResult),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (t *ToolResultPart) UnmarshalJSON(data []byte) error {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ ToolCallID string `json:"tool_call_id"`
+ Output json.RawMessage `json:"output"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(mpj.Data, &aux); err != nil {
+ return err
+ }
+
+ t.ToolCallID = aux.ToolCallID
+
+ // Unmarshal the Output field
+ output, err := UnmarshalToolResultOutputContent(aux.Output)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal tool result output: %w", err)
+ }
+ t.Output = output
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ t.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (m *Message) UnmarshalJSON(data []byte) error {
+ var aux struct {
+ Role MessageRole `json:"role"`
+ Content []json.RawMessage `json:"content"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ m.Role = aux.Role
+
+ m.Content = make([]MessagePart, len(aux.Content))
+ for i, rawPart := range aux.Content {
+ part, err := UnmarshalMessagePart(rawPart)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal message part at index %d: %w", i, err)
+ }
+ m.Content[i] = part
+ }
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ m.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (f FunctionTool) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema map[string]any `json:"input_schema"`
+ ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
+ }{
+ Name: f.Name,
+ Description: f.Description,
+ InputSchema: f.InputSchema,
+ ProviderOptions: f.ProviderOptions,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(toolJSON{
+ Type: string(ToolTypeFunction),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (f *FunctionTool) UnmarshalJSON(data []byte) error {
+ var tj toolJSON
+ if err := json.Unmarshal(data, &tj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema map[string]any `json:"input_schema"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
+ }
+
+ if err := json.Unmarshal(tj.Data, &aux); err != nil {
+ return err
+ }
+
+ f.Name = aux.Name
+ f.Description = aux.Description
+ f.InputSchema = aux.InputSchema
+
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ f.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) {
+ dataBytes, err := json.Marshal(struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Args map[string]any `json:"args"`
+ }{
+ ID: p.ID,
+ Name: p.Name,
+ Args: p.Args,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(toolJSON{
+ Type: string(ToolTypeProviderDefined),
+ Data: json.RawMessage(dataBytes),
+ })
+}
+
+func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error {
+ var tj toolJSON
+ if err := json.Unmarshal(data, &tj); err != nil {
+ return err
+ }
+
+ var aux struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Args map[string]any `json:"args"`
+ }
+
+ if err := json.Unmarshal(tj.Data, &aux); err != nil {
+ return err
+ }
+
+ p.ID = aux.ID
+ p.Name = aux.Name
+ p.Args = aux.Args
+
+ return nil
+}
+
+// UnmarshalTool unmarshals JSON into the appropriate Tool type
+func UnmarshalTool(data []byte) (Tool, error) {
+ var tj toolJSON
+ if err := json.Unmarshal(data, &tj); err != nil {
+ return nil, err
+ }
+
+ switch ToolType(tj.Type) {
+ case ToolTypeFunction:
+ var tool FunctionTool
+ if err := tool.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return tool, nil
+ case ToolTypeProviderDefined:
+ var tool ProviderDefinedTool
+ if err := tool.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return tool, nil
+ default:
+ return nil, fmt.Errorf("unknown tool type: %s", tj.Type)
+ }
+}
+
+// UnmarshalContent unmarshals JSON into the appropriate Content type
+func UnmarshalContent(data []byte) (Content, error) {
+ var cj contentJSON
+ if err := json.Unmarshal(data, &cj); err != nil {
+ return nil, err
+ }
+
+ switch ContentType(cj.Type) {
+ case ContentTypeText:
+ var content TextContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ContentTypeReasoning:
+ var content ReasoningContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ContentTypeFile:
+ var content FileContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ContentTypeSource:
+ var content SourceContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ContentTypeToolCall:
+ var content ToolCallContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ContentTypeToolResult:
+ var content ToolResultContent
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ default:
+ return nil, fmt.Errorf("unknown content type: %s", cj.Type)
+ }
+}
+
+// UnmarshalMessagePart unmarshals JSON into the appropriate MessagePart type
+func UnmarshalMessagePart(data []byte) (MessagePart, error) {
+ var mpj messagePartJSON
+ if err := json.Unmarshal(data, &mpj); err != nil {
+ return nil, err
+ }
+
+ switch ContentType(mpj.Type) {
+ case ContentTypeText:
+ var part TextPart
+ if err := part.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return part, nil
+ case ContentTypeReasoning:
+ var part ReasoningPart
+ if err := part.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return part, nil
+ case ContentTypeFile:
+ var part FilePart
+ if err := part.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return part, nil
+ case ContentTypeToolCall:
+ var part ToolCallPart
+ if err := part.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return part, nil
+ case ContentTypeToolResult:
+ var part ToolResultPart
+ if err := part.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return part, nil
+ default:
+ return nil, fmt.Errorf("unknown message part type: %s", mpj.Type)
+ }
+}
+
+// UnmarshalToolResultOutputContent unmarshals JSON into the appropriate ToolResultOutputContent type
+func UnmarshalToolResultOutputContent(data []byte) (ToolResultOutputContent, error) {
+ var troj toolResultOutputJSON
+ if err := json.Unmarshal(data, &troj); err != nil {
+ return nil, err
+ }
+
+ switch ToolResultContentType(troj.Type) {
+ case ToolResultContentTypeText:
+ var content ToolResultOutputContentText
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ToolResultContentTypeError:
+ var content ToolResultOutputContentError
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ case ToolResultContentTypeMedia:
+ var content ToolResultOutputContentMedia
+ if err := content.UnmarshalJSON(data); err != nil {
+ return nil, err
+ }
+ return content, nil
+ default:
+ return nil, fmt.Errorf("unknown tool result output content type: %s", troj.Type)
+ }
+}
@@ -0,0 +1,647 @@
+package fantasy
+
+import (
+ "encoding/json"
+ "errors"
+ "reflect"
+ "testing"
+)
+
+func TestMessageJSONSerialization(t *testing.T) {
+ tests := []struct {
+ name string
+ message Message
+ }{
+ {
+ name: "simple text message",
+ message: Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{
+ TextPart{Text: "Hello, world!"},
+ },
+ },
+ },
+ {
+ name: "message with multiple text parts",
+ message: Message{
+ Role: MessageRoleAssistant,
+ Content: []MessagePart{
+ TextPart{Text: "First part"},
+ TextPart{Text: "Second part"},
+ TextPart{Text: "Third part"},
+ },
+ },
+ },
+ {
+ name: "message with reasoning part",
+ message: Message{
+ Role: MessageRoleAssistant,
+ Content: []MessagePart{
+ ReasoningPart{Text: "Let me think about this..."},
+ TextPart{Text: "Here's my answer"},
+ },
+ },
+ },
+ {
+ name: "message with file part",
+ message: Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{
+ TextPart{Text: "Here's an image:"},
+ FilePart{
+ Filename: "test.png",
+ Data: []byte{0x89, 0x50, 0x4E, 0x47}, // PNG header
+ MediaType: "image/png",
+ },
+ },
+ },
+ },
+ {
+ name: "message with tool call",
+ message: Message{
+ Role: MessageRoleAssistant,
+ Content: []MessagePart{
+ ToolCallPart{
+ ToolCallID: "call_123",
+ ToolName: "get_weather",
+ Input: `{"location": "San Francisco"}`,
+ ProviderExecuted: false,
+ },
+ },
+ },
+ },
+ {
+ name: "message with tool result - text output",
+ message: Message{
+ Role: MessageRoleTool,
+ Content: []MessagePart{
+ ToolResultPart{
+ ToolCallID: "call_123",
+ Output: ToolResultOutputContentText{
+ Text: "The weather is sunny, 72Β°F",
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "message with tool result - error output",
+ message: Message{
+ Role: MessageRoleTool,
+ Content: []MessagePart{
+ ToolResultPart{
+ ToolCallID: "call_456",
+ Output: ToolResultOutputContentError{
+ Error: errors.New("API rate limit exceeded"),
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "message with tool result - media output",
+ message: Message{
+ Role: MessageRoleTool,
+ Content: []MessagePart{
+ ToolResultPart{
+ ToolCallID: "call_789",
+ Output: ToolResultOutputContentMedia{
+ Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
+ MediaType: "image/png",
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "complex message with mixed content",
+ message: Message{
+ Role: MessageRoleAssistant,
+ Content: []MessagePart{
+ TextPart{Text: "I'll analyze this image and call some tools."},
+ ReasoningPart{Text: "First, I need to identify the objects..."},
+ ToolCallPart{
+ ToolCallID: "call_001",
+ ToolName: "analyze_image",
+ Input: `{"image_id": "img_123"}`,
+ ProviderExecuted: false,
+ },
+ ToolCallPart{
+ ToolCallID: "call_002",
+ ToolName: "get_context",
+ Input: `{"query": "similar images"}`,
+ ProviderExecuted: true,
+ },
+ },
+ },
+ },
+ {
+ name: "system message",
+ message: Message{
+ Role: MessageRoleSystem,
+ Content: []MessagePart{
+ TextPart{Text: "You are a helpful assistant."},
+ },
+ },
+ },
+ {
+ name: "empty content",
+ message: Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Marshal the message
+ data, err := json.Marshal(tt.message)
+ if err != nil {
+ t.Fatalf("failed to marshal message: %v", err)
+ }
+
+ // Unmarshal back
+ var decoded Message
+ err = json.Unmarshal(data, &decoded)
+ if err != nil {
+ t.Fatalf("failed to unmarshal message: %v", err)
+ }
+
+ // Compare roles
+ if decoded.Role != tt.message.Role {
+ t.Errorf("role mismatch: got %v, want %v", decoded.Role, tt.message.Role)
+ }
+
+ // Compare content length
+ if len(decoded.Content) != len(tt.message.Content) {
+ t.Fatalf("content length mismatch: got %d, want %d", len(decoded.Content), len(tt.message.Content))
+ }
+
+ // Compare each content part
+ for i := range tt.message.Content {
+ original := tt.message.Content[i]
+ decodedPart := decoded.Content[i]
+
+ if original.GetType() != decodedPart.GetType() {
+ t.Errorf("content[%d] type mismatch: got %v, want %v", i, decodedPart.GetType(), original.GetType())
+ continue
+ }
+
+ compareMessagePart(t, i, original, decodedPart)
+ }
+ })
+ }
+}
+
+func compareMessagePart(t *testing.T, index int, original, decoded MessagePart) {
+ switch original.GetType() {
+ case ContentTypeText:
+ orig := original.(TextPart)
+ dec := decoded.(TextPart)
+ if orig.Text != dec.Text {
+ t.Errorf("content[%d] text mismatch: got %q, want %q", index, dec.Text, orig.Text)
+ }
+
+ case ContentTypeReasoning:
+ orig := original.(ReasoningPart)
+ dec := decoded.(ReasoningPart)
+ if orig.Text != dec.Text {
+ t.Errorf("content[%d] reasoning text mismatch: got %q, want %q", index, dec.Text, orig.Text)
+ }
+
+ case ContentTypeFile:
+ orig := original.(FilePart)
+ dec := decoded.(FilePart)
+ if orig.Filename != dec.Filename {
+ t.Errorf("content[%d] filename mismatch: got %q, want %q", index, dec.Filename, orig.Filename)
+ }
+ if orig.MediaType != dec.MediaType {
+ t.Errorf("content[%d] media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
+ }
+ if !reflect.DeepEqual(orig.Data, dec.Data) {
+ t.Errorf("content[%d] file data mismatch", index)
+ }
+
+ case ContentTypeToolCall:
+ orig := original.(ToolCallPart)
+ dec := decoded.(ToolCallPart)
+ if orig.ToolCallID != dec.ToolCallID {
+ t.Errorf("content[%d] tool call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
+ }
+ if orig.ToolName != dec.ToolName {
+ t.Errorf("content[%d] tool name mismatch: got %q, want %q", index, dec.ToolName, orig.ToolName)
+ }
+ if orig.Input != dec.Input {
+ t.Errorf("content[%d] tool input mismatch: got %q, want %q", index, dec.Input, orig.Input)
+ }
+ if orig.ProviderExecuted != dec.ProviderExecuted {
+ t.Errorf("content[%d] provider executed mismatch: got %v, want %v", index, dec.ProviderExecuted, orig.ProviderExecuted)
+ }
+
+ case ContentTypeToolResult:
+ orig := original.(ToolResultPart)
+ dec := decoded.(ToolResultPart)
+ if orig.ToolCallID != dec.ToolCallID {
+ t.Errorf("content[%d] tool result call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
+ }
+ compareToolResultOutput(t, index, orig.Output, dec.Output)
+ }
+}
+
+func compareToolResultOutput(t *testing.T, index int, original, decoded ToolResultOutputContent) {
+ if original.GetType() != decoded.GetType() {
+ t.Errorf("content[%d] tool result output type mismatch: got %v, want %v", index, decoded.GetType(), original.GetType())
+ return
+ }
+
+ switch original.GetType() {
+ case ToolResultContentTypeText:
+ orig := original.(ToolResultOutputContentText)
+ dec := decoded.(ToolResultOutputContentText)
+ if orig.Text != dec.Text {
+ t.Errorf("content[%d] tool result text mismatch: got %q, want %q", index, dec.Text, orig.Text)
+ }
+
+ case ToolResultContentTypeError:
+ orig := original.(ToolResultOutputContentError)
+ dec := decoded.(ToolResultOutputContentError)
+ if orig.Error.Error() != dec.Error.Error() {
+ t.Errorf("content[%d] tool result error mismatch: got %q, want %q", index, dec.Error.Error(), orig.Error.Error())
+ }
+
+ case ToolResultContentTypeMedia:
+ orig := original.(ToolResultOutputContentMedia)
+ dec := decoded.(ToolResultOutputContentMedia)
+ if orig.Data != dec.Data {
+ t.Errorf("content[%d] tool result media data mismatch", index)
+ }
+ if orig.MediaType != dec.MediaType {
+ t.Errorf("content[%d] tool result media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
+ }
+ }
+}
+
+func TestHelperFunctions(t *testing.T) {
+ t.Run("NewUserMessage - text only", func(t *testing.T) {
+ msg := NewUserMessage("Hello")
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ if decoded.Role != MessageRoleUser {
+ t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleUser)
+ }
+
+ if len(decoded.Content) != 1 {
+ t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
+ }
+
+ textPart := decoded.Content[0].(TextPart)
+ if textPart.Text != "Hello" {
+ t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Hello")
+ }
+ })
+
+ t.Run("NewUserMessage - with files", func(t *testing.T) {
+ msg := NewUserMessage("Check this image",
+ FilePart{
+ Filename: "image1.jpg",
+ Data: []byte{0xFF, 0xD8, 0xFF},
+ MediaType: "image/jpeg",
+ },
+ FilePart{
+ Filename: "image2.png",
+ Data: []byte{0x89, 0x50, 0x4E, 0x47},
+ MediaType: "image/png",
+ },
+ )
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ if len(decoded.Content) != 3 {
+ t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
+ }
+
+ // Check text part
+ textPart := decoded.Content[0].(TextPart)
+ if textPart.Text != "Check this image" {
+ t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Check this image")
+ }
+
+ // Check first file
+ file1 := decoded.Content[1].(FilePart)
+ if file1.Filename != "image1.jpg" {
+ t.Errorf("file1 name mismatch: got %q, want %q", file1.Filename, "image1.jpg")
+ }
+
+ // Check second file
+ file2 := decoded.Content[2].(FilePart)
+ if file2.Filename != "image2.png" {
+ t.Errorf("file2 name mismatch: got %q, want %q", file2.Filename, "image2.png")
+ }
+ })
+
+ t.Run("NewSystemMessage - single prompt", func(t *testing.T) {
+ msg := NewSystemMessage("You are a helpful assistant.")
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ if decoded.Role != MessageRoleSystem {
+ t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleSystem)
+ }
+
+ if len(decoded.Content) != 1 {
+ t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
+ }
+
+ textPart := decoded.Content[0].(TextPart)
+ if textPart.Text != "You are a helpful assistant." {
+ t.Errorf("text mismatch: got %q, want %q", textPart.Text, "You are a helpful assistant.")
+ }
+ })
+
+ t.Run("NewSystemMessage - multiple prompts", func(t *testing.T) {
+ msg := NewSystemMessage("First instruction", "Second instruction", "Third instruction")
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ if len(decoded.Content) != 3 {
+ t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
+ }
+
+ expected := []string{"First instruction", "Second instruction", "Third instruction"}
+ for i, exp := range expected {
+ textPart := decoded.Content[i].(TextPart)
+ if textPart.Text != exp {
+ t.Errorf("content[%d] text mismatch: got %q, want %q", i, textPart.Text, exp)
+ }
+ }
+ })
+}
+
+func TestEdgeCases(t *testing.T) {
+ t.Run("empty text part", func(t *testing.T) {
+ msg := Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{
+ TextPart{Text: ""},
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ textPart := decoded.Content[0].(TextPart)
+ if textPart.Text != "" {
+ t.Errorf("expected empty text, got %q", textPart.Text)
+ }
+ })
+
+ t.Run("nil error in tool result", func(t *testing.T) {
+ msg := Message{
+ Role: MessageRoleTool,
+ Content: []MessagePart{
+ ToolResultPart{
+ ToolCallID: "call_123",
+ Output: ToolResultOutputContentError{
+ Error: nil,
+ },
+ },
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ toolResult := decoded.Content[0].(ToolResultPart)
+ errorOutput := toolResult.Output.(ToolResultOutputContentError)
+ if errorOutput.Error != nil {
+ t.Errorf("expected nil error, got %v", errorOutput.Error)
+ }
+ })
+
+ t.Run("empty file data", func(t *testing.T) {
+ msg := Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{
+ FilePart{
+ Filename: "empty.txt",
+ Data: []byte{},
+ MediaType: "text/plain",
+ },
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ filePart := decoded.Content[0].(FilePart)
+ if len(filePart.Data) != 0 {
+ t.Errorf("expected empty data, got %d bytes", len(filePart.Data))
+ }
+ })
+
+ t.Run("unicode in text", func(t *testing.T) {
+ msg := Message{
+ Role: MessageRoleUser,
+ Content: []MessagePart{
+ TextPart{Text: "Hello δΈη! π ΠΡΠΈΠ²Π΅Ρ"},
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ var decoded Message
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ textPart := decoded.Content[0].(TextPart)
+ if textPart.Text != "Hello δΈη! π ΠΡΠΈΠ²Π΅Ρ" {
+ t.Errorf("unicode text mismatch: got %q, want %q", textPart.Text, "Hello δΈη! π ΠΡΠΈΠ²Π΅Ρ")
+ }
+ })
+}
+
+func TestInvalidJSONHandling(t *testing.T) {
+ t.Run("unknown message part type", func(t *testing.T) {
+ invalidJSON := `{
+ "role": "user",
+ "content": [
+ {
+ "type": "unknown-type",
+ "data": {}
+ }
+ ],
+ "provider_options": null
+ }`
+
+ var msg Message
+ err := json.Unmarshal([]byte(invalidJSON), &msg)
+ if err == nil {
+ t.Error("expected error for unknown message part type, got nil")
+ }
+ })
+
+ t.Run("unknown tool result output type", func(t *testing.T) {
+ invalidJSON := `{
+ "role": "tool",
+ "content": [
+ {
+ "type": "tool-result",
+ "data": {
+ "tool_call_id": "call_123",
+ "output": {
+ "type": "unknown-output-type",
+ "data": {}
+ },
+ "provider_options": null
+ }
+ }
+ ],
+ "provider_options": null
+ }`
+
+ var msg Message
+ err := json.Unmarshal([]byte(invalidJSON), &msg)
+ if err == nil {
+ t.Error("expected error for unknown tool result output type, got nil")
+ }
+ })
+
+ t.Run("malformed JSON", func(t *testing.T) {
+ invalidJSON := `{"role": "user", "content": [`
+
+ var msg Message
+ err := json.Unmarshal([]byte(invalidJSON), &msg)
+ if err == nil {
+ t.Error("expected error for malformed JSON, got nil")
+ }
+ })
+}
+
+// Mock provider data for testing provider options
+type mockProviderData struct {
+ Key string `json:"key"`
+}
+
+func (m mockProviderData) Options() {}
+func (m mockProviderData) Type() string { return "mock" }
+func (m mockProviderData) MarshalJSON() ([]byte, error) {
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ mockProviderData
+ }{
+ Type: "mock",
+ mockProviderData: m,
+ })
+}
+
+func (m *mockProviderData) UnmarshalJSON(data []byte) error {
+ var aux struct {
+ Type string `json:"type"`
+ mockProviderData
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ *m = aux.mockProviderData
+ return nil
+}
+
+func TestPromptSerialization(t *testing.T) {
+ t.Run("serialize prompt (message slice)", func(t *testing.T) {
+ prompt := Prompt{
+ NewSystemMessage("You are helpful"),
+ NewUserMessage("Hello"),
+ Message{
+ Role: MessageRoleAssistant,
+ Content: []MessagePart{
+ TextPart{Text: "Hi there!"},
+ },
+ },
+ }
+
+ data, err := json.Marshal(prompt)
+ if err != nil {
+ t.Fatalf("failed to marshal prompt: %v", err)
+ }
+
+ var decoded Prompt
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("failed to unmarshal prompt: %v", err)
+ }
+
+ if len(decoded) != 3 {
+ t.Fatalf("expected 3 messages, got %d", len(decoded))
+ }
+
+ if decoded[0].Role != MessageRoleSystem {
+ t.Errorf("message 0 role mismatch: got %v, want %v", decoded[0].Role, MessageRoleSystem)
+ }
+
+ if decoded[1].Role != MessageRoleUser {
+ t.Errorf("message 1 role mismatch: got %v, want %v", decoded[1].Role, MessageRoleUser)
+ }
+
+ if decoded[2].Role != MessageRoleAssistant {
+ t.Errorf("message 2 role mismatch: got %v, want %v", decoded[2].Role, MessageRoleAssistant)
+ }
+ })
+}
@@ -0,0 +1,149 @@
+package fantasy
+
+import (
+ "encoding/json"
+ "fmt"
+)
+
+func (c *Call) UnmarshalJSON(data []byte) error {
+ var aux struct {
+ Prompt Prompt `json:"prompt"`
+ MaxOutputTokens *int64 `json:"max_output_tokens"`
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ Tools []json.RawMessage `json:"tools"`
+ ToolChoice *ToolChoice `json:"tool_choice"`
+ ProviderOptions map[string]json.RawMessage `json:"provider_options"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ c.Prompt = aux.Prompt
+ c.MaxOutputTokens = aux.MaxOutputTokens
+ c.Temperature = aux.Temperature
+ c.TopP = aux.TopP
+ c.TopK = aux.TopK
+ c.PresencePenalty = aux.PresencePenalty
+ c.FrequencyPenalty = aux.FrequencyPenalty
+ c.ToolChoice = aux.ToolChoice
+
+ // Unmarshal Tools slice
+ c.Tools = make([]Tool, len(aux.Tools))
+ for i, rawTool := range aux.Tools {
+ tool, err := UnmarshalTool(rawTool)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err)
+ }
+ c.Tools[i] = tool
+ }
+
+ // Unmarshal ProviderOptions
+ if len(aux.ProviderOptions) > 0 {
+ options, err := UnmarshalProviderOptions(aux.ProviderOptions)
+ if err != nil {
+ return err
+ }
+ c.ProviderOptions = options
+ }
+
+ return nil
+}
+
+func (r *Response) UnmarshalJSON(data []byte) error {
+ var aux struct {
+ Content json.RawMessage `json:"content"`
+ FinishReason FinishReason `json:"finish_reason"`
+ Usage Usage `json:"usage"`
+ Warnings []CallWarning `json:"warnings"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ r.FinishReason = aux.FinishReason
+ r.Usage = aux.Usage
+ r.Warnings = aux.Warnings
+
+ // Unmarshal ResponseContent (need to know the type definition)
+ // If ResponseContent is []Content:
+ var rawContent []json.RawMessage
+ if err := json.Unmarshal(aux.Content, &rawContent); err != nil {
+ return err
+ }
+
+ content := make([]Content, len(rawContent))
+ for i, rawItem := range rawContent {
+ item, err := UnmarshalContent(rawItem)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err)
+ }
+ content[i] = item
+ }
+ r.Content = content
+
+ // Unmarshal ProviderMetadata
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ r.ProviderMetadata = metadata
+ }
+
+ return nil
+}
+
+func (s *StreamPart) UnmarshalJSON(data []byte) error {
+ var aux struct {
+ Type StreamPartType `json:"type"`
+ ID string `json:"id"`
+ ToolCallName string `json:"tool_call_name"`
+ ToolCallInput string `json:"tool_call_input"`
+ Delta string `json:"delta"`
+ ProviderExecuted bool `json:"provider_executed"`
+ Usage Usage `json:"usage"`
+ FinishReason FinishReason `json:"finish_reason"`
+ Error error `json:"error"`
+ Warnings []CallWarning `json:"warnings"`
+ SourceType SourceType `json:"source_type"`
+ URL string `json:"url"`
+ Title string `json:"title"`
+ ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ s.Type = aux.Type
+ s.ID = aux.ID
+ s.ToolCallName = aux.ToolCallName
+ s.ToolCallInput = aux.ToolCallInput
+ s.Delta = aux.Delta
+ s.ProviderExecuted = aux.ProviderExecuted
+ s.Usage = aux.Usage
+ s.FinishReason = aux.FinishReason
+ s.Error = aux.Error
+ s.Warnings = aux.Warnings
+ s.SourceType = aux.SourceType
+ s.URL = aux.URL
+ s.Title = aux.Title
+
+ // Unmarshal ProviderMetadata
+ if len(aux.ProviderMetadata) > 0 {
+ metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
+ if err != nil {
+ return err
+ }
+ s.ProviderMetadata = metadata
+ }
+
+ return nil
+}
@@ -0,0 +1,70 @@
+package fantasy
+
+import (
+ "encoding/json"
+ "fmt"
+ "sync"
+)
+
+// providerDataJSON is the serialized wrapper used by the registry.
+type providerDataJSON struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+}
+
+// UnmarshalFunc converts raw JSON into a ProviderOptionsData implementation.
+type UnmarshalFunc func([]byte) (ProviderOptionsData, error)
+
+var (
+ providerRegistry = make(map[string]UnmarshalFunc)
+ registryMutex sync.RWMutex
+)
+
+// RegisterProviderType registers a provider type ID with its unmarshal function.
+// Type IDs must be globally unique (e.g. "openai.options").
+func RegisterProviderType(typeID string, unmarshalFn UnmarshalFunc) {
+ registryMutex.Lock()
+ defer registryMutex.Unlock()
+ providerRegistry[typeID] = unmarshalFn
+}
+
+// unmarshalProviderData routes a typed payload to the correct constructor.
+func unmarshalProviderData(data []byte) (ProviderOptionsData, error) {
+ var pj providerDataJSON
+ if err := json.Unmarshal(data, &pj); err != nil {
+ return nil, err
+ }
+
+ registryMutex.RLock()
+ unmarshalFn, exists := providerRegistry[pj.Type]
+ registryMutex.RUnlock()
+
+ if !exists {
+ return nil, fmt.Errorf("unknown provider data type: %s", pj.Type)
+ }
+
+ return unmarshalFn(pj.Data)
+}
+
+// unmarshalProviderDataMap is a helper for unmarshaling maps of provider data.
+func unmarshalProviderDataMap(data map[string]json.RawMessage) (map[string]ProviderOptionsData, error) {
+ result := make(map[string]ProviderOptionsData)
+ for provider, rawData := range data {
+ providerData, err := unmarshalProviderData(rawData)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal provider data for %s: %w", provider, err)
+ }
+ result[provider] = providerData
+ }
+ return result, nil
+}
+
+// UnmarshalProviderOptions unmarshals a map of provider options by type.
+func UnmarshalProviderOptions(data map[string]json.RawMessage) (ProviderOptions, error) {
+ return unmarshalProviderDataMap(data)
+}
+
+// UnmarshalProviderMetadata unmarshals a map of provider metadata by type.
+func UnmarshalProviderMetadata(data map[string]json.RawMessage) (ProviderMetadata, error) {
+ return unmarshalProviderDataMap(data)
+}
@@ -1,7 +1,18 @@
// Package anthropic provides an implementation of the fantasy AI SDK for Anthropic's language models.
package anthropic
-import "charm.land/fantasy"
+import (
+ "encoding/json"
+
+ "charm.land/fantasy"
+)
+
+// Global type identifiers for Anthropic-specific provider data.
+const (
+ TypeProviderOptions = Name + ".options"
+ TypeReasoningOptionMetadata = Name + ".reasoning_metadata"
+ TypeProviderCacheControl = Name + ".cache_control_options"
+)
// ProviderOptions represents additional options for the Anthropic provider.
type ProviderOptions struct {
@@ -13,6 +24,34 @@ type ProviderOptions struct {
// Options implements the ProviderOptions interface.
func (o *ProviderOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions.
+func (o ProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions.
+func (o *ProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderOptions
+ var oo plain
+ err := json.Unmarshal(data, &oo)
+ if err != nil {
+ return err
+ }
+ *o = ProviderOptions(oo)
+ return nil
+}
+
// ThinkingProviderOption represents thinking options for the Anthropic provider.
type ThinkingProviderOption struct {
BudgetTokens int64 `json:"budget_tokens"`
@@ -27,6 +66,34 @@ type ReasoningOptionMetadata struct {
// Options implements the ProviderOptions interface.
func (*ReasoningOptionMetadata) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ReasoningOptionMetadata.
+func (m ReasoningOptionMetadata) MarshalJSON() ([]byte, error) {
+ type plain ReasoningOptionMetadata
+ raw, err := json.Marshal(plain(m))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeReasoningOptionMetadata,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ReasoningOptionMetadata.
+func (m *ReasoningOptionMetadata) UnmarshalJSON(data []byte) error {
+ type plain ReasoningOptionMetadata
+ var rm plain
+ err := json.Unmarshal(data, &rm)
+ if err != nil {
+ return err
+ }
+ *m = ReasoningOptionMetadata(rm)
+ return nil
+}
+
// ProviderCacheControlOptions represents cache control options for the Anthropic provider.
type ProviderCacheControlOptions struct {
CacheControl CacheControl `json:"cache_control"`
@@ -35,6 +102,34 @@ type ProviderCacheControlOptions struct {
// Options implements the ProviderOptions interface.
func (*ProviderCacheControlOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderCacheControlOptions.
+func (o ProviderCacheControlOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderCacheControlOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderCacheControl,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderCacheControlOptions.
+func (o *ProviderCacheControlOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderCacheControlOptions
+ var cc plain
+ err := json.Unmarshal(data, &cc)
+ if err != nil {
+ return err
+ }
+ *o = ProviderCacheControlOptions(cc)
+ return nil
+}
+
// CacheControl represents cache control settings for the Anthropic provider.
type CacheControl struct {
Type string `json:"type"`
@@ -62,3 +157,28 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) {
}
return &options, nil
}
+
+// Register Anthropic provider-specific types with the global registry.
+func init() {
+ fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeReasoningOptionMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ReasoningOptionMetadata
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeProviderCacheControl, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderCacheControlOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+}
@@ -1,7 +1,16 @@
// Package google provides an implementation of the fantasy AI SDK for Google's language models.
package google
-import "charm.land/fantasy"
+import (
+ "encoding/json"
+
+ "charm.land/fantasy"
+)
+
+// Global type identifiers for Google-specific provider data.
+const (
+ TypeProviderOptions = Name + ".options"
+)
// ThinkingConfig represents thinking configuration for the Google provider.
type ThinkingConfig struct {
@@ -51,6 +60,34 @@ type ProviderOptions struct {
// Options implements the ProviderOptionsData interface for ProviderOptions.
func (o *ProviderOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions.
+func (o ProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions.
+func (o *ProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderOptions
+ var oo plain
+ err := json.Unmarshal(data, &oo)
+ if err != nil {
+ return err
+ }
+ *o = ProviderOptions(oo)
+ return nil
+}
+
// ParseOptions parses provider options from a map for the Google provider.
func ParseOptions(data map[string]any) (*ProviderOptions, error) {
var options ProviderOptions
@@ -59,3 +96,14 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) {
}
return &options, nil
}
+
+// Register Google provider-specific types with the global registry.
+func init() {
+ fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+}
@@ -2,6 +2,8 @@
package openai
import (
+ "encoding/json"
+
"charm.land/fantasy"
"github.com/openai/openai-go/v2"
)
@@ -20,6 +22,13 @@ const (
ReasoningEffortHigh ReasoningEffort = "high"
)
+// Global type identifiers for OpenAI-specific provider data.
+const (
+ TypeProviderOptions = Name + ".options"
+ TypeProviderFileOptions = Name + ".file_options"
+ TypeProviderMetadata = Name + ".metadata"
+)
+
// ProviderMetadata represents additional metadata from OpenAI provider.
type ProviderMetadata struct {
Logprobs []openai.ChatCompletionTokenLogprob `json:"logprobs"`
@@ -30,6 +39,34 @@ type ProviderMetadata struct {
// Options implements the ProviderOptions interface.
func (*ProviderMetadata) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata.
+func (m ProviderMetadata) MarshalJSON() ([]byte, error) {
+ type plain ProviderMetadata
+ raw, err := json.Marshal(plain(m))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderMetadata,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata.
+func (m *ProviderMetadata) UnmarshalJSON(data []byte) error {
+ type plain ProviderMetadata
+ var pm plain
+ err := json.Unmarshal(data, &pm)
+ if err != nil {
+ return err
+ }
+ *m = ProviderMetadata(pm)
+ return nil
+}
+
// ProviderOptions represents additional options for OpenAI provider.
type ProviderOptions struct {
LogitBias map[string]int64 `json:"logit_bias"`
@@ -52,6 +89,34 @@ type ProviderOptions struct {
// Options implements the ProviderOptions interface.
func (*ProviderOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions.
+func (o ProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions.
+func (o *ProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderOptions
+ var oo plain
+ err := json.Unmarshal(data, &oo)
+ if err != nil {
+ return err
+ }
+ *o = ProviderOptions(oo)
+ return nil
+}
+
// ProviderFileOptions represents file options for OpenAI provider.
type ProviderFileOptions struct {
ImageDetail string `json:"image_detail"`
@@ -60,6 +125,34 @@ type ProviderFileOptions struct {
// Options implements the ProviderOptions interface.
func (*ProviderFileOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderFileOptions.
+func (o ProviderFileOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderFileOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderFileOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderFileOptions.
+func (o *ProviderFileOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderFileOptions
+ var of plain
+ err := json.Unmarshal(data, &of)
+ if err != nil {
+ return err
+ }
+ *o = ProviderFileOptions(of)
+ return nil
+}
+
// ReasoningEffortOption creates a pointer to a ReasoningEffort value.
func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort {
return &e
@@ -87,3 +180,28 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) {
}
return &options, nil
}
+
+// Register OpenAI provider-specific types with the global registry.
+func init() {
+ fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeProviderFileOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderFileOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderMetadata
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+}
@@ -2,11 +2,18 @@
package openai
import (
+ "encoding/json"
"slices"
"charm.land/fantasy"
)
+// Global type identifiers for OpenAI Responses API-specific data.
+const (
+ TypeResponsesProviderOptions = Name + ".responses.options"
+ TypeResponsesReasoningMetadata = Name + ".responses.reasoning_metadata"
+)
+
// ResponsesReasoningMetadata represents reasoning metadata for OpenAI Responses API.
type ResponsesReasoningMetadata struct {
ItemID string `json:"item_id"`
@@ -17,6 +24,34 @@ type ResponsesReasoningMetadata struct {
// Options implements the ProviderOptions interface.
func (*ResponsesReasoningMetadata) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ResponsesReasoningMetadata.
+func (m ResponsesReasoningMetadata) MarshalJSON() ([]byte, error) {
+ type plain ResponsesReasoningMetadata
+ raw, err := json.Marshal(plain(m))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeResponsesReasoningMetadata,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesReasoningMetadata.
+func (m *ResponsesReasoningMetadata) UnmarshalJSON(data []byte) error {
+ type plain ResponsesReasoningMetadata
+ var rm plain
+ err := json.Unmarshal(data, &rm)
+ if err != nil {
+ return err
+ }
+ *m = ResponsesReasoningMetadata(rm)
+ return nil
+}
+
// IncludeType represents the type of content to include for OpenAI Responses API.
type IncludeType string
@@ -71,6 +106,37 @@ type ResponsesProviderOptions struct {
User *string `json:"user"`
}
+// Options implements the ProviderOptions interface.
+func (*ResponsesProviderOptions) Options() {}
+
+// MarshalJSON implements custom JSON marshaling with type info for ResponsesProviderOptions.
+func (o ResponsesProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ResponsesProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeResponsesProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesProviderOptions.
+func (o *ResponsesProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ResponsesProviderOptions
+ var ro plain
+ err := json.Unmarshal(data, &ro)
+ if err != nil {
+ return err
+ }
+ *o = ResponsesProviderOptions(ro)
+ return nil
+}
+
// responsesReasoningModelIds lists the model IDs that support reasoning for OpenAI Responses API.
var responsesReasoningModelIDs = []string{
"o1",
@@ -121,9 +187,6 @@ var responsesModelIDs = append([]string{
"gpt-5-chat-latest",
}, responsesReasoningModelIDs...)
-// Options implements the ProviderOptions interface.
-func (*ResponsesProviderOptions) Options() {}
-
// NewResponsesProviderOptions creates new provider options for OpenAI Responses API.
func NewResponsesProviderOptions(opts *ResponsesProviderOptions) fantasy.ProviderOptions {
return fantasy.ProviderOptions{
@@ -149,3 +212,21 @@ func IsResponsesModel(modelID string) bool {
func IsResponsesReasoningModel(modelID string) bool {
return slices.Contains(responsesReasoningModelIDs, modelID)
}
+
+// Register OpenAI Responses API-specific types with the global registry.
+func init() {
+ fantasy.RegisterProviderType(TypeResponsesProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ResponsesProviderOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeResponsesReasoningMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ResponsesReasoningMetadata
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+}
@@ -2,10 +2,17 @@
package openaicompat
import (
+ "encoding/json"
+
"charm.land/fantasy"
"charm.land/fantasy/providers/openai"
)
+// Global type identifiers for OpenRouter-specific provider data.
+const (
+ TypeProviderOptions = Name + ".options"
+)
+
// ProviderOptions represents additional options for the OpenAI-compatible provider.
type ProviderOptions struct {
User *string `json:"user"`
@@ -20,6 +27,34 @@ type ReasoningData struct {
// Options implements the ProviderOptions interface.
func (*ProviderOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions.
+func (o ProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions.
+func (o *ProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderOptions
+ var oo plain
+ err := json.Unmarshal(data, &oo)
+ if err != nil {
+ return err
+ }
+ *o = ProviderOptions(oo)
+ return nil
+}
+
// NewProviderOptions creates new provider options for the OpenAI-compatible provider.
func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions {
return fantasy.ProviderOptions{
@@ -2,6 +2,8 @@
package openrouter
import (
+ "encoding/json"
+
"charm.land/fantasy"
)
@@ -17,6 +19,12 @@ const (
ReasoningEffortHigh ReasoningEffort = "high"
)
+// Global type identifiers for OpenRouter-specific provider data.
+const (
+ TypeProviderOptions = Name + ".options"
+ TypeProviderMetadata = Name + ".metadata"
+)
+
// PromptTokensDetails represents details about prompt tokens for OpenRouter.
type PromptTokensDetails struct {
CachedTokens int64
@@ -54,6 +62,34 @@ type ProviderMetadata struct {
// Options implements the ProviderOptionsData interface for ProviderMetadata.
func (*ProviderMetadata) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata.
+func (m ProviderMetadata) MarshalJSON() ([]byte, error) {
+ type plain ProviderMetadata
+ raw, err := json.Marshal(plain(m))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderMetadata,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata.
+func (m *ProviderMetadata) UnmarshalJSON(data []byte) error {
+ type plain ProviderMetadata
+ var pm plain
+ err := json.Unmarshal(data, &pm)
+ if err != nil {
+ return err
+ }
+ *m = ProviderMetadata(pm)
+ return nil
+}
+
// ReasoningOptions represents reasoning options for OpenRouter.
type ReasoningOptions struct {
// Whether reasoning is enabled
@@ -110,6 +146,34 @@ type ProviderOptions struct {
// Options implements the ProviderOptionsData interface for ProviderOptions.
func (*ProviderOptions) Options() {}
+// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions.
+func (o ProviderOptions) MarshalJSON() ([]byte, error) {
+ type plain ProviderOptions
+ raw, err := json.Marshal(plain(o))
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }{
+ Type: TypeProviderOptions,
+ Data: raw,
+ })
+}
+
+// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions.
+func (o *ProviderOptions) UnmarshalJSON(data []byte) error {
+ type plain ProviderOptions
+ var oo plain
+ err := json.Unmarshal(data, &oo)
+ if err != nil {
+ return err
+ }
+ *o = ProviderOptions(oo)
+ return nil
+}
+
// ReasoningDetail represents a reasoning detail for OpenRouter.
type ReasoningDetail struct {
ID string `json:"id,omitempty"`
@@ -148,3 +212,21 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) {
}
return &options, nil
}
+
+// Register OpenRouter provider-specific types with the global registry.
+func init() {
+ fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderOptions
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+ fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) {
+ var v ProviderMetadata
+ if err := json.Unmarshal(data, &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
+ })
+}
@@ -0,0 +1,140 @@
+package providertests
+
+import (
+ "encoding/json"
+ "testing"
+
+ "charm.land/fantasy"
+ "charm.land/fantasy/providers/openai"
+ "github.com/stretchr/testify/require"
+)
+
+func TestProviderRegistry_Serialization_OpenAIOptions(t *testing.T) {
+ msg := fantasy.Message{
+ Role: fantasy.MessageRoleUser,
+ Content: []fantasy.MessagePart{
+ fantasy.TextPart{Text: "hi"},
+ },
+ ProviderOptions: fantasy.ProviderOptions{
+ openai.Name: &openai.ProviderOptions{User: fantasy.Opt("tester")},
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ require.NoError(t, err)
+
+ var raw struct {
+ ProviderOptions map[string]map[string]any `json:"provider_options"`
+ }
+ require.NoError(t, json.Unmarshal(data, &raw))
+
+ po, ok := raw.ProviderOptions[openai.Name]
+ require.True(t, ok)
+ require.Equal(t, openai.TypeProviderOptions, po["type"]) // no magic strings
+ // ensure inner data has the field we set
+ inner, ok := po["data"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "tester", inner["user"])
+
+ var decoded fantasy.Message
+ require.NoError(t, json.Unmarshal(data, &decoded))
+
+ got, ok := decoded.ProviderOptions[openai.Name]
+ require.True(t, ok)
+ opt, ok := got.(*openai.ProviderOptions)
+ require.True(t, ok)
+ require.NotNil(t, opt.User)
+ require.Equal(t, "tester", *opt.User)
+}
+
+func TestProviderRegistry_Serialization_OpenAIResponses(t *testing.T) {
+ // Use ResponsesProviderOptions in provider options
+ msg := fantasy.Message{
+ Role: fantasy.MessageRoleUser,
+ Content: []fantasy.MessagePart{
+ fantasy.TextPart{Text: "hello"},
+ },
+ ProviderOptions: fantasy.ProviderOptions{
+ openai.Name: &openai.ResponsesProviderOptions{
+ PromptCacheKey: fantasy.Opt("cache-key-1"),
+ ParallelToolCalls: fantasy.Opt(true),
+ },
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ require.NoError(t, err)
+
+ // JSON should include the typed wrapper with constant TypeResponsesProviderOptions
+ var raw struct {
+ ProviderOptions map[string]map[string]any `json:"provider_options"`
+ }
+ require.NoError(t, json.Unmarshal(data, &raw))
+
+ po := raw.ProviderOptions[openai.Name]
+ require.Equal(t, openai.TypeResponsesProviderOptions, po["type"]) // no magic strings
+ inner, ok := po["data"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "cache-key-1", inner["prompt_cache_key"])
+ require.Equal(t, true, inner["parallel_tool_calls"])
+
+ // Unmarshal back and assert concrete type
+ var decoded fantasy.Message
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ got := decoded.ProviderOptions[openai.Name]
+ reqOpts, ok := got.(*openai.ResponsesProviderOptions)
+ require.True(t, ok)
+ require.NotNil(t, reqOpts.PromptCacheKey)
+ require.Equal(t, "cache-key-1", *reqOpts.PromptCacheKey)
+ require.NotNil(t, reqOpts.ParallelToolCalls)
+ require.Equal(t, true, *reqOpts.ParallelToolCalls)
+}
+
+func TestProviderRegistry_Serialization_OpenAIResponsesReasoningMetadata(t *testing.T) {
+ resp := fantasy.Response{
+ Content: []fantasy.Content{
+ fantasy.TextContent{
+ Text: "",
+ ProviderMetadata: fantasy.ProviderMetadata{
+ openai.Name: &openai.ResponsesReasoningMetadata{
+ ItemID: "item-123",
+ Summary: []string{"part1", "part2"},
+ },
+ },
+ },
+ },
+ }
+
+ data, err := json.Marshal(resp)
+ require.NoError(t, err)
+
+ // Ensure the provider metadata is wrapped with type using constant
+ var raw struct {
+ Content []struct {
+ Type string `json:"type"`
+ Data map[string]any `json:"data"`
+ } `json:"content"`
+ }
+ require.NoError(t, json.Unmarshal(data, &raw))
+ require.Greater(t, len(raw.Content), 0)
+ tc := raw.Content[0]
+ pm, ok := tc.Data["provider_metadata"].(map[string]any)
+ require.True(t, ok)
+ om, ok := pm[openai.Name].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, openai.TypeResponsesReasoningMetadata, om["type"]) // no magic strings
+ inner, ok := om["data"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "item-123", inner["item_id"])
+
+ // Unmarshal back
+ var decoded fantasy.Response
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ pmDecoded := decoded.Content[0].(fantasy.TextContent).ProviderMetadata
+ val, ok := pmDecoded[openai.Name]
+ require.True(t, ok)
+ meta, ok := val.(*openai.ResponsesReasoningMetadata)
+ require.True(t, ok)
+ require.Equal(t, "item-123", meta.ItemID)
+ require.Equal(t, []string{"part1", "part2"}, meta.Summary)
+}