From c999bfd68de475e48fbff6ebb939be1b38290bfe Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 13 Nov 2025 11:21:38 +0100 Subject: [PATCH] feat: make structs serializable (#34) --- content.go | 49 + content_json.go | 1062 ++++++++++++++++++++ json_test.go | 647 ++++++++++++ model_json.go | 152 +++ provider_registry.go | 106 ++ providers/anthropic/provider_options.go | 89 +- providers/google/provider_options.go | 64 +- providers/openai/provider_options.go | 85 ++ providers/openai/responses_options.go | 65 +- providers/openaicompat/provider_options.go | 35 + providers/openrouter/provider_options.go | 64 +- providertests/provider_registry_test.go | 421 ++++++++ tool.go | 8 +- 13 files changed, 2836 insertions(+), 11 deletions(-) create mode 100644 content_json.go create mode 100644 json_test.go create mode 100644 model_json.go create mode 100644 provider_registry.go create mode 100644 providertests/provider_registry_test.go diff --git a/content.go b/content.go index 0dda6c5c296a0e8ca9a0320a74c03a3c18ab3f5e..93dc7d1ca87962a9351062def5fb616d022c3c36 100644 --- a/content.go +++ b/content.go @@ -1,8 +1,57 @@ package fantasy +import "encoding/json" + // ProviderOptionsData is an interface for provider-specific options data. +// All implementations MUST also implement encoding/json.Marshaler and +// encoding/json.Unmarshaler interfaces to ensure proper JSON serialization +// with the provider registry system. +// +// Recommended implementation pattern using generic helpers: +// +// // Define type constants at the top of your file +// const TypeMyProviderOptions = "myprovider.options" +// +// type MyProviderOptions struct { +// Field string `json:"field"` +// } +// +// // Register the type in init() - place at top of file after constants +// func init() { +// fantasy.RegisterProviderType(TypeMyProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { +// var opts MyProviderOptions +// if err := json.Unmarshal(data, &opts); err != nil { +// return nil, err +// } +// return &opts, nil +// }) +// } +// +// // Implement ProviderOptionsData interface +// func (*MyProviderOptions) Options() {} +// +// // Implement json.Marshaler using the generic helper +// func (m MyProviderOptions) MarshalJSON() ([]byte, error) { +// type plain MyProviderOptions +// return fantasy.MarshalProviderType(TypeMyProviderOptions, plain(m)) +// } +// +// // Implement json.Unmarshaler using the generic helper +// // Note: Receives inner data after type routing by the registry. +// func (m *MyProviderOptions) UnmarshalJSON(data []byte) error { +// type plain MyProviderOptions +// var p plain +// if err := fantasy.UnmarshalProviderType(data, &p); err != nil { +// return err +// } +// *m = MyProviderOptions(p) +// return nil +// } type ProviderOptionsData interface { + // Options is a marker method that identifies types implementing this interface. Options() + json.Marshaler + json.Unmarshaler } // ProviderMetadata represents additional provider-specific metadata. diff --git a/content_json.go b/content_json.go new file mode 100644 index 0000000000000000000000000000000000000000..bde7393ec85a92d4a7939e0e2bb1f22368fa6bb7 --- /dev/null +++ b/content_json.go @@ -0,0 +1,1062 @@ +package fantasy + +import ( + "encoding/json" + "errors" + "fmt" +) + +// contentJSON is a helper type for JSON serialization of Content in Response. +type contentJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// messagePartJSON is a helper type for JSON serialization of MessagePart. +type messagePartJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// toolResultOutputJSON is a helper type for JSON serialization of ToolResultOutputContent. +type toolResultOutputJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// toolJSON is a helper type for JSON serialization of Tool. +type toolJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// MarshalJSON implements json.Marshaler for TextContent. +func (t TextContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + Text: t.Text, + ProviderMetadata: t.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for TextContent. +func (t *TextContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.Text = aux.Text + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ReasoningContent. +func (r ReasoningContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + Text: r.Text, + ProviderMetadata: r.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeReasoning), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ReasoningContent. +func (r *ReasoningContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + r.Text = aux.Text + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + r.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for FileContent. +func (f FileContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + MediaType string `json:"media_type"` + Data []byte `json:"data"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + MediaType: f.MediaType, + Data: f.Data, + ProviderMetadata: f.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeFile), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for FileContent. +func (f *FileContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + MediaType string `json:"media_type"` + Data []byte `json:"data"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + f.MediaType = aux.MediaType + f.Data = aux.Data + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + f.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for SourceContent. +func (s SourceContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + SourceType SourceType `json:"source_type"` + ID string `json:"id"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + MediaType string `json:"media_type,omitempty"` + Filename string `json:"filename,omitempty"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + SourceType: s.SourceType, + ID: s.ID, + URL: s.URL, + Title: s.Title, + MediaType: s.MediaType, + Filename: s.Filename, + ProviderMetadata: s.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeSource), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for SourceContent. +func (s *SourceContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + SourceType SourceType `json:"source_type"` + ID string `json:"id"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + MediaType string `json:"media_type,omitempty"` + Filename string `json:"filename,omitempty"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + s.SourceType = aux.SourceType + s.ID = aux.ID + s.URL = aux.URL + s.Title = aux.Title + s.MediaType = aux.MediaType + s.Filename = aux.Filename + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + s.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ToolCallContent. +func (t ToolCallContent) MarshalJSON() ([]byte, error) { + var validationErrMsg *string + if t.ValidationError != nil { + msg := t.ValidationError.Error() + validationErrMsg = &msg + } + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + Invalid bool `json:"invalid,omitempty"` + ValidationError *string `json:"validation_error,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Input: t.Input, + ProviderExecuted: t.ProviderExecuted, + ProviderMetadata: t.ProviderMetadata, + Invalid: t.Invalid, + ValidationError: validationErrMsg, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeToolCall), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolCallContent. +func (t *ToolCallContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + Invalid bool `json:"invalid,omitempty"` + ValidationError *string `json:"validation_error,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.Input = aux.Input + t.ProviderExecuted = aux.ProviderExecuted + t.Invalid = aux.Invalid + if aux.ValidationError != nil { + t.ValidationError = errors.New(*aux.ValidationError) + } + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ToolResultContent. +func (t ToolResultContent) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Result ToolResultOutputContent `json:"result"` + ClientMetadata string `json:"client_metadata,omitempty"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata ProviderMetadata `json:"provider_metadata,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Result: t.Result, + ClientMetadata: t.ClientMetadata, + ProviderExecuted: t.ProviderExecuted, + ProviderMetadata: t.ProviderMetadata, + }) + if err != nil { + return nil, err + } + + return json.Marshal(contentJSON{ + Type: string(ContentTypeToolResult), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolResultContent. +func (t *ToolResultContent) UnmarshalJSON(data []byte) error { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Result json.RawMessage `json:"result"` + ClientMetadata string `json:"client_metadata,omitempty"` + ProviderExecuted bool `json:"provider_executed"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata,omitempty"` + } + + if err := json.Unmarshal(cj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.ClientMetadata = aux.ClientMetadata + t.ProviderExecuted = aux.ProviderExecuted + + // Unmarshal the Result field + result, err := UnmarshalToolResultOutputContent(aux.Result) + if err != nil { + return fmt.Errorf("failed to unmarshal tool result output: %w", err) + } + t.Result = result + + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + t.ProviderMetadata = metadata + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ToolResultOutputContentText. +func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + }{ + Text: t.Text, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolResultOutputContentText. +func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Text string `json:"text"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + + t.Text = temp.Text + return nil +} + +// MarshalJSON implements json.Marshaler for ToolResultOutputContentError. +func (t ToolResultOutputContentError) MarshalJSON() ([]byte, error) { + errMsg := "" + if t.Error != nil { + errMsg = t.Error.Error() + } + dataBytes, err := json.Marshal(struct { + Error string `json:"error"` + }{ + Error: errMsg, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeError), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolResultOutputContentError. +func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Error string `json:"error"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + if temp.Error != "" { + t.Error = errors.New(temp.Error) + } + return nil +} + +// MarshalJSON implements json.Marshaler for ToolResultOutputContentMedia. +func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Data string `json:"data"` + MediaType string `json:"media_type"` + }{ + Data: t.Data, + MediaType: t.MediaType, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolResultOutputJSON{ + Type: string(ToolResultContentTypeMedia), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolResultOutputContentMedia. +func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error { + var tr toolResultOutputJSON + if err := json.Unmarshal(data, &tr); err != nil { + return err + } + + var temp struct { + Data string `json:"data"` + MediaType string `json:"media_type"` + } + + if err := json.Unmarshal(tr.Data, &temp); err != nil { + return err + } + + t.Data = temp.Data + t.MediaType = temp.MediaType + return nil +} + +// MarshalJSON implements json.Marshaler for TextPart. +func (t TextPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Text: t.Text, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeText), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for TextPart. +func (t *TextPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.Text = aux.Text + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ReasoningPart. +func (r ReasoningPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Text string `json:"text"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Text: r.Text, + ProviderOptions: r.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeReasoning), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ReasoningPart. +func (r *ReasoningPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Text string `json:"text"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + r.Text = aux.Text + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + r.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for FilePart. +func (f FilePart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Filename string `json:"filename"` + Data []byte `json:"data"` + MediaType string `json:"media_type"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Filename: f.Filename, + Data: f.Data, + MediaType: f.MediaType, + ProviderOptions: f.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeFile), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for FilePart. +func (f *FilePart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + Filename string `json:"filename"` + Data []byte `json:"data"` + MediaType string `json:"media_type"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + f.Filename = aux.Filename + f.Data = aux.Data + f.MediaType = aux.MediaType + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + f.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ToolCallPart. +func (t ToolCallPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + ToolCallID: t.ToolCallID, + ToolName: t.ToolName, + Input: t.Input, + ProviderExecuted: t.ProviderExecuted, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeToolCall), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolCallPart. +func (t *ToolCallPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + t.ToolName = aux.ToolName + t.Input = aux.Input + t.ProviderExecuted = aux.ProviderExecuted + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ToolResultPart. +func (t ToolResultPart) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ToolCallID string `json:"tool_call_id"` + Output ToolResultOutputContent `json:"output"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + ToolCallID: t.ToolCallID, + Output: t.Output, + ProviderOptions: t.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(messagePartJSON{ + Type: string(ContentTypeToolResult), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ToolResultPart. +func (t *ToolResultPart) UnmarshalJSON(data []byte) error { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return err + } + + var aux struct { + ToolCallID string `json:"tool_call_id"` + Output json.RawMessage `json:"output"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(mpj.Data, &aux); err != nil { + return err + } + + t.ToolCallID = aux.ToolCallID + + // Unmarshal the Output field + output, err := UnmarshalToolResultOutputContent(aux.Output) + if err != nil { + return fmt.Errorf("failed to unmarshal tool result output: %w", err) + } + t.Output = output + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + t.ProviderOptions = options + } + + return nil +} + +// UnmarshalJSON implements json.Unmarshaler for Message. +func (m *Message) UnmarshalJSON(data []byte) error { + var aux struct { + Role MessageRole `json:"role"` + Content []json.RawMessage `json:"content"` + ProviderOptions map[string]json.RawMessage `json:"provider_options"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + m.Role = aux.Role + + m.Content = make([]MessagePart, len(aux.Content)) + for i, rawPart := range aux.Content { + part, err := UnmarshalMessagePart(rawPart) + if err != nil { + return fmt.Errorf("failed to unmarshal message part at index %d: %w", i, err) + } + m.Content[i] = part + } + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + m.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for FunctionTool. +func (f FunctionTool) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + ProviderOptions ProviderOptions `json:"provider_options,omitempty"` + }{ + Name: f.Name, + Description: f.Description, + InputSchema: f.InputSchema, + ProviderOptions: f.ProviderOptions, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolJSON{ + Type: string(ToolTypeFunction), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for FunctionTool. +func (f *FunctionTool) UnmarshalJSON(data []byte) error { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return err + } + + var aux struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"` + } + + if err := json.Unmarshal(tj.Data, &aux); err != nil { + return err + } + + f.Name = aux.Name + f.Description = aux.Description + f.InputSchema = aux.InputSchema + + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + f.ProviderOptions = options + } + + return nil +} + +// MarshalJSON implements json.Marshaler for ProviderDefinedTool. +func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) { + dataBytes, err := json.Marshal(struct { + ID string `json:"id"` + Name string `json:"name"` + Args map[string]any `json:"args"` + }{ + ID: p.ID, + Name: p.Name, + Args: p.Args, + }) + if err != nil { + return nil, err + } + + return json.Marshal(toolJSON{ + Type: string(ToolTypeProviderDefined), + Data: json.RawMessage(dataBytes), + }) +} + +// UnmarshalJSON implements json.Unmarshaler for ProviderDefinedTool. +func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return err + } + + var aux struct { + ID string `json:"id"` + Name string `json:"name"` + Args map[string]any `json:"args"` + } + + if err := json.Unmarshal(tj.Data, &aux); err != nil { + return err + } + + p.ID = aux.ID + p.Name = aux.Name + p.Args = aux.Args + + return nil +} + +// UnmarshalTool unmarshals JSON into the appropriate Tool type. +func UnmarshalTool(data []byte) (Tool, error) { + var tj toolJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, err + } + + switch ToolType(tj.Type) { + case ToolTypeFunction: + var tool FunctionTool + if err := tool.UnmarshalJSON(data); err != nil { + return nil, err + } + return tool, nil + case ToolTypeProviderDefined: + var tool ProviderDefinedTool + if err := tool.UnmarshalJSON(data); err != nil { + return nil, err + } + return tool, nil + default: + return nil, fmt.Errorf("unknown tool type: %s", tj.Type) + } +} + +// UnmarshalContent unmarshals JSON into the appropriate Content type. +func UnmarshalContent(data []byte) (Content, error) { + var cj contentJSON + if err := json.Unmarshal(data, &cj); err != nil { + return nil, err + } + + switch ContentType(cj.Type) { + case ContentTypeText: + var content TextContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeReasoning: + var content ReasoningContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeFile: + var content FileContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeSource: + var content SourceContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeToolCall: + var content ToolCallContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ContentTypeToolResult: + var content ToolResultContent + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + default: + return nil, fmt.Errorf("unknown content type: %s", cj.Type) + } +} + +// UnmarshalMessagePart unmarshals JSON into the appropriate MessagePart type. +func UnmarshalMessagePart(data []byte) (MessagePart, error) { + var mpj messagePartJSON + if err := json.Unmarshal(data, &mpj); err != nil { + return nil, err + } + + switch ContentType(mpj.Type) { + case ContentTypeText: + var part TextPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeReasoning: + var part ReasoningPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeFile: + var part FilePart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeToolCall: + var part ToolCallPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + case ContentTypeToolResult: + var part ToolResultPart + if err := part.UnmarshalJSON(data); err != nil { + return nil, err + } + return part, nil + default: + return nil, fmt.Errorf("unknown message part type: %s", mpj.Type) + } +} + +// UnmarshalToolResultOutputContent unmarshals JSON into the appropriate ToolResultOutputContent type. +func UnmarshalToolResultOutputContent(data []byte) (ToolResultOutputContent, error) { + var troj toolResultOutputJSON + if err := json.Unmarshal(data, &troj); err != nil { + return nil, err + } + + switch ToolResultContentType(troj.Type) { + case ToolResultContentTypeText: + var content ToolResultOutputContentText + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ToolResultContentTypeError: + var content ToolResultOutputContentError + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + case ToolResultContentTypeMedia: + var content ToolResultOutputContentMedia + if err := content.UnmarshalJSON(data); err != nil { + return nil, err + } + return content, nil + default: + return nil, fmt.Errorf("unknown tool result output content type: %s", troj.Type) + } +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000000000000000000000000000000000000..da9f491407e284dbc988c59595d79aa631dfd524 --- /dev/null +++ b/json_test.go @@ -0,0 +1,647 @@ +package fantasy + +import ( + "encoding/json" + "errors" + "reflect" + "testing" +) + +func TestMessageJSONSerialization(t *testing.T) { + tests := []struct { + name string + message Message + }{ + { + name: "simple text message", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Hello, world!"}, + }, + }, + }, + { + name: "message with multiple text parts", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "First part"}, + TextPart{Text: "Second part"}, + TextPart{Text: "Third part"}, + }, + }, + }, + { + name: "message with reasoning part", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + ReasoningPart{Text: "Let me think about this..."}, + TextPart{Text: "Here's my answer"}, + }, + }, + }, + { + name: "message with file part", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Here's an image:"}, + FilePart{ + Filename: "test.png", + Data: []byte{0x89, 0x50, 0x4E, 0x47}, // PNG header + MediaType: "image/png", + }, + }, + }, + }, + { + name: "message with tool call", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + ToolCallPart{ + ToolCallID: "call_123", + ToolName: "get_weather", + Input: `{"location": "San Francisco"}`, + ProviderExecuted: false, + }, + }, + }, + }, + { + name: "message with tool result - text output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_123", + Output: ToolResultOutputContentText{ + Text: "The weather is sunny, 72°F", + }, + }, + }, + }, + }, + { + name: "message with tool result - error output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_456", + Output: ToolResultOutputContentError{ + Error: errors.New("API rate limit exceeded"), + }, + }, + }, + }, + }, + { + name: "message with tool result - media output", + message: Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_789", + Output: ToolResultOutputContentMedia{ + Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", + MediaType: "image/png", + }, + }, + }, + }, + }, + { + name: "complex message with mixed content", + message: Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "I'll analyze this image and call some tools."}, + ReasoningPart{Text: "First, I need to identify the objects..."}, + ToolCallPart{ + ToolCallID: "call_001", + ToolName: "analyze_image", + Input: `{"image_id": "img_123"}`, + ProviderExecuted: false, + }, + ToolCallPart{ + ToolCallID: "call_002", + ToolName: "get_context", + Input: `{"query": "similar images"}`, + ProviderExecuted: true, + }, + }, + }, + }, + { + name: "system message", + message: Message{ + Role: MessageRoleSystem, + Content: []MessagePart{ + TextPart{Text: "You are a helpful assistant."}, + }, + }, + }, + { + name: "empty content", + message: Message{ + Role: MessageRoleUser, + Content: []MessagePart{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal the message + data, err := json.Marshal(tt.message) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + + // Unmarshal back + var decoded Message + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("failed to unmarshal message: %v", err) + } + + // Compare roles + if decoded.Role != tt.message.Role { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, tt.message.Role) + } + + // Compare content length + if len(decoded.Content) != len(tt.message.Content) { + t.Fatalf("content length mismatch: got %d, want %d", len(decoded.Content), len(tt.message.Content)) + } + + // Compare each content part + for i := range tt.message.Content { + original := tt.message.Content[i] + decodedPart := decoded.Content[i] + + if original.GetType() != decodedPart.GetType() { + t.Errorf("content[%d] type mismatch: got %v, want %v", i, decodedPart.GetType(), original.GetType()) + continue + } + + compareMessagePart(t, i, original, decodedPart) + } + }) + } +} + +func compareMessagePart(t *testing.T, index int, original, decoded MessagePart) { + switch original.GetType() { + case ContentTypeText: + orig := original.(TextPart) + dec := decoded.(TextPart) + if orig.Text != dec.Text { + t.Errorf("content[%d] text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ContentTypeReasoning: + orig := original.(ReasoningPart) + dec := decoded.(ReasoningPart) + if orig.Text != dec.Text { + t.Errorf("content[%d] reasoning text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ContentTypeFile: + orig := original.(FilePart) + dec := decoded.(FilePart) + if orig.Filename != dec.Filename { + t.Errorf("content[%d] filename mismatch: got %q, want %q", index, dec.Filename, orig.Filename) + } + if orig.MediaType != dec.MediaType { + t.Errorf("content[%d] media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType) + } + if !reflect.DeepEqual(orig.Data, dec.Data) { + t.Errorf("content[%d] file data mismatch", index) + } + + case ContentTypeToolCall: + orig := original.(ToolCallPart) + dec := decoded.(ToolCallPart) + if orig.ToolCallID != dec.ToolCallID { + t.Errorf("content[%d] tool call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID) + } + if orig.ToolName != dec.ToolName { + t.Errorf("content[%d] tool name mismatch: got %q, want %q", index, dec.ToolName, orig.ToolName) + } + if orig.Input != dec.Input { + t.Errorf("content[%d] tool input mismatch: got %q, want %q", index, dec.Input, orig.Input) + } + if orig.ProviderExecuted != dec.ProviderExecuted { + t.Errorf("content[%d] provider executed mismatch: got %v, want %v", index, dec.ProviderExecuted, orig.ProviderExecuted) + } + + case ContentTypeToolResult: + orig := original.(ToolResultPart) + dec := decoded.(ToolResultPart) + if orig.ToolCallID != dec.ToolCallID { + t.Errorf("content[%d] tool result call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID) + } + compareToolResultOutput(t, index, orig.Output, dec.Output) + } +} + +func compareToolResultOutput(t *testing.T, index int, original, decoded ToolResultOutputContent) { + if original.GetType() != decoded.GetType() { + t.Errorf("content[%d] tool result output type mismatch: got %v, want %v", index, decoded.GetType(), original.GetType()) + return + } + + switch original.GetType() { + case ToolResultContentTypeText: + orig := original.(ToolResultOutputContentText) + dec := decoded.(ToolResultOutputContentText) + if orig.Text != dec.Text { + t.Errorf("content[%d] tool result text mismatch: got %q, want %q", index, dec.Text, orig.Text) + } + + case ToolResultContentTypeError: + orig := original.(ToolResultOutputContentError) + dec := decoded.(ToolResultOutputContentError) + if orig.Error.Error() != dec.Error.Error() { + t.Errorf("content[%d] tool result error mismatch: got %q, want %q", index, dec.Error.Error(), orig.Error.Error()) + } + + case ToolResultContentTypeMedia: + orig := original.(ToolResultOutputContentMedia) + dec := decoded.(ToolResultOutputContentMedia) + if orig.Data != dec.Data { + t.Errorf("content[%d] tool result media data mismatch", index) + } + if orig.MediaType != dec.MediaType { + t.Errorf("content[%d] tool result media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType) + } + } +} + +func TestHelperFunctions(t *testing.T) { + t.Run("NewUserMessage - text only", func(t *testing.T) { + msg := NewUserMessage("Hello") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.Role != MessageRoleUser { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleUser) + } + + if len(decoded.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(decoded.Content)) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Hello" { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Hello") + } + }) + + t.Run("NewUserMessage - with files", func(t *testing.T) { + msg := NewUserMessage("Check this image", + FilePart{ + Filename: "image1.jpg", + Data: []byte{0xFF, 0xD8, 0xFF}, + MediaType: "image/jpeg", + }, + FilePart{ + Filename: "image2.png", + Data: []byte{0x89, 0x50, 0x4E, 0x47}, + MediaType: "image/png", + }, + ) + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(decoded.Content) != 3 { + t.Fatalf("expected 3 content parts, got %d", len(decoded.Content)) + } + + // Check text part + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Check this image" { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Check this image") + } + + // Check first file + file1 := decoded.Content[1].(FilePart) + if file1.Filename != "image1.jpg" { + t.Errorf("file1 name mismatch: got %q, want %q", file1.Filename, "image1.jpg") + } + + // Check second file + file2 := decoded.Content[2].(FilePart) + if file2.Filename != "image2.png" { + t.Errorf("file2 name mismatch: got %q, want %q", file2.Filename, "image2.png") + } + }) + + t.Run("NewSystemMessage - single prompt", func(t *testing.T) { + msg := NewSystemMessage("You are a helpful assistant.") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.Role != MessageRoleSystem { + t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleSystem) + } + + if len(decoded.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(decoded.Content)) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "You are a helpful assistant." { + t.Errorf("text mismatch: got %q, want %q", textPart.Text, "You are a helpful assistant.") + } + }) + + t.Run("NewSystemMessage - multiple prompts", func(t *testing.T) { + msg := NewSystemMessage("First instruction", "Second instruction", "Third instruction") + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(decoded.Content) != 3 { + t.Fatalf("expected 3 content parts, got %d", len(decoded.Content)) + } + + expected := []string{"First instruction", "Second instruction", "Third instruction"} + for i, exp := range expected { + textPart := decoded.Content[i].(TextPart) + if textPart.Text != exp { + t.Errorf("content[%d] text mismatch: got %q, want %q", i, textPart.Text, exp) + } + } + }) +} + +func TestEdgeCases(t *testing.T) { + t.Run("empty text part", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: ""}, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "" { + t.Errorf("expected empty text, got %q", textPart.Text) + } + }) + + t.Run("nil error in tool result", func(t *testing.T) { + msg := Message{ + Role: MessageRoleTool, + Content: []MessagePart{ + ToolResultPart{ + ToolCallID: "call_123", + Output: ToolResultOutputContentError{ + Error: nil, + }, + }, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + toolResult := decoded.Content[0].(ToolResultPart) + errorOutput := toolResult.Output.(ToolResultOutputContentError) + if errorOutput.Error != nil { + t.Errorf("expected nil error, got %v", errorOutput.Error) + } + }) + + t.Run("empty file data", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + FilePart{ + Filename: "empty.txt", + Data: []byte{}, + MediaType: "text/plain", + }, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + filePart := decoded.Content[0].(FilePart) + if len(filePart.Data) != 0 { + t.Errorf("expected empty data, got %d bytes", len(filePart.Data)) + } + }) + + t.Run("unicode in text", func(t *testing.T) { + msg := Message{ + Role: MessageRoleUser, + Content: []MessagePart{ + TextPart{Text: "Hello 世界! 🌍 Привет"}, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded Message + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + textPart := decoded.Content[0].(TextPart) + if textPart.Text != "Hello 世界! 🌍 Привет" { + t.Errorf("unicode text mismatch: got %q, want %q", textPart.Text, "Hello 世界! 🌍 Привет") + } + }) +} + +func TestInvalidJSONHandling(t *testing.T) { + t.Run("unknown message part type", func(t *testing.T) { + invalidJSON := `{ + "role": "user", + "content": [ + { + "type": "unknown-type", + "data": {} + } + ], + "provider_options": null + }` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for unknown message part type, got nil") + } + }) + + t.Run("unknown tool result output type", func(t *testing.T) { + invalidJSON := `{ + "role": "tool", + "content": [ + { + "type": "tool-result", + "data": { + "tool_call_id": "call_123", + "output": { + "type": "unknown-output-type", + "data": {} + }, + "provider_options": null + } + } + ], + "provider_options": null + }` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for unknown tool result output type, got nil") + } + }) + + t.Run("malformed JSON", func(t *testing.T) { + invalidJSON := `{"role": "user", "content": [` + + var msg Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + if err == nil { + t.Error("expected error for malformed JSON, got nil") + } + }) +} + +// Mock provider data for testing provider options +type mockProviderData struct { + Key string `json:"key"` +} + +func (m mockProviderData) Options() {} +func (m mockProviderData) Type() string { return "mock" } +func (m mockProviderData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + mockProviderData + }{ + Type: "mock", + mockProviderData: m, + }) +} + +func (m *mockProviderData) UnmarshalJSON(data []byte) error { + var aux struct { + Type string `json:"type"` + mockProviderData + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *m = aux.mockProviderData + return nil +} + +func TestPromptSerialization(t *testing.T) { + t.Run("serialize prompt (message slice)", func(t *testing.T) { + prompt := Prompt{ + NewSystemMessage("You are helpful"), + NewUserMessage("Hello"), + Message{ + Role: MessageRoleAssistant, + Content: []MessagePart{ + TextPart{Text: "Hi there!"}, + }, + }, + } + + data, err := json.Marshal(prompt) + if err != nil { + t.Fatalf("failed to marshal prompt: %v", err) + } + + var decoded Prompt + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal prompt: %v", err) + } + + if len(decoded) != 3 { + t.Fatalf("expected 3 messages, got %d", len(decoded)) + } + + if decoded[0].Role != MessageRoleSystem { + t.Errorf("message 0 role mismatch: got %v, want %v", decoded[0].Role, MessageRoleSystem) + } + + if decoded[1].Role != MessageRoleUser { + t.Errorf("message 1 role mismatch: got %v, want %v", decoded[1].Role, MessageRoleUser) + } + + if decoded[2].Role != MessageRoleAssistant { + t.Errorf("message 2 role mismatch: got %v, want %v", decoded[2].Role, MessageRoleAssistant) + } + }) +} diff --git a/model_json.go b/model_json.go new file mode 100644 index 0000000000000000000000000000000000000000..90a8c78520977ae8dc797c2c0575f168297917ca --- /dev/null +++ b/model_json.go @@ -0,0 +1,152 @@ +package fantasy + +import ( + "encoding/json" + "fmt" +) + +// UnmarshalJSON implements json.Unmarshaler for Call. +func (c *Call) UnmarshalJSON(data []byte) error { + var aux struct { + Prompt Prompt `json:"prompt"` + MaxOutputTokens *int64 `json:"max_output_tokens"` + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int64 `json:"top_k"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + Tools []json.RawMessage `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice"` + ProviderOptions map[string]json.RawMessage `json:"provider_options"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + c.Prompt = aux.Prompt + c.MaxOutputTokens = aux.MaxOutputTokens + c.Temperature = aux.Temperature + c.TopP = aux.TopP + c.TopK = aux.TopK + c.PresencePenalty = aux.PresencePenalty + c.FrequencyPenalty = aux.FrequencyPenalty + c.ToolChoice = aux.ToolChoice + + // Unmarshal Tools slice + c.Tools = make([]Tool, len(aux.Tools)) + for i, rawTool := range aux.Tools { + tool, err := UnmarshalTool(rawTool) + if err != nil { + return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err) + } + c.Tools[i] = tool + } + + // Unmarshal ProviderOptions + if len(aux.ProviderOptions) > 0 { + options, err := UnmarshalProviderOptions(aux.ProviderOptions) + if err != nil { + return err + } + c.ProviderOptions = options + } + + return nil +} + +// UnmarshalJSON implements json.Unmarshaler for Response. +func (r *Response) UnmarshalJSON(data []byte) error { + var aux struct { + Content json.RawMessage `json:"content"` + FinishReason FinishReason `json:"finish_reason"` + Usage Usage `json:"usage"` + Warnings []CallWarning `json:"warnings"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + r.FinishReason = aux.FinishReason + r.Usage = aux.Usage + r.Warnings = aux.Warnings + + // Unmarshal ResponseContent (need to know the type definition) + // If ResponseContent is []Content: + var rawContent []json.RawMessage + if err := json.Unmarshal(aux.Content, &rawContent); err != nil { + return err + } + + content := make([]Content, len(rawContent)) + for i, rawItem := range rawContent { + item, err := UnmarshalContent(rawItem) + if err != nil { + return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err) + } + content[i] = item + } + r.Content = content + + // Unmarshal ProviderMetadata + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + r.ProviderMetadata = metadata + } + + return nil +} + +// UnmarshalJSON implements json.Unmarshaler for StreamPart. +func (s *StreamPart) UnmarshalJSON(data []byte) error { + var aux struct { + Type StreamPartType `json:"type"` + ID string `json:"id"` + ToolCallName string `json:"tool_call_name"` + ToolCallInput string `json:"tool_call_input"` + Delta string `json:"delta"` + ProviderExecuted bool `json:"provider_executed"` + Usage Usage `json:"usage"` + FinishReason FinishReason `json:"finish_reason"` + Error error `json:"error"` + Warnings []CallWarning `json:"warnings"` + SourceType SourceType `json:"source_type"` + URL string `json:"url"` + Title string `json:"title"` + ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"` + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + s.Type = aux.Type + s.ID = aux.ID + s.ToolCallName = aux.ToolCallName + s.ToolCallInput = aux.ToolCallInput + s.Delta = aux.Delta + s.ProviderExecuted = aux.ProviderExecuted + s.Usage = aux.Usage + s.FinishReason = aux.FinishReason + s.Error = aux.Error + s.Warnings = aux.Warnings + s.SourceType = aux.SourceType + s.URL = aux.URL + s.Title = aux.Title + + // Unmarshal ProviderMetadata + if len(aux.ProviderMetadata) > 0 { + metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata) + if err != nil { + return err + } + s.ProviderMetadata = metadata + } + + return nil +} diff --git a/provider_registry.go b/provider_registry.go new file mode 100644 index 0000000000000000000000000000000000000000..36c96c447e53222063a00d703dc1722ba8c6433c --- /dev/null +++ b/provider_registry.go @@ -0,0 +1,106 @@ +package fantasy + +import ( + "encoding/json" + "fmt" + "sync" +) + +// providerDataJSON is the serialized wrapper used by the registry. +type providerDataJSON struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +// UnmarshalFunc converts raw JSON into a ProviderOptionsData implementation. +type UnmarshalFunc func([]byte) (ProviderOptionsData, error) + +// providerRegistry uses sync.Map for lock-free reads after initialization. +// All registrations happen in init() functions before concurrent access. +var providerRegistry sync.Map + +// RegisterProviderType registers a provider type ID with its unmarshal function. +// Type IDs must be globally unique (e.g. "openai.options"). +// This should only be called during package initialization (init functions). +func RegisterProviderType(typeID string, unmarshalFn UnmarshalFunc) { + providerRegistry.Store(typeID, unmarshalFn) +} + +// unmarshalProviderData routes a typed payload to the correct constructor. +func unmarshalProviderData(data []byte) (ProviderOptionsData, error) { + var pj providerDataJSON + if err := json.Unmarshal(data, &pj); err != nil { + return nil, err + } + + val, exists := providerRegistry.Load(pj.Type) + if !exists { + return nil, fmt.Errorf("unknown provider data type: %s", pj.Type) + } + + unmarshalFn := val.(UnmarshalFunc) + return unmarshalFn(pj.Data) +} + +// unmarshalProviderDataMap is a helper for unmarshaling maps of provider data. +func unmarshalProviderDataMap(data map[string]json.RawMessage) (map[string]ProviderOptionsData, error) { + result := make(map[string]ProviderOptionsData) + for provider, rawData := range data { + providerData, err := unmarshalProviderData(rawData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal provider data for %s: %w", provider, err) + } + result[provider] = providerData + } + return result, nil +} + +// UnmarshalProviderOptions unmarshals a map of provider options by type. +func UnmarshalProviderOptions(data map[string]json.RawMessage) (ProviderOptions, error) { + return unmarshalProviderDataMap(data) +} + +// UnmarshalProviderMetadata unmarshals a map of provider metadata by type. +func UnmarshalProviderMetadata(data map[string]json.RawMessage) (ProviderMetadata, error) { + return unmarshalProviderDataMap(data) +} + +// MarshalProviderType marshals provider data with a type wrapper using generics. +// To avoid infinite recursion, use the "type plain T" pattern before calling this. +// +// Usage in provider types: +// +// func (o ProviderOptions) MarshalJSON() ([]byte, error) { +// type plain ProviderOptions +// return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +// } +func MarshalProviderType[T any](typeID string, data T) ([]byte, error) { + rawData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + return json.Marshal(providerDataJSON{ + Type: typeID, + Data: json.RawMessage(rawData), + }) +} + +// UnmarshalProviderType unmarshals provider data without type wrapper using generics. +// To avoid infinite recursion, unmarshal to a plain type first. +// Note: This receives the inner 'data' field after type routing by the registry. +// +// Usage in provider types: +// +// func (o *ProviderOptions) UnmarshalJSON(data []byte) error { +// type plain ProviderOptions +// var p plain +// if err := fantasy.UnmarshalProviderType(data, &p); err != nil { +// return err +// } +// *o = ProviderOptions(p) +// return nil +// } +func UnmarshalProviderType[T any](data []byte, target *T) error { + return json.Unmarshal(data, target) +} diff --git a/providers/anthropic/provider_options.go b/providers/anthropic/provider_options.go index 905a4bdbead91b8f6622889745b1caf1255e4897..7c426f59a7235cdc656f55fa18ae8ba71a7f5ae3 100644 --- a/providers/anthropic/provider_options.go +++ b/providers/anthropic/provider_options.go @@ -1,7 +1,43 @@ // Package anthropic provides an implementation of the fantasy AI SDK for Anthropic's language models. package anthropic -import "charm.land/fantasy" +import ( + "encoding/json" + + "charm.land/fantasy" +) + +// Global type identifiers for Anthropic-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeReasoningOptionMetadata = Name + ".reasoning_metadata" + TypeProviderCacheControl = Name + ".cache_control_options" +) + +// Register Anthropic provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeReasoningOptionMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ReasoningOptionMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderCacheControl, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderCacheControlOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} // ProviderOptions represents additional options for the Anthropic provider. type ProviderOptions struct { @@ -13,6 +49,23 @@ type ProviderOptions struct { // Options implements the ProviderOptions interface. func (o *ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderOptions(p) + return nil +} + // ThinkingProviderOption represents thinking options for the Anthropic provider. type ThinkingProviderOption struct { BudgetTokens int64 `json:"budget_tokens"` @@ -27,6 +80,23 @@ type ReasoningOptionMetadata struct { // Options implements the ProviderOptions interface. func (*ReasoningOptionMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ReasoningOptionMetadata. +func (m ReasoningOptionMetadata) MarshalJSON() ([]byte, error) { + type plain ReasoningOptionMetadata + return fantasy.MarshalProviderType(TypeReasoningOptionMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ReasoningOptionMetadata. +func (m *ReasoningOptionMetadata) UnmarshalJSON(data []byte) error { + type plain ReasoningOptionMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = ReasoningOptionMetadata(p) + return nil +} + // ProviderCacheControlOptions represents cache control options for the Anthropic provider. type ProviderCacheControlOptions struct { CacheControl CacheControl `json:"cache_control"` @@ -35,6 +105,23 @@ type ProviderCacheControlOptions struct { // Options implements the ProviderOptions interface. func (*ProviderCacheControlOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderCacheControlOptions. +func (o ProviderCacheControlOptions) MarshalJSON() ([]byte, error) { + type plain ProviderCacheControlOptions + return fantasy.MarshalProviderType(TypeProviderCacheControl, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderCacheControlOptions. +func (o *ProviderCacheControlOptions) UnmarshalJSON(data []byte) error { + type plain ProviderCacheControlOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderCacheControlOptions(p) + return nil +} + // CacheControl represents cache control settings for the Anthropic provider. type CacheControl struct { Type string `json:"type"` diff --git a/providers/google/provider_options.go b/providers/google/provider_options.go index c86ecffa998abd29d781fe763d85df25815a8afa..7d45563c030e3d23a842d047969baeb81d70fcd3 100644 --- a/providers/google/provider_options.go +++ b/providers/google/provider_options.go @@ -1,7 +1,35 @@ // Package google provides an implementation of the fantasy AI SDK for Google's language models. package google -import "charm.land/fantasy" +import ( + "encoding/json" + + "charm.land/fantasy" +) + +// Global type identifiers for Google-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeReasoningMetadata = Name + ".reasoning_metadata" +) + +// Register Google provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeReasoningMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ReasoningMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} // ThinkingConfig represents thinking configuration for the Google provider. type ThinkingConfig struct { @@ -17,6 +45,23 @@ type ReasoningMetadata struct { // Options implements the ProviderOptionsData interface for ReasoningMetadata. func (m *ReasoningMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ReasoningMetadata. +func (m ReasoningMetadata) MarshalJSON() ([]byte, error) { + type plain ReasoningMetadata + return fantasy.MarshalProviderType(TypeReasoningMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ReasoningMetadata. +func (m *ReasoningMetadata) UnmarshalJSON(data []byte) error { + type plain ReasoningMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = ReasoningMetadata(p) + return nil +} + // SafetySetting represents safety settings for the Google provider. type SafetySetting struct { // 'HARM_CATEGORY_UNSPECIFIED', @@ -59,6 +104,23 @@ type ProviderOptions struct { // Options implements the ProviderOptionsData interface for ProviderOptions. func (o *ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderOptions(p) + return nil +} + // ParseOptions parses provider options from a map for the Google provider. func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions diff --git a/providers/openai/provider_options.go b/providers/openai/provider_options.go index 9217c66277dc59da6ad53aecacf9efc976a8f052..adb02fbd135af77239092ba1b25e254329e0efcf 100644 --- a/providers/openai/provider_options.go +++ b/providers/openai/provider_options.go @@ -2,6 +2,8 @@ package openai import ( + "encoding/json" + "charm.land/fantasy" "github.com/openai/openai-go/v2" ) @@ -20,6 +22,38 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) +// Global type identifiers for OpenAI-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeProviderFileOptions = Name + ".file_options" + TypeProviderMetadata = Name + ".metadata" +) + +// Register OpenAI provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderFileOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderFileOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} + // ProviderMetadata represents additional metadata from OpenAI provider. type ProviderMetadata struct { Logprobs []openai.ChatCompletionTokenLogprob `json:"logprobs"` @@ -30,6 +64,23 @@ type ProviderMetadata struct { // Options implements the ProviderOptions interface. func (*ProviderMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata. +func (m ProviderMetadata) MarshalJSON() ([]byte, error) { + type plain ProviderMetadata + return fantasy.MarshalProviderType(TypeProviderMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata. +func (m *ProviderMetadata) UnmarshalJSON(data []byte) error { + type plain ProviderMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = ProviderMetadata(p) + return nil +} + // ProviderOptions represents additional options for OpenAI provider. type ProviderOptions struct { LogitBias map[string]int64 `json:"logit_bias"` @@ -52,6 +103,23 @@ type ProviderOptions struct { // Options implements the ProviderOptions interface. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderOptions(p) + return nil +} + // ProviderFileOptions represents file options for OpenAI provider. type ProviderFileOptions struct { ImageDetail string `json:"image_detail"` @@ -60,6 +128,23 @@ type ProviderFileOptions struct { // Options implements the ProviderOptions interface. func (*ProviderFileOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderFileOptions. +func (o ProviderFileOptions) MarshalJSON() ([]byte, error) { + type plain ProviderFileOptions + return fantasy.MarshalProviderType(TypeProviderFileOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderFileOptions. +func (o *ProviderFileOptions) UnmarshalJSON(data []byte) error { + type plain ProviderFileOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderFileOptions(p) + return nil +} + // ReasoningEffortOption creates a pointer to a ReasoningEffort value. func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { return &e diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index e81dcb20292e3de43bd1b9df0b885efe6fee4a73..88dba7f42ae5af4aa4304aff64ccb2bb15c86525 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -2,11 +2,36 @@ package openai import ( + "encoding/json" "slices" "charm.land/fantasy" ) +// Global type identifiers for OpenAI Responses API-specific data. +const ( + TypeResponsesProviderOptions = Name + ".responses.options" + TypeResponsesReasoningMetadata = Name + ".responses.reasoning_metadata" +) + +// Register OpenAI Responses API-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeResponsesProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ResponsesProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeResponsesReasoningMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ResponsesReasoningMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} + // ResponsesReasoningMetadata represents reasoning metadata for OpenAI Responses API. type ResponsesReasoningMetadata struct { ItemID string `json:"item_id"` @@ -17,6 +42,23 @@ type ResponsesReasoningMetadata struct { // Options implements the ProviderOptions interface. func (*ResponsesReasoningMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ResponsesReasoningMetadata. +func (m ResponsesReasoningMetadata) MarshalJSON() ([]byte, error) { + type plain ResponsesReasoningMetadata + return fantasy.MarshalProviderType(TypeResponsesReasoningMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesReasoningMetadata. +func (m *ResponsesReasoningMetadata) UnmarshalJSON(data []byte) error { + type plain ResponsesReasoningMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = ResponsesReasoningMetadata(p) + return nil +} + // IncludeType represents the type of content to include for OpenAI Responses API. type IncludeType string @@ -71,6 +113,26 @@ type ResponsesProviderOptions struct { User *string `json:"user"` } +// Options implements the ProviderOptions interface. +func (*ResponsesProviderOptions) Options() {} + +// MarshalJSON implements custom JSON marshaling with type info for ResponsesProviderOptions. +func (o ResponsesProviderOptions) MarshalJSON() ([]byte, error) { + type plain ResponsesProviderOptions + return fantasy.MarshalProviderType(TypeResponsesProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ResponsesProviderOptions. +func (o *ResponsesProviderOptions) UnmarshalJSON(data []byte) error { + type plain ResponsesProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ResponsesProviderOptions(p) + return nil +} + // responsesReasoningModelIds lists the model IDs that support reasoning for OpenAI Responses API. var responsesReasoningModelIDs = []string{ "o1", @@ -121,9 +183,6 @@ var responsesModelIDs = append([]string{ "gpt-5-chat-latest", }, responsesReasoningModelIDs...) -// Options implements the ProviderOptions interface. -func (*ResponsesProviderOptions) Options() {} - // NewResponsesProviderOptions creates new provider options for OpenAI Responses API. func NewResponsesProviderOptions(opts *ResponsesProviderOptions) fantasy.ProviderOptions { return fantasy.ProviderOptions{ diff --git a/providers/openaicompat/provider_options.go b/providers/openaicompat/provider_options.go index 89dfc61b9a7be1eccb512eb3c682131b8963d299..afb037bf21e51d8698e4b51bc6f85a9ff99f242b 100644 --- a/providers/openaicompat/provider_options.go +++ b/providers/openaicompat/provider_options.go @@ -2,10 +2,28 @@ package openaicompat import ( + "encoding/json" + "charm.land/fantasy" "charm.land/fantasy/providers/openai" ) +// Global type identifiers for OpenAI-compatible provider data. +const ( + TypeProviderOptions = Name + ".options" +) + +// Register OpenAI-compatible provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} + // ProviderOptions represents additional options for the OpenAI-compatible provider. type ProviderOptions struct { User *string `json:"user"` @@ -20,6 +38,23 @@ type ReasoningData struct { // Options implements the ProviderOptions interface. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderOptions(p) + return nil +} + // NewProviderOptions creates new provider options for the OpenAI-compatible provider. func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { return fantasy.ProviderOptions{ diff --git a/providers/openrouter/provider_options.go b/providers/openrouter/provider_options.go index 6e8b513cdacb962956385a73ac4493b43fd1ca71..ed2d8f5edd714150b49301527b3dcb1e55184422 100644 --- a/providers/openrouter/provider_options.go +++ b/providers/openrouter/provider_options.go @@ -2,6 +2,8 @@ package openrouter import ( + "encoding/json" + "charm.land/fantasy" ) @@ -17,14 +19,38 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) +// Global type identifiers for OpenRouter-specific provider data. +const ( + TypeProviderOptions = Name + ".options" + TypeProviderMetadata = Name + ".metadata" +) + +// Register OpenRouter provider-specific types with the global registry. +func init() { + fantasy.RegisterProviderType(TypeProviderOptions, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderOptions + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) + fantasy.RegisterProviderType(TypeProviderMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v ProviderMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) +} + // PromptTokensDetails represents details about prompt tokens for OpenRouter. type PromptTokensDetails struct { - CachedTokens int64 + CachedTokens int64 `json:"cached_tokens"` } // CompletionTokensDetails represents details about completion tokens for OpenRouter. type CompletionTokensDetails struct { - ReasoningTokens int64 + ReasoningTokens int64 `json:"reasoning_tokens"` } // CostDetails represents cost details for OpenRouter. @@ -54,6 +80,23 @@ type ProviderMetadata struct { // Options implements the ProviderOptionsData interface for ProviderMetadata. func (*ProviderMetadata) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderMetadata. +func (m ProviderMetadata) MarshalJSON() ([]byte, error) { + type plain ProviderMetadata + return fantasy.MarshalProviderType(TypeProviderMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderMetadata. +func (m *ProviderMetadata) UnmarshalJSON(data []byte) error { + type plain ProviderMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = ProviderMetadata(p) + return nil +} + // ReasoningOptions represents reasoning options for OpenRouter. type ReasoningOptions struct { // Whether reasoning is enabled @@ -110,6 +153,23 @@ type ProviderOptions struct { // Options implements the ProviderOptionsData interface for ProviderOptions. func (*ProviderOptions) Options() {} +// MarshalJSON implements custom JSON marshaling with type info for ProviderOptions. +func (o ProviderOptions) MarshalJSON() ([]byte, error) { + type plain ProviderOptions + return fantasy.MarshalProviderType(TypeProviderOptions, plain(o)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for ProviderOptions. +func (o *ProviderOptions) UnmarshalJSON(data []byte) error { + type plain ProviderOptions + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *o = ProviderOptions(p) + return nil +} + // ReasoningDetail represents a reasoning detail for OpenRouter. type ReasoningDetail struct { ID string `json:"id,omitempty"` diff --git a/providertests/provider_registry_test.go b/providertests/provider_registry_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e798e489430949357d05c28f1573e9ac89315ea9 --- /dev/null +++ b/providertests/provider_registry_test.go @@ -0,0 +1,421 @@ +package providertests + +import ( + "encoding/json" + "testing" + + "charm.land/fantasy" + "charm.land/fantasy/providers/anthropic" + "charm.land/fantasy/providers/google" + "charm.land/fantasy/providers/openai" + "charm.land/fantasy/providers/openaicompat" + "charm.land/fantasy/providers/openrouter" + "github.com/stretchr/testify/require" +) + +func TestProviderRegistry_Serialization_OpenAIOptions(t *testing.T) { + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hi"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openai.Name: &openai.ProviderOptions{User: fantasy.Opt("tester")}, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var raw struct { + ProviderOptions map[string]map[string]any `json:"provider_options"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + + po, ok := raw.ProviderOptions[openai.Name] + require.True(t, ok) + require.Equal(t, openai.TypeProviderOptions, po["type"]) // no magic strings + // ensure inner data has the field we set + inner, ok := po["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "tester", inner["user"]) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[openai.Name] + require.True(t, ok) + opt, ok := got.(*openai.ProviderOptions) + require.True(t, ok) + require.NotNil(t, opt.User) + require.Equal(t, "tester", *opt.User) +} + +func TestProviderRegistry_Serialization_OpenAIResponses(t *testing.T) { + // Use ResponsesProviderOptions in provider options + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openai.Name: &openai.ResponsesProviderOptions{ + PromptCacheKey: fantasy.Opt("cache-key-1"), + ParallelToolCalls: fantasy.Opt(true), + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + // JSON should include the typed wrapper with constant TypeResponsesProviderOptions + var raw struct { + ProviderOptions map[string]map[string]any `json:"provider_options"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + + po := raw.ProviderOptions[openai.Name] + require.Equal(t, openai.TypeResponsesProviderOptions, po["type"]) // no magic strings + inner, ok := po["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "cache-key-1", inner["prompt_cache_key"]) + require.Equal(t, true, inner["parallel_tool_calls"]) + + // Unmarshal back and assert concrete type + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + got := decoded.ProviderOptions[openai.Name] + reqOpts, ok := got.(*openai.ResponsesProviderOptions) + require.True(t, ok) + require.NotNil(t, reqOpts.PromptCacheKey) + require.Equal(t, "cache-key-1", *reqOpts.PromptCacheKey) + require.NotNil(t, reqOpts.ParallelToolCalls) + require.Equal(t, true, *reqOpts.ParallelToolCalls) +} + +func TestProviderRegistry_Serialization_OpenAIResponsesReasoningMetadata(t *testing.T) { + resp := fantasy.Response{ + Content: []fantasy.Content{ + fantasy.TextContent{ + Text: "", + ProviderMetadata: fantasy.ProviderMetadata{ + openai.Name: &openai.ResponsesReasoningMetadata{ + ItemID: "item-123", + Summary: []string{"part1", "part2"}, + }, + }, + }, + }, + } + + data, err := json.Marshal(resp) + require.NoError(t, err) + + // Ensure the provider metadata is wrapped with type using constant + var raw struct { + Content []struct { + Type string `json:"type"` + Data map[string]any `json:"data"` + } `json:"content"` + } + require.NoError(t, json.Unmarshal(data, &raw)) + require.Greater(t, len(raw.Content), 0) + tc := raw.Content[0] + pm, ok := tc.Data["provider_metadata"].(map[string]any) + require.True(t, ok) + om, ok := pm[openai.Name].(map[string]any) + require.True(t, ok) + require.Equal(t, openai.TypeResponsesReasoningMetadata, om["type"]) // no magic strings + inner, ok := om["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, "item-123", inner["item_id"]) + + // Unmarshal back + var decoded fantasy.Response + require.NoError(t, json.Unmarshal(data, &decoded)) + pmDecoded := decoded.Content[0].(fantasy.TextContent).ProviderMetadata + val, ok := pmDecoded[openai.Name] + require.True(t, ok) + meta, ok := val.(*openai.ResponsesReasoningMetadata) + require.True(t, ok) + require.Equal(t, "item-123", meta.ItemID) + require.Equal(t, []string{"part1", "part2"}, meta.Summary) +} + +func TestProviderRegistry_Serialization_AnthropicOptions(t *testing.T) { + sendReasoning := true + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test message"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + anthropic.Name: &anthropic.ProviderOptions{ + SendReasoning: &sendReasoning, + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[anthropic.Name] + require.True(t, ok) + opt, ok := got.(*anthropic.ProviderOptions) + require.True(t, ok) + require.NotNil(t, opt.SendReasoning) + require.Equal(t, true, *opt.SendReasoning) +} + +func TestProviderRegistry_Serialization_GoogleOptions(t *testing.T) { + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test message"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + google.Name: &google.ProviderOptions{ + CachedContent: "cached-123", + Threshold: "BLOCK_ONLY_HIGH", + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[google.Name] + require.True(t, ok) + opt, ok := got.(*google.ProviderOptions) + require.True(t, ok) + require.Equal(t, "cached-123", opt.CachedContent) + require.Equal(t, "BLOCK_ONLY_HIGH", opt.Threshold) +} + +func TestProviderRegistry_Serialization_OpenRouterOptions(t *testing.T) { + includeUsage := true + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test message"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openrouter.Name: &openrouter.ProviderOptions{ + IncludeUsage: &includeUsage, + User: fantasy.Opt("test-user"), + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[openrouter.Name] + require.True(t, ok) + opt, ok := got.(*openrouter.ProviderOptions) + require.True(t, ok) + require.NotNil(t, opt.IncludeUsage) + require.Equal(t, true, *opt.IncludeUsage) + require.NotNil(t, opt.User) + require.Equal(t, "test-user", *opt.User) +} + +func TestProviderRegistry_Serialization_OpenAICompatOptions(t *testing.T) { + effort := openai.ReasoningEffortHigh + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test message"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openaicompat.Name: &openaicompat.ProviderOptions{ + User: fantasy.Opt("test-user"), + ReasoningEffort: &effort, + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + got, ok := decoded.ProviderOptions[openaicompat.Name] + require.True(t, ok) + opt, ok := got.(*openaicompat.ProviderOptions) + require.True(t, ok) + require.NotNil(t, opt.User) + require.Equal(t, "test-user", *opt.User) + require.NotNil(t, opt.ReasoningEffort) + require.Equal(t, openai.ReasoningEffortHigh, *opt.ReasoningEffort) +} + +func TestProviderRegistry_MultiProvider(t *testing.T) { + // Test with multiple providers in one message + sendReasoning := true + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + openai.Name: &openai.ProviderOptions{User: fantasy.Opt("user1")}, + anthropic.Name: &anthropic.ProviderOptions{ + SendReasoning: &sendReasoning, + }, + }, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + // Check OpenAI options + openaiOpt, ok := decoded.ProviderOptions[openai.Name] + require.True(t, ok) + openaiData, ok := openaiOpt.(*openai.ProviderOptions) + require.True(t, ok) + require.Equal(t, "user1", *openaiData.User) + + // Check Anthropic options + anthropicOpt, ok := decoded.ProviderOptions[anthropic.Name] + require.True(t, ok) + anthropicData, ok := anthropicOpt.(*anthropic.ProviderOptions) + require.True(t, ok) + require.Equal(t, true, *anthropicData.SendReasoning) +} + +func TestProviderRegistry_ErrorHandling(t *testing.T) { + t.Run("unknown provider type", func(t *testing.T) { + invalidJSON := `{ + "role": "user", + "content": [{"type": "text", "data": {"text": "hi"}}], + "provider_options": { + "unknown": { + "type": "unknown.provider.type", + "data": {} + } + } + }` + + var msg fantasy.Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown provider data type") + }) + + t.Run("malformed provider data", func(t *testing.T) { + invalidJSON := `{ + "role": "user", + "content": [{"type": "text", "data": {"text": "hi"}}], + "provider_options": { + "openai": "not-an-object" + } + }` + + var msg fantasy.Message + err := json.Unmarshal([]byte(invalidJSON), &msg) + require.Error(t, err) + }) +} + +func TestProviderRegistry_AllTypesRegistered(t *testing.T) { + // Verify all expected provider types are registered + // We test that unmarshaling with proper type IDs doesn't fail with "unknown provider data type" + tests := []struct { + name string + providerName string + data fantasy.ProviderOptionsData + }{ + {"OpenAI Options", openai.Name, &openai.ProviderOptions{}}, + {"OpenAI File Options", openai.Name, &openai.ProviderFileOptions{}}, + {"OpenAI Metadata", openai.Name, &openai.ProviderMetadata{}}, + {"OpenAI Responses Options", openai.Name, &openai.ResponsesProviderOptions{}}, + {"Anthropic Options", anthropic.Name, &anthropic.ProviderOptions{}}, + {"Google Options", google.Name, &google.ProviderOptions{}}, + {"OpenRouter Options", openrouter.Name, &openrouter.ProviderOptions{}}, + {"OpenAICompat Options", openaicompat.Name, &openaicompat.ProviderOptions{}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a message with the provider options + msg := fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "test"}, + }, + ProviderOptions: fantasy.ProviderOptions{ + tc.providerName: tc.data, + }, + } + + // Marshal and unmarshal + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded fantasy.Message + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + // Verify the provider options exist + _, ok := decoded.ProviderOptions[tc.providerName] + require.True(t, ok, "Provider options should be present after round-trip") + }) + } + + // Test metadata types separately as they go in different field + metadataTests := []struct { + name string + providerName string + data fantasy.ProviderOptionsData + }{ + {"OpenAI Responses Reasoning Metadata", openai.Name, &openai.ResponsesReasoningMetadata{}}, + {"Anthropic Reasoning Metadata", anthropic.Name, &anthropic.ReasoningOptionMetadata{}}, + {"Google Reasoning Metadata", google.Name, &google.ReasoningMetadata{}}, + {"OpenRouter Metadata", openrouter.Name, &openrouter.ProviderMetadata{}}, + } + + for _, tc := range metadataTests { + t.Run(tc.name, func(t *testing.T) { + // Create a response with provider metadata + resp := fantasy.Response{ + Content: []fantasy.Content{ + fantasy.TextContent{ + Text: "test", + ProviderMetadata: fantasy.ProviderMetadata{ + tc.providerName: tc.data, + }, + }, + }, + } + + // Marshal and unmarshal + data, err := json.Marshal(resp) + require.NoError(t, err) + + var decoded fantasy.Response + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + // Verify the provider metadata exists + textContent, ok := decoded.Content[0].(fantasy.TextContent) + require.True(t, ok) + _, ok = textContent.ProviderMetadata[tc.providerName] + require.True(t, ok, "Provider metadata should be present after round-trip") + }) + } +} diff --git a/tool.go b/tool.go index 9dbff68584b0f7bebb64d58df6929d42ff40801c..9731739ceaf8fcf5dc666f3f0e11a8ab22d5013f 100644 --- a/tool.go +++ b/tool.go @@ -14,10 +14,10 @@ type Schema = schema.Schema // ToolInfo represents tool metadata, matching the existing pattern. type ToolInfo struct { - Name string - Description string - Parameters map[string]any - Required []string + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` + Required []string `json:"required"` } // ToolCall represents a tool invocation, matching the existing pattern.