1package fantasy
2
3import (
4 "encoding/json"
5 "fmt"
6)
7
8// UnmarshalJSON implements json.Unmarshaler for Call.
9func (c *Call) UnmarshalJSON(data []byte) error {
10 var aux struct {
11 Prompt Prompt `json:"prompt"`
12 MaxOutputTokens *int64 `json:"max_output_tokens"`
13 Temperature *float64 `json:"temperature"`
14 TopP *float64 `json:"top_p"`
15 TopK *int64 `json:"top_k"`
16 PresencePenalty *float64 `json:"presence_penalty"`
17 FrequencyPenalty *float64 `json:"frequency_penalty"`
18 Tools []json.RawMessage `json:"tools"`
19 ToolChoice *ToolChoice `json:"tool_choice"`
20 ProviderOptions map[string]json.RawMessage `json:"provider_options"`
21 }
22
23 if err := json.Unmarshal(data, &aux); err != nil {
24 return err
25 }
26
27 c.Prompt = aux.Prompt
28 c.MaxOutputTokens = aux.MaxOutputTokens
29 c.Temperature = aux.Temperature
30 c.TopP = aux.TopP
31 c.TopK = aux.TopK
32 c.PresencePenalty = aux.PresencePenalty
33 c.FrequencyPenalty = aux.FrequencyPenalty
34 c.ToolChoice = aux.ToolChoice
35
36 // Unmarshal Tools slice
37 c.Tools = make([]Tool, len(aux.Tools))
38 for i, rawTool := range aux.Tools {
39 tool, err := UnmarshalTool(rawTool)
40 if err != nil {
41 return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err)
42 }
43 c.Tools[i] = tool
44 }
45
46 // Unmarshal ProviderOptions
47 if len(aux.ProviderOptions) > 0 {
48 options, err := UnmarshalProviderOptions(aux.ProviderOptions)
49 if err != nil {
50 return err
51 }
52 c.ProviderOptions = options
53 }
54
55 return nil
56}
57
58// UnmarshalJSON implements json.Unmarshaler for Response.
59func (r *Response) UnmarshalJSON(data []byte) error {
60 var aux struct {
61 Content json.RawMessage `json:"content"`
62 FinishReason FinishReason `json:"finish_reason"`
63 Usage Usage `json:"usage"`
64 Warnings []CallWarning `json:"warnings"`
65 ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
66 }
67
68 if err := json.Unmarshal(data, &aux); err != nil {
69 return err
70 }
71
72 r.FinishReason = aux.FinishReason
73 r.Usage = aux.Usage
74 r.Warnings = aux.Warnings
75
76 // Unmarshal ResponseContent (need to know the type definition)
77 // If ResponseContent is []Content:
78 var rawContent []json.RawMessage
79 if err := json.Unmarshal(aux.Content, &rawContent); err != nil {
80 return err
81 }
82
83 content := make([]Content, len(rawContent))
84 for i, rawItem := range rawContent {
85 item, err := UnmarshalContent(rawItem)
86 if err != nil {
87 return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err)
88 }
89 content[i] = item
90 }
91 r.Content = content
92
93 // Unmarshal ProviderMetadata
94 if len(aux.ProviderMetadata) > 0 {
95 metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
96 if err != nil {
97 return err
98 }
99 r.ProviderMetadata = metadata
100 }
101
102 return nil
103}
104
105// MarshalJSON implements json.Marshaler for StreamPart.
106func (s StreamPart) MarshalJSON() ([]byte, error) {
107 type alias StreamPart
108 aux := struct {
109 alias
110 Error string `json:"error,omitempty"`
111 }{
112 alias: (alias)(s),
113 }
114
115 // Marshal error to string
116 if s.Error != nil {
117 aux.Error = s.Error.Error()
118 }
119
120 // Clear the original Error field to avoid duplicate marshaling
121 aux.alias.Error = nil
122
123 return json.Marshal(aux)
124}
125
126// UnmarshalJSON implements json.Unmarshaler for StreamPart.
127func (s *StreamPart) UnmarshalJSON(data []byte) error {
128 type alias StreamPart
129 aux := struct {
130 *alias
131 Error string `json:"error"`
132 ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
133 }{
134 alias: (*alias)(s),
135 }
136
137 if err := json.Unmarshal(data, &aux); err != nil {
138 return err
139 }
140
141 // Unmarshal error string back to error type
142 if aux.Error != "" {
143 s.Error = fmt.Errorf("%s", aux.Error)
144 }
145
146 // Unmarshal ProviderMetadata
147 if len(aux.ProviderMetadata) > 0 {
148 metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
149 if err != nil {
150 return err
151 }
152 s.ProviderMetadata = metadata
153 }
154
155 return nil
156}