From 7ec05815fc3996e77c9518490350709743583be7 Mon Sep 17 00:00:00 2001 From: kujtimiihoxha Date: Wed, 22 Oct 2025 13:15:58 +0200 Subject: [PATCH] feat: make structs serializable --- content.go | 48 + content_json.go | 1022 ++++++++++++++++++++ json_test.go | 647 +++++++++++++ model_json.go | 149 +++ provider_registry.go | 70 ++ providers/anthropic/provider_options.go | 122 ++- providers/google/provider_options.go | 50 +- providers/openai/provider_options.go | 118 +++ providers/openai/responses_options.go | 87 +- providers/openaicompat/provider_options.go | 35 + providers/openrouter/provider_options.go | 82 ++ providertests/provider_registry_test.go | 140 +++ 12 files changed, 2565 insertions(+), 5 deletions(-) create mode 100644 content_json.go create mode 100644 json_test.go create mode 100644 model_json.go create mode 100644 provider_registry.go create mode 100644 providertests/provider_registry_test.go diff --git a/content.go b/content.go index 0dda6c5c296a0e8ca9a0320a74c03a3c18ab3f5e..2a873b35e2e552244a199c071ada71c6fb6141a4 100644 --- a/content.go +++ b/content.go @@ -1,8 +1,56 @@ package fantasy +import "encoding/json" + // ProviderOptionsData is an interface for provider-specific options data. +// All implementations MUST also implement encoding/json.Marshaler and +// encoding/json.Unmarshaler interfaces to ensure proper JSON serialization +// with the provider registry system. +// +// Required implementation pattern: +// +// type MyProviderOptions struct { +// Field string `json:"field"` +// } +// +// // Implement ProviderOptionsData +// func (*MyProviderOptions) Options() {} +// +// // Implement json.Marshaler - use fantasy.MarshalProviderData +// func (m MyProviderOptions) MarshalJSON() ([]byte, error) { +// return fantasy.MarshalProviderData(&m, "provider.type") +// } +// +// // Implement json.Unmarshaler - use fantasy.UnmarshalProviderData +// func (m *MyProviderOptions) UnmarshalJSON(data []byte) error { +// providerData, err := fantasy.UnmarshalProviderData(data) +// if err != nil { +// return err +// } +// opts, ok := providerData.(*MyProviderOptions) +// if !ok { +// return fmt.Errorf("invalid type") +// } +// *m = *opts +// return nil +// } +// +// Additionally, register the type in init(): +// +// func init() { +// fantasy.RegisterProviderType("provider.type", func(data []byte) (fantasy.ProviderOptionsData, error) { +// var opts MyProviderOptions +// if err := json.Unmarshal(data, &opts); err != nil { +// return nil, err +// } +// return &opts, nil +// }) +// } type ProviderOptionsData interface { + // Options is a marker method that identifies types implementing this interface. Options() + json.Marshaler + json.Unmarshaler } // ProviderMetadata represents additional provider-specific metadata. diff --git a/content_json.go b/content_json.go new file mode 100644 index 0000000000000000000000000000000000000000..bfeb45cbee6922f1290f22b3d01e3032a21d5cb0 --- /dev/null +++ b/content_json.go @@ -0,0 +1,1022 @@ +package fantasy + +import ( + "encoding/json" + "errors" + "fmt" +) + +// contentJSON is a helper type for JSON serialization of Content in Response. +type contentJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// messagePartJSON is a helper type for JSON serialization of MessagePart. +type messagePartJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// toolResultOutputJSON is a helper type for JSON serialization of ToolResultOutputContent. +type toolResultOutputJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// toolJSON is a helper type for JSON serialization of Tool. +type toolJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +func (t TextContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + Text: t.Text, + ProviderMetadata: t.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *TextContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.Text = aux.Text + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +func (r ReasoningContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + Text: r.Text, + ProviderMetadata: r.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeReasoning), + Data: json.RawMessage(dataBytes), + }) +} + +func (r *ReasoningContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + r.Text = aux.Text + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + r.ProviderMetadata = metadata + } + + return nil +} + +func (f FileContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + MediaType string `json:"media_type"` + Data []byte `json:"data"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + MediaType: f.MediaType, + Data: f.Data, + ProviderMetadata: f.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeFile), + Data: json.RawMessage(dataBytes), + }) +} + +func (f *FileContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + MediaType string `json:"media_type"` + Data []byte `json:"data"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + f.MediaType = aux.MediaType + f.Data = aux.Data + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + f.ProviderMetadata = metadata + } + + return nil +} + +func (s SourceContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + SourceType SourceType `json:"source_type"` + ID string `json:"id"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + MediaType string `json:"media_type,omitempty"` + Filename string `json:"filename,omitempty"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + SourceType: s.SourceType, + ID: s.ID, + URL: s.URL, + Title: s.Title, + MediaType: s.MediaType, + Filename: s.Filename, + ProviderMetadata: s.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeSource), + Data: json.RawMessage(dataBytes), + }) +} + +func (s *SourceContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + SourceType SourceType `json:"source_type"` + ID string `json:"id"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + MediaType string `json:"media_type,omitempty"` + Filename string `json:"filename,omitempty"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + s.SourceType = aux.SourceType + s.ID = aux.ID + s.URL = aux.URL + s.Title = aux.Title + s.MediaType = aux.MediaType + s.Filename = aux.Filename + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + s.ProviderMetadata = metadata + } + + return nil +} + +func (t ToolCallContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + Invalid bool `json:"invalid,omitempty"` + ValidationError error `json:"validation_error,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Input: t.Input, + ProviderExecuted: t.ProviderExecuted, + ProviderMetadata: t.ProviderMetadata, + Invalid: t.Invalid, + ValidationError: t.ValidationError, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeToolCall), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolCallContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + Invalid bool `json:"invalid,omitempty"` + ValidationError error `json:"validation_error,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.Input = aux.Input + t.ProviderExecuted = aux.ProviderExecuted + t.Invalid = aux.Invalid + t.ValidationError = aux.ValidationError + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +func (t ToolResultContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Result ToolResultOutputContent `json:"result"` + ClientMetadata string `json:"client_metadata,omitempty"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Result: t.Result, + ClientMetadata: t.ClientMetadata, + ProviderExecuted: t.ProviderExecuted, + ProviderMetadata: t.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeToolResult), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolResultContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Result json.RawMessage `json:"result"` + ClientMetadata string `json:"client_metadata,omitempty"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.ClientMetadata = aux.ClientMetadata + t.ProviderExecuted = aux.ProviderExecuted + + // Unmarshal the Result field + result, err := UnmarshalToolResultOutputContent(aux.Result) + if err != nil { + return fmt.Errorf("failed to unmarshal tool result output: %w", err) + } + t.Result = result + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + }{ + Text: t.Text, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Text string `json:"text"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + + t.Text = temp.Text + return nil +} + +func (t ToolResultOutputContentError) MarshalJSON() ([]byte, error) { + errMsg := "" + if t.Error != nil { + errMsg = t.Error.Error() + } + dataBytes, err := json.Marshal(struct { + Error string `json:"error"` + }{ + Error: errMsg, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeError), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Error string `json:"error"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + if temp.Error != "" { + t.Error = errors.New(temp.Error) + } + return nil +} + +func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Data string `json:"data"` + MediaType string `json:"media_type"` + }{ + Data: t.Data, + MediaType: t.MediaType, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeMedia), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Data string `json:"data"` + MediaType string `json:"media_type"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + + t.Data = temp.Data + t.MediaType = temp.MediaType + return nil +} + +func (t TextPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Text: t.Text, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *TextPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.Text = aux.Text + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +func (r ReasoningPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Text: r.Text, + ProviderOptions: r.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeReasoning), + Data: json.RawMessage(dataBytes), + }) +} + +func (r *ReasoningPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + r.Text = aux.Text + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + r.ProviderOptions = options + } + + return nil +} + +func (f FilePart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Filename string `json:"filename"` + Data []byte `json:"data"` + MediaType string `json:"media_type"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Filename: f.Filename, + Data: f.Data, + MediaType: f.MediaType, + ProviderOptions: f.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeFile), + Data: json.RawMessage(dataBytes), + }) +} + +func (f *FilePart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Filename string `json:"filename"` + Data []byte `json:"data"` + MediaType string `json:"media_type"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + f.Filename = aux.Filename + f.Data = aux.Data + f.MediaType = aux.MediaType + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + f.ProviderOptions = options + } + + return nil +} + +func (t ToolCallPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Input: t.Input, + ProviderExecuted: t.ProviderExecuted, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeToolCall), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolCallPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.Input = aux.Input + t.ProviderExecuted = aux.ProviderExecuted + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +func (t ToolResultPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + Output ToolResultOutputContent `json:"output"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + ToolCallID: t.ToolCallID, + Output: t.Output, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeToolResult), + Data: json.RawMessage(dataBytes), + }) +} + +func (t *ToolResultPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + Output json.RawMessage `json:"output"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + + // Unmarshal the Output field + output, err := UnmarshalToolResultOutputContent(aux.Output) + if err != nil { + return fmt.Errorf("failed to unmarshal tool result output: %w", err) + } + t.Output = output + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +func (m *Message) UnmarshalJSON(data []byte) error { + var aux struct { + Role MessageRole `json:"role"` + Content []json.RawMessage `json:"content"` + ProviderOptions map[string]json.RawMessage `json:"provider_options"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + m.Role = aux.Role + + m.Content = make([]MessagePart, len(aux.Content)) + for i, rawPart := range aux.Content { + part, err := UnmarshalMessagePart(rawPart) + if err != nil { + return fmt.Errorf("failed to unmarshal message part at index %d: %w", i, err) + } + m.Content[i] = part + } + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + m.ProviderOptions = options + } + + return nil +} + +func (f FunctionTool) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Name: f.Name, + Description: f.Description, + InputSchema: f.InputSchema, + ProviderOptions: f.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolJSON{ + Type: string(ToolTypeFunction), + Data: json.RawMessage(dataBytes), + }) +} + +func (f *FunctionTool) UnmarshalJSON(data []byte) error { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return err + } + + var aux struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(tj.Data, &aux); err != nil { + return err + } + + f.Name = aux.Name + f.Description = aux.Description + f.InputSchema = aux.InputSchema + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + f.ProviderOptions = options + } + + return nil +} + +func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ID string `json:"id"` + Name string `json:"name"` + Args map[string]any `json:"args"` + }{ + ID: p.ID, + Name: p.Name, + Args: p.Args, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolJSON{ + Type: string(ToolTypeProviderDefined), + Data: json.RawMessage(dataBytes), + }) +} + +func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return err + } + + var aux struct { + ID string `json:"id"` + Name string `json:"name"` + Args map[string]any `json:"args"` + } + + if err := json.Unmarshal(tj.Data, &aux); err != nil { + return err + } + + p.ID = aux.ID + p.Name = aux.Name + p.Args = aux.Args + + return nil +} + +// UnmarshalTool unmarshals JSON into the appropriate Tool type +func UnmarshalTool(data []byte) (Tool, error) { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, err + } + + switch ToolType(tj.Type) { + case ToolTypeFunction: + var tool FunctionTool + if err := tool.UnmarshalJSON(data); err != nil { + return nil, err + } + return tool, nil + case ToolTypeProviderDefined: + var tool ProviderDefinedTool + if err := tool.UnmarshalJSON(data); err != nil { + return nil, err + } + return tool, nil + default: + return nil, fmt.Errorf("unknown tool type: %s", tj.Type) + } +} + +// UnmarshalContent unmarshals JSON into the appropriate Content type +func UnmarshalContent(data []byte) (Content, error) { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return nil, err + } + + switch ContentType(cj.Type) { + case ContentTypeText: + var content TextContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeReasoning: + var content ReasoningContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeFile: + var content FileContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeSource: + var content SourceContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeToolCall: + var content ToolCallContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeToolResult: + var content ToolResultContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + default: + return nil, fmt.Errorf("unknown content type: %s", cj.Type) + } +} + +// UnmarshalMessagePart unmarshals JSON into the appropriate MessagePart type +func UnmarshalMessagePart(data []byte) (MessagePart, error) { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return nil, err + } + + switch ContentType(mpj.Type) { + case ContentTypeText: + var part TextPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeReasoning: + var part ReasoningPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeFile: + var part FilePart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeToolCall: + var part ToolCallPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeToolResult: + var part ToolResultPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + default: + return nil, fmt.Errorf("unknown message part type: %s", mpj.Type) + } +} + +// UnmarshalToolResultOutputContent unmarshals JSON into the appropriate ToolResultOutputContent type +func UnmarshalToolResultOutputContent(data []byte) (ToolResultOutputContent, error) { + var troj toolResultOutputJSON + if err := json.Unmarshal(data, &troj); err != nil { + return nil, err + } + + switch ToolResultContentType(troj.Type) { + case ToolResultContentTypeText: + var content ToolResultOutputContentText + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ToolResultContentTypeError: + var content ToolResultOutputContentError + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ToolResultContentTypeMedia: + var content ToolResultOutputContentMedia + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + default: + return nil, fmt.Errorf("unknown tool result output content type: %s", troj.Type) + } +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000000000000000000000000000000000000..da9f491407e284dbc988c59595d79aa631dfd524 --- /dev/null +++ b/json_test.go @@ -0,0 +1,647 @@ +package fantasy + +import ( + "encoding/json" + "errors" + "reflect" + "testing" +) + +func TestMessageJSONSerialization(t *testing.T) { + tests := []struct { + name string + message Message + }{ + { + name: "simple text message", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Hello, world!"}, + }, + }, + }, + { + name: "message with multiple text parts", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "First part"}, + TextPart{Text: "Second part"}, + TextPart{Text: "Third part"}, + }, + }, + }, + { + name: "message with reasoning part", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + ReasoningPart{Text: "Let me think about this..."}, + TextPart{Text: "Here's my answer"}, + }, + }, + }, + { + name: "message with file part", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Here's an image:"}, + FilePart{ + Filename: "test.png", + Data: []byte{0x89, 0x50, 0x4E, 0x47}, // PNG header + MediaType: "image/png", + }, + }, + }, + }, + { + name: "message with tool call", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + ToolCallPart{ + ToolCallID: "call_123", + ToolName: "get_weather", + Input: `{"location": "San Francisco"}`, + ProviderExecuted: false, + }, + }, + }, + }, + { + name: "message with tool result - text output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_123", + Output: ToolResultOutputContentText{ + Text: "The weather is sunny, 72°F", + }, + }, + }, + }, + }, + { + name: "message with tool result - error output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_456", + Output: ToolResultOutputContentError{ + Error: errors.New("API rate limit exceeded"), + }, + }, + }, + }, + }, + { + name: "message with tool result - media output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_789", + Output: ToolResultOutputContentMedia{ + Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", + MediaType: "image/png", + }, + }, + }, + }, + }, + { + name: "complex message with mixed content", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "I'll analyze this image and call some tools."}, + ReasoningPart{Text: "First, I need to identify the objects..."}, + ToolCallPart{ + ToolCallID: "call_001", + ToolName: "analyze_image", + Input: `{"image_id": "img_123"}`, + ProviderExecuted: false, + }, + ToolCallPart{ + ToolCallID: "call_002", + ToolName: "get_context", + Input: `{"query": "similar images"}`, + ProviderExecuted: true, + }, + }, + }, + }, + { + name: "system message", + message: Message{ + Role: MessageRoleSystem, + Content: []MessagePart{ + TextPart{Text: "You are a helpful assistant."}, + }, + }, + }, + { + name: "empty content", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal the message + data, err := json.Marshal(tt.message) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + + // Unmarshal back + var decoded Message + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("failed to unmarshal message: %v", err) + } + + // Compare roles + if decoded.Role != tt.message.Role { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, tt.message.Role) + } + + // Compare content length + if len(decoded.Content) != len(tt.message.Content) { + t.Fatalf("content length mismatch: got %d, want %d", len(decoded.Content), len(tt.message.Content)) + } + + // Compare each content part + for i := range tt.message.Content { + original := tt.message.Content[i] + decodedPart := decoded.Content[i] + + if original.GetType() != decodedPart.GetType() { + t.Errorf("content[%d] type mismatch: got %v, want %v", i, decodedPart.GetType(), original.GetType()) + continue + } + + compareMessagePart(t, i, original, decodedPart) + } + }) + } +} + +func compareMessagePart(t *testing.T, index int, original, decoded MessagePart) { + switch original.GetType() { + case ContentTypeText: + orig := original.(TextPart) + dec := decoded.(TextPart) + if orig.Text != dec.Text { + t.Errorf("content[%d] text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ContentTypeReasoning: + orig := original.(ReasoningPart) + dec := decoded.(ReasoningPart) + if orig.Text != dec.Text { + t.Errorf("content[%d] reasoning text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ContentTypeFile: + orig := original.(FilePart) + dec := decoded.(FilePart) + if orig.Filename != dec.Filename { + t.Errorf("content[%d] filename mismatch: got %q, want %q", index, dec.Filename, orig.Filename) + } + if orig.MediaType != dec.MediaType { + t.Errorf("content[%d] media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType) + } + if !reflect.DeepEqual(orig.Data, dec.Data) { + t.Errorf("content[%d] file data mismatch", index) + } + + case ContentTypeToolCall: + orig := original.(ToolCallPart) + dec := decoded.(ToolCallPart) + if orig.ToolCallID != dec.ToolCallID { + t.Errorf("content[%d] tool call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID) + } + if orig.ToolName != dec.ToolName { + t.Errorf("content[%d] tool name mismatch: got %q, want %q", index, dec.ToolName, orig.ToolName) + } + if orig.Input != dec.Input { + t.Errorf("content[%d] tool input mismatch: got %q, want %q", index, dec.Input, orig.Input) + } + if orig.ProviderExecuted != dec.ProviderExecuted { + t.Errorf("content[%d] provider executed mismatch: got %v, want %v", index, dec.ProviderExecuted, orig.ProviderExecuted) + } + + case ContentTypeToolResult: + orig := original.(ToolResultPart) + dec := decoded.(ToolResultPart) + if orig.ToolCallID != dec.ToolCallID { + t.Errorf("content[%d] tool result call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID) + } + compareToolResultOutput(t, index, orig.Output, dec.Output) + } +} + +func compareToolResultOutput(t *testing.T, index int, original, decoded ToolResultOutputContent) { + if original.GetType() != decoded.GetType() { + t.Errorf("content[%d] tool result output type mismatch: got %v, want %v", index, decoded.GetType(), original.GetType()) + return + } + + switch original.GetType() { + case ToolResultContentTypeText: + orig := original.(ToolResultOutputContentText) + dec := decoded.(ToolResultOutputContentText) + if orig.Text != dec.Text { + t.Errorf("content[%d] tool result text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ToolResultContentTypeError: + orig := original.(ToolResultOutputContentError) + dec := decoded.(ToolResultOutputContentError) + if orig.Error.Error() != dec.Error.Error() { + t.Errorf("content[%d] tool result error mismatch: got %q, want %q", index, dec.Error.Error(), orig.Error.Error()) + } + + case ToolResultContentTypeMedia: + orig := original.(ToolResultOutputContentMedia) + dec := decoded.(ToolResultOutputContentMedia) + if orig.Data != dec.Data { + t.Errorf("content[%d] tool result media data mismatch", index) + } + if orig.MediaType != dec.MediaType { + t.Errorf("content[%d] tool result media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType) + } + } +} + +func TestHelperFunctions(t *testing.T) { + t.Run("NewUserMessage - text only", func(t *testing.T) { + msg := NewUserMessage("Hello") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.Role != MessageRoleUser { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleUser) + } + + if len(decoded.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(decoded.Content)) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Hello" { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Hello") + } + }) + + t.Run("NewUserMessage - with files", func(t *testing.T) { + msg := NewUserMessage("Check this image", + FilePart{ + Filename: "image1.jpg", + Data: []byte{0xFF, 0xD8, 0xFF}, + MediaType: "image/jpeg", + }, + FilePart{ + Filename: "image2.png", + Data: []byte{0x89, 0x50, 0x4E, 0x47}, + MediaType: "image/png", + }, + ) + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(decoded.Content) != 3 { + t.Fatalf("expected 3 content parts, got %d", len(decoded.Content)) + } + + // Check text part + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Check this image" { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Check this image") + } + + // Check first file + file1 := decoded.Content[1].(FilePart) + if file1.Filename != "image1.jpg" { + t.Errorf("file1 name mismatch: got %q, want %q", file1.Filename, "image1.jpg") + } + + // Check second file + file2 := decoded.Content[2].(FilePart) + if file2.Filename != "image2.png" { + t.Errorf("file2 name mismatch: got %q, want %q", file2.Filename, "image2.png") + } + }) + + t.Run("NewSystemMessage - single prompt", func(t *testing.T) { + msg := NewSystemMessage("You are a helpful assistant.") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.Role != MessageRoleSystem { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleSystem) + } + + if len(decoded.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(decoded.Content)) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "You are a helpful assistant." { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "You are a helpful assistant.") + } + }) + + t.Run("NewSystemMessage - multiple prompts", func(t *testing.T) { + msg := NewSystemMessage("First instruction", "Second instruction", "Third instruction") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(decoded.Content) != 3 { + t.Fatalf("expected 3 content parts, got %d", len(decoded.Content)) + } + + expected := []string{"First instruction", "Second instruction", "Third instruction"} + for i, exp := range expected { + textPart := decoded.Content[i].(TextPart) + if textPart.Text != exp { + t.Errorf("content[%d] text mismatch: got %q, want %q", i, textPart.Text, exp) + } + } + }) +} + +func TestEdgeCases(t *testing.T) { + t.Run("empty text part", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: ""}, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "" { + t.Errorf("expected empty text, got %q", textPart.Text) + } + }) + + t.Run("nil error in tool result", func(t *testing.T) { + msg := Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_123", + Output: ToolResultOutputContentError{ + Error: nil, + }, + }, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + toolResult := decoded.Content[0].(ToolResultPart) + errorOutput := toolResult.Output.(ToolResultOutputContentError) + if errorOutput.Error != nil { + t.Errorf("expected nil error, got %v", errorOutput.Error) + } + }) + + t.Run("empty file data", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + FilePart{ + Filename: "empty.txt", + Data: []byte{}, + MediaType: "text/plain", + }, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + filePart := decoded.Content[0].(FilePart) + if len(filePart.Data) != 0 { + t.Errorf("expected empty data, got %d bytes", len(filePart.Data)) + } + }) + + t.Run("unicode in text", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Hello 世界! 🌍 Привет"}, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Hello 世界! 🌍 Привет" { + t.Errorf("unicode text mismatch: got %q, want %q", textPart.Text, "Hello 世界! 🌍 Привет") + } + }) +} + +func TestInvalidJSONHandling(t *testing.T) { + t.Run("unknown message part type", func(t *testing.T) { + invalidJSON := `{ + "role": "user", + "content": [ + { + "type": "unknown-type", + "data": {} + } + ], + "provider_options": null + }` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for unknown message part type, got nil") + } + }) + + t.Run("unknown tool result output type", func(t *testing.T) { + invalidJSON := `{ + "role": "tool", + "content": [ + { + "type": "tool-result", + "data": { + "tool_call_id": "call_123", + "output": { + "type": "unknown-output-type", + "data": {} + }, + "provider_options": null + } + } + ], + "provider_options": null + }` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for unknown tool result output type, got nil") + } + }) + + t.Run("malformed JSON", func(t *testing.T) { + invalidJSON := `{"role": "user", "content": [` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for malformed JSON, got nil") + } + }) +} + +// Mock provider data for testing provider options +type mockProviderData struct { + Key string `json:"key"` +} + +func (m mockProviderData) Options() {} +func (m mockProviderData) Type() string { return "mock" } +func (m mockProviderData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + mockProviderData + }{ + Type: "mock", + mockProviderData: m, + }) +} + +func (m *mockProviderData) UnmarshalJSON(data []byte) error { + var aux struct { + Type string `json:"type"` + mockProviderData + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *m = aux.mockProviderData + return nil +} + +func TestPromptSerialization(t *testing.T) { + t.Run("serialize prompt (message slice)", func(t *testing.T) { + prompt := Prompt{ + NewSystemMessage("You are helpful"), + NewUserMessage("Hello"), + Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "Hi there!"}, + }, + }, + } + + data, err := json.Marshal(prompt) + if err != nil { + t.Fatalf("failed to marshal prompt: %v", err) + } + + var decoded Prompt + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal prompt: %v", err) + } + + if len(decoded) != 3 { + t.Fatalf("expected 3 messages, got %d", len(decoded)) + } + + if decoded[0].Role != MessageRoleSystem { + t.Errorf("message 0 role mismatch: got %v, want %v", decoded[0].Role, MessageRoleSystem) + } + + if decoded[1].Role != MessageRoleUser { + t.Errorf("message 1 role mismatch: got %v, want %v", decoded[1].Role, MessageRoleUser) + } + + if decoded[2].Role != MessageRoleAssistant { + t.Errorf("message 2 role mismatch: got %v, want %v", decoded[2].Role, MessageRoleAssistant) + } + }) +} diff --git a/model_json.go b/model_json.go new file mode 100644 index 0000000000000000000000000000000000000000..cfcc35d01cf90e383aaa40662eeabccb89f4ad39 --- /dev/null +++ b/model_json.go @@ -0,0 +1,149 @@ +package fantasy + +import ( + "encoding/json" + "fmt" +) + +func (c *Call) UnmarshalJSON(data []byte) error { + var aux struct { + Prompt Prompt `json:"prompt"` + MaxOutputTokens *int64 `json:"max_output_tokens"` + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int64 `json:"top_k"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + Tools []json.RawMessage `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice"` + ProviderOptions map[string]json.RawMessage `json:"provider_options"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + c.Prompt = aux.Prompt + c.MaxOutputTokens = aux.MaxOutputTokens + c.Temperature = aux.Temperature + c.TopP = aux.TopP + c.TopK = aux.TopK + c.PresencePenalty = aux.PresencePenalty + c.FrequencyPenalty = aux.FrequencyPenalty + c.ToolChoice = aux.ToolChoice + + // Unmarshal Tools slice + c.Tools = make([]Tool, len(aux.Tools)) + for i, rawTool := range aux.Tools { + tool, err := UnmarshalTool(rawTool) + if err != nil { + return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err) + } + c.Tools[i] = tool + } + + // Unmarshal ProviderOptions + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + c.ProviderOptions = options + } + + return nil +} + +func (r *Response) UnmarshalJSON(data []byte) error { + var aux struct { + Content json.RawMessage `json:"content"` + FinishReason FinishReason `json:"finish_reason"` + Usage Usage `json:"usage"` + Warnings []CallWarning `json:"warnings"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + r.FinishReason = aux.FinishReason + r.Usage = aux.Usage + r.Warnings = aux.Warnings + + // Unmarshal ResponseContent (need to know the type definition) + // If ResponseContent is []Content: + var rawContent []json.RawMessage + if err := json.Unmarshal(aux.Content, &rawContent); err != nil { + return err + } + + content := make([]Content, len(rawContent)) + for i, rawItem := range rawContent { + item, err := UnmarshalContent(rawItem) + if err != nil { + return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err) + } + content[i] = item + } + r.Content = content + + // Unmarshal ProviderMetadata + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + r.ProviderMetadata = metadata + } + + return nil +} + +func (s *StreamPart) UnmarshalJSON(data []byte) error { + var aux struct { + Type StreamPartType `json:"type"` + ID string `json:"id"` + ToolCallName string `json:"tool_call_name"` + ToolCallInput string `json:"tool_call_input"` + Delta string `json:"delta"` + ProviderExecuted bool `json:"provider_executed"` + Usage Usage `json:"usage"` + FinishReason FinishReason `json:"finish_reason"` + Error error `json:"error"` + Warnings []CallWarning `json:"warnings"` + SourceType SourceType `json:"source_type"` + URL string `json:"url"` + Title string `json:"title"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + s.Type = aux.Type + s.ID = aux.ID + s.ToolCallName = aux.ToolCallName + s.ToolCallInput = aux.ToolCallInput + s.Delta = aux.Delta + s.ProviderExecuted = aux.ProviderExecuted + s.Usage = aux.Usage + s.FinishReason = aux.FinishReason + s.Error = aux.Error + s.Warnings = aux.Warnings + s.SourceType = aux.SourceType + s.URL = aux.URL + s.Title = aux.Title + + // Unmarshal ProviderMetadata + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + s.ProviderMetadata = metadata + } + + return nil +} diff --git a/provider_registry.go b/provider_registry.go new file mode 100644 index 0000000000000000000000000000000000000000..19c8e7840346ba21d1015575ccd66a02edf2ffaa --- /dev/null +++ b/provider_registry.go @@ -0,0 +1,70 @@ +package fantasy + +import ( + "encoding/json" + "fmt" + "sync" +) + +// providerDataJSON is the serialized wrapper used by the registry. +type providerDataJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// UnmarshalFunc converts raw JSON into a ProviderOptionsData implementation. +type UnmarshalFunc func([]byte) (ProviderOptionsData, error) + +var ( + providerRegistry = make(map[string]UnmarshalFunc) + registryMutex sync.RWMutex +) + +// RegisterProviderType registers a provider type ID with its unmarshal function. +// Type IDs must be globally unique (e.g. "openai.options"). +func RegisterProviderType(typeID string, unmarshalFn UnmarshalFunc) { + registryMutex.Lock() + defer registryMutex.Unlock() + providerRegistry[typeID] = unmarshalFn +} + +// unmarshalProviderData routes a typed payload to the correct constructor. +func unmarshalProviderData(data []byte) (ProviderOptionsData, error) { + var pj providerDataJSON + if err := json.Unmarshal(data, &pj); err != nil { + return nil, err + } + + registryMutex.RLock() + unmarshalFn, exists := providerRegistry[pj.Type] + registryMutex.RUnlock() + + if !exists { + return nil, fmt.Errorf("unknown provider data type: %s", pj.Type) + } + + return unmarshalFn(pj.Data) +} + +// unmarshalProviderDataMap is a helper for unmarshaling maps of provider data. +func unmarshalProviderDataMap(data map[string]json.RawMessage) (map[string]ProviderOptionsData, error) { + result := make(map[string]ProviderOptionsData) + for provider, rawData := range data { + providerData, err := unmarshalProviderData(rawData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal provider data for %s: %w", provider, err) + } + result[provider] = providerData + } + return result, nil +} + +// UnmarshalProviderOptions unmarshals a map of provider options by type. +func UnmarshalProviderOptions(data map[string]json.RawMessage) (ProviderOptions, error) { + return unmarshalProviderDataMap(data) +} + +// UnmarshalProviderMetadata unmarshals a map of provider metadata by type. +func UnmarshalProviderMetadata(data map[string]json.RawMessage) (ProviderMetadata, error) { + return unmarshalProviderDataMap(data) +} diff --git a/providers/anthropic/provider_options.go b/providers/anthropic/provider_options.go index 905a4bdbead91b8f6622889745b1caf1255e4897..d7c90770a83518785cd46a601acd6370f0e97f1c 100644 --- a/providers/anthropic/provider_options.go +++ b/providers/anthropic/provider_options.go @@ -1,7 +1,18 @@ // Package anthropic provides an implementation of the fantasy AI SDK for Anthropic's language models. package anthropic -import "charm.land/fantasy" +import ( + "encoding/json" + + "charm.land/fantasy" +) + +// Global type identifiers for Anthropic-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeReasoningOptionMetadata = Name + ".reasoning_metadata" + TypeProviderCacheControl = Name + ".cache_control_options" +) // ProviderOptions represents additional options for the Anthropic provider. type ProviderOptions struct { @@ -13,6 +24,34 @@ type ProviderOptions struct { // Options implements the ProviderOptions interface. func (o *ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var oo plain + err := json.Unmarshal(data, &oo) + if err != nil { + return err + } + *o = ProviderOptions(oo) + return nil +} + // ThinkingProviderOption represents thinking options for the Anthropic provider. type ThinkingProviderOption struct { BudgetTokens int64 `json:"budget_tokens"` @@ -27,6 +66,34 @@ type ReasoningOptionMetadata struct { // Options implements the ProviderOptions interface. func (*ReasoningOptionMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ReasoningOptionMetadata. +func (m ReasoningOptionMetadata) MarshalJSON() ([]byte, error) { + type plain ReasoningOptionMetadata + raw, err := json.Marshal(plain(m)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeReasoningOptionMetadata, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ReasoningOptionMetadata. +func (m *ReasoningOptionMetadata) UnmarshalJSON(data []byte) error { + type plain ReasoningOptionMetadata + var rm plain + err := json.Unmarshal(data, &rm) + if err != nil { + return err + } + *m = ReasoningOptionMetadata(rm) + return nil +} + // ProviderCacheControlOptions represents cache control options for the Anthropic provider. type ProviderCacheControlOptions struct { CacheControl CacheControl `json:"cache_control"` @@ -35,6 +102,34 @@ type ProviderCacheControlOptions struct { // Options implements the ProviderOptions interface. func (*ProviderCacheControlOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderCacheControlOptions. +func (o ProviderCacheControlOptions) MarshalJSON() ([]byte, error) { + type plain ProviderCacheControlOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderCacheControl, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderCacheControlOptions. +func (o *ProviderCacheControlOptions) UnmarshalJSON(data []byte) error { + type plain ProviderCacheControlOptions + var cc plain + err := json.Unmarshal(data, &cc) + if err != nil { + return err + } + *o = ProviderCacheControlOptions(cc) + return nil +} + // CacheControl represents cache control settings for the Anthropic provider. type CacheControl struct { Type string `json:"type"` @@ -62,3 +157,28 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) { } return &options, nil } + +// Register Anthropic provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeReasoningOptionMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ReasoningOptionMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderCacheControl, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderCacheControlOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} diff --git a/providers/google/provider_options.go b/providers/google/provider_options.go index 4c645eeff255abec27cfe7d3a3aa128472fc81e7..1e9c399d07c0a8cf1752d95939ad6572fa309def 100644 --- a/providers/google/provider_options.go +++ b/providers/google/provider_options.go @@ -1,7 +1,16 @@ // Package google provides an implementation of the fantasy AI SDK for Google's language models. package google -import "charm.land/fantasy" +import ( + "encoding/json" + + "charm.land/fantasy" +) + +// Global type identifiers for Google-specific provider data. +const ( + TypeProviderOptions = Name + ".options" +) // ThinkingConfig represents thinking configuration for the Google provider. type ThinkingConfig struct { @@ -51,6 +60,34 @@ type ProviderOptions struct { // Options implements the ProviderOptionsData interface for ProviderOptions. func (o *ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var oo plain + err := json.Unmarshal(data, &oo) + if err != nil { + return err + } + *o = ProviderOptions(oo) + return nil +} + // ParseOptions parses provider options from a map for the Google provider. func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions @@ -59,3 +96,14 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) { } return &options, nil } + +// Register Google provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} diff --git a/providers/openai/provider_options.go b/providers/openai/provider_options.go index 9217c66277dc59da6ad53aecacf9efc976a8f052..04d3fbe2e30eadddf79f47417bb8548181787abb 100644 --- a/providers/openai/provider_options.go +++ b/providers/openai/provider_options.go @@ -2,6 +2,8 @@ package openai import ( + "encoding/json" + "charm.land/fantasy" "github.com/openai/openai-go/v2" ) @@ -20,6 +22,13 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) +// Global type identifiers for OpenAI-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeProviderFileOptions = Name + ".file_options" + TypeProviderMetadata = Name + ".metadata" +) + // ProviderMetadata represents additional metadata from OpenAI provider. type ProviderMetadata struct { Logprobs []openai.ChatCompletionTokenLogprob `json:"logprobs"` @@ -30,6 +39,34 @@ type ProviderMetadata struct { // Options implements the ProviderOptions interface. func (*ProviderMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata. +func (m ProviderMetadata) MarshalJSON() ([]byte, error) { + type plain ProviderMetadata + raw, err := json.Marshal(plain(m)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderMetadata, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata. +func (m *ProviderMetadata) UnmarshalJSON(data []byte) error { + type plain ProviderMetadata + var pm plain + err := json.Unmarshal(data, &pm) + if err != nil { + return err + } + *m = ProviderMetadata(pm) + return nil +} + // ProviderOptions represents additional options for OpenAI provider. type ProviderOptions struct { LogitBias map[string]int64 `json:"logit_bias"` @@ -52,6 +89,34 @@ type ProviderOptions struct { // Options implements the ProviderOptions interface. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var oo plain + err := json.Unmarshal(data, &oo) + if err != nil { + return err + } + *o = ProviderOptions(oo) + return nil +} + // ProviderFileOptions represents file options for OpenAI provider. type ProviderFileOptions struct { ImageDetail string `json:"image_detail"` @@ -60,6 +125,34 @@ type ProviderFileOptions struct { // Options implements the ProviderOptions interface. func (*ProviderFileOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderFileOptions. +func (o ProviderFileOptions) MarshalJSON() ([]byte, error) { + type plain ProviderFileOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderFileOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderFileOptions. +func (o *ProviderFileOptions) UnmarshalJSON(data []byte) error { + type plain ProviderFileOptions + var of plain + err := json.Unmarshal(data, &of) + if err != nil { + return err + } + *o = ProviderFileOptions(of) + return nil +} + // ReasoningEffortOption creates a pointer to a ReasoningEffort value. func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { return &e @@ -87,3 +180,28 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) { } return &options, nil } + +// Register OpenAI provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderFileOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderFileOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index e81dcb20292e3de43bd1b9df0b885efe6fee4a73..d00dc8b6fa00a6c4c650765984d6618ba673c9b5 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -2,11 +2,18 @@ package openai import ( + "encoding/json" "slices" "charm.land/fantasy" ) +// Global type identifiers for OpenAI Responses API-specific data. +const ( + TypeResponsesProviderOptions = Name + ".responses.options" + TypeResponsesReasoningMetadata = Name + ".responses.reasoning_metadata" +) + // ResponsesReasoningMetadata represents reasoning metadata for OpenAI Responses API. type ResponsesReasoningMetadata struct { ItemID string `json:"item_id"` @@ -17,6 +24,34 @@ type ResponsesReasoningMetadata struct { // Options implements the ProviderOptions interface. func (*ResponsesReasoningMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ResponsesReasoningMetadata. +func (m ResponsesReasoningMetadata) MarshalJSON() ([]byte, error) { + type plain ResponsesReasoningMetadata + raw, err := json.Marshal(plain(m)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeResponsesReasoningMetadata, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesReasoningMetadata. +func (m *ResponsesReasoningMetadata) UnmarshalJSON(data []byte) error { + type plain ResponsesReasoningMetadata + var rm plain + err := json.Unmarshal(data, &rm) + if err != nil { + return err + } + *m = ResponsesReasoningMetadata(rm) + return nil +} + // IncludeType represents the type of content to include for OpenAI Responses API. type IncludeType string @@ -71,6 +106,37 @@ type ResponsesProviderOptions struct { User *string `json:"user"` } +// Options implements the ProviderOptions interface. +func (*ResponsesProviderOptions) Options() {} + +// MarshalJSON implements custom JSON marshaling with type info for ResponsesProviderOptions. +func (o ResponsesProviderOptions) MarshalJSON() ([]byte, error) { + type plain ResponsesProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeResponsesProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesProviderOptions. +func (o *ResponsesProviderOptions) UnmarshalJSON(data []byte) error { + type plain ResponsesProviderOptions + var ro plain + err := json.Unmarshal(data, &ro) + if err != nil { + return err + } + *o = ResponsesProviderOptions(ro) + return nil +} + // responsesReasoningModelIds lists the model IDs that support reasoning for OpenAI Responses API. var responsesReasoningModelIDs = []string{ "o1", @@ -121,9 +187,6 @@ var responsesModelIDs = append([]string{ "gpt-5-chat-latest", }, responsesReasoningModelIDs...) -// Options implements the ProviderOptions interface. -func (*ResponsesProviderOptions) Options() {} - // NewResponsesProviderOptions creates new provider options for OpenAI Responses API. func NewResponsesProviderOptions(opts *ResponsesProviderOptions) fantasy.ProviderOptions { return fantasy.ProviderOptions{ @@ -149,3 +212,21 @@ func IsResponsesModel(modelID string) bool { func IsResponsesReasoningModel(modelID string) bool { return slices.Contains(responsesReasoningModelIDs, modelID) } + +// Register OpenAI Responses API-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeResponsesProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ResponsesProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeResponsesReasoningMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ResponsesReasoningMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} diff --git a/providers/openaicompat/provider_options.go b/providers/openaicompat/provider_options.go index 89dfc61b9a7be1eccb512eb3c682131b8963d299..1c81e92369982c286c10aae8fe896a0d2fdfe21c 100644 --- a/providers/openaicompat/provider_options.go +++ b/providers/openaicompat/provider_options.go @@ -2,10 +2,17 @@ package openaicompat import ( + "encoding/json" + "charm.land/fantasy" "charm.land/fantasy/providers/openai" ) +// Global type identifiers for OpenRouter-specific provider data. +const ( + TypeProviderOptions = Name + ".options" +) + // ProviderOptions represents additional options for the OpenAI-compatible provider. type ProviderOptions struct { User *string `json:"user"` @@ -20,6 +27,34 @@ type ReasoningData struct { // Options implements the ProviderOptions interface. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var oo plain + err := json.Unmarshal(data, &oo) + if err != nil { + return err + } + *o = ProviderOptions(oo) + return nil +} + // NewProviderOptions creates new provider options for the OpenAI-compatible provider. func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { return fantasy.ProviderOptions{ diff --git a/providers/openrouter/provider_options.go b/providers/openrouter/provider_options.go index 6e8b513cdacb962956385a73ac4493b43fd1ca71..876853ed93d77d78bdb047aa65d5dc3392adeac9 100644 --- a/providers/openrouter/provider_options.go +++ b/providers/openrouter/provider_options.go @@ -2,6 +2,8 @@ package openrouter import ( + "encoding/json" + "charm.land/fantasy" ) @@ -17,6 +19,12 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) +// Global type identifiers for OpenRouter-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeProviderMetadata = Name + ".metadata" +) + // PromptTokensDetails represents details about prompt tokens for OpenRouter. type PromptTokensDetails struct { CachedTokens int64 @@ -54,6 +62,34 @@ type ProviderMetadata struct { // Options implements the ProviderOptionsData interface for ProviderMetadata. func (*ProviderMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata. +func (m ProviderMetadata) MarshalJSON() ([]byte, error) { + type plain ProviderMetadata + raw, err := json.Marshal(plain(m)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderMetadata, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata. +func (m *ProviderMetadata) UnmarshalJSON(data []byte) error { + type plain ProviderMetadata + var pm plain + err := json.Unmarshal(data, &pm) + if err != nil { + return err + } + *m = ProviderMetadata(pm) + return nil +} + // ReasoningOptions represents reasoning options for OpenRouter. type ReasoningOptions struct { // Whether reasoning is enabled @@ -110,6 +146,34 @@ type ProviderOptions struct { // Options implements the ProviderOptionsData interface for ProviderOptions. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + raw, err := json.Marshal(plain(o)) + if err != nil { + return nil, err + } + return json.Marshal(struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + }{ + Type: TypeProviderOptions, + Data: raw, + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var oo plain + err := json.Unmarshal(data, &oo) + if err != nil { + return err + } + *o = ProviderOptions(oo) + return nil +} + // ReasoningDetail represents a reasoning detail for OpenRouter. type ReasoningDetail struct { ID string `json:"id,omitempty"` @@ -148,3 +212,21 @@ func ParseOptions(data map[string]any) (*ProviderOptions, error) { } return &options, nil } + +// Register OpenRouter provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} diff --git a/providertests/provider_registry_test.go b/providertests/provider_registry_test.go new file mode 100644 index 0000000000000000000000000000000000000000..55bd72674e9f9b7bc3da2ab68728b974c827da4c --- /dev/null +++ b/providertests/provider_registry_test.go @@ -0,0 +1,140 @@ +package providertests + +import ( + "encoding/json" + "testing" + + "charm.land/fantasy" + "charm.land/fantasy/providers/openai" + "github.com/stretchr/testify/require" +) + +func TestProviderRegistry_Serialization_OpenAIOptions(t *testing.T) { + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hi"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openai.Name: &openai.ProviderOptions{User: fantasy.Opt("tester")}, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var raw struct { + ProviderOptions map[string]map[string]any `json:"provider_options"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + + po, ok := raw.ProviderOptions[openai.Name] + require.True(t, ok) + require.Equal(t, openai.TypeProviderOptions, po["type"]) // no magic strings + // ensure inner data has the field we set + inner, ok := po["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "tester", inner["user"]) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[openai.Name] + require.True(t, ok) + opt, ok := got.(*openai.ProviderOptions) + require.True(t, ok) + require.NotNil(t, opt.User) + require.Equal(t, "tester", *opt.User) +} + +func TestProviderRegistry_Serialization_OpenAIResponses(t *testing.T) { + // Use ResponsesProviderOptions in provider options + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openai.Name: &openai.ResponsesProviderOptions{ + PromptCacheKey: fantasy.Opt("cache-key-1"), + ParallelToolCalls: fantasy.Opt(true), + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + // JSON should include the typed wrapper with constant TypeResponsesProviderOptions + var raw struct { + ProviderOptions map[string]map[string]any `json:"provider_options"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + + po := raw.ProviderOptions[openai.Name] + require.Equal(t, openai.TypeResponsesProviderOptions, po["type"]) // no magic strings + inner, ok := po["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "cache-key-1", inner["prompt_cache_key"]) + require.Equal(t, true, inner["parallel_tool_calls"]) + + // Unmarshal back and assert concrete type + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + got := decoded.ProviderOptions[openai.Name] + reqOpts, ok := got.(*openai.ResponsesProviderOptions) + require.True(t, ok) + require.NotNil(t, reqOpts.PromptCacheKey) + require.Equal(t, "cache-key-1", *reqOpts.PromptCacheKey) + require.NotNil(t, reqOpts.ParallelToolCalls) + require.Equal(t, true, *reqOpts.ParallelToolCalls) +} + +func TestProviderRegistry_Serialization_OpenAIResponsesReasoningMetadata(t *testing.T) { + resp := fantasy.Response{ + Content: []fantasy.Content{ + fantasy.TextContent{ + Text: "", + ProviderMetadata: fantasy.ProviderMetadata{ + openai.Name: &openai.ResponsesReasoningMetadata{ + ItemID: "item-123", + Summary: []string{"part1", "part2"}, + }, + }, + }, + }, + } + + data, err := json.Marshal(resp) + require.NoError(t, err) + + // Ensure the provider metadata is wrapped with type using constant + var raw struct { + Content []struct { + Type string `json:"type"` + Data map[string]any `json:"data"` + } `json:"content"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + require.Greater(t, len(raw.Content), 0) + tc := raw.Content[0] + pm, ok := tc.Data["provider_metadata"].(map[string]any) + require.True(t, ok) + om, ok := pm[openai.Name].(map[string]any) + require.True(t, ok) + require.Equal(t, openai.TypeResponsesReasoningMetadata, om["type"]) // no magic strings + inner, ok := om["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "item-123", inner["item_id"]) + + // Unmarshal back + var decoded fantasy.Response + require.NoError(t, json.Unmarshal(data, &decoded)) + pmDecoded := decoded.Content[0].(fantasy.TextContent).ProviderMetadata + val, ok := pmDecoded[openai.Name] + require.True(t, ok) + meta, ok := val.(*openai.ResponsesReasoningMetadata) + require.True(t, ok) + require.Equal(t, "item-123", meta.ItemID) + require.Equal(t, []string{"part1", "part2"}, meta.Summary) +}