1package fantasy
2
3import (
4 "encoding/json"
5 "fmt"
6)
7
8// UnmarshalJSON implements json.Unmarshaler for Call.
9func (c *Call) UnmarshalJSON(data []byte) error {
10 var aux struct {
11 Prompt Prompt `json:"prompt"`
12 MaxOutputTokens *int64 `json:"max_output_tokens"`
13 Temperature *float64 `json:"temperature"`
14 TopP *float64 `json:"top_p"`
15 TopK *int64 `json:"top_k"`
16 PresencePenalty *float64 `json:"presence_penalty"`
17 FrequencyPenalty *float64 `json:"frequency_penalty"`
18 Tools []json.RawMessage `json:"tools"`
19 ToolChoice *ToolChoice `json:"tool_choice"`
20 ProviderOptions map[string]json.RawMessage `json:"provider_options"`
21 }
22
23 if err := json.Unmarshal(data, &aux); err != nil {
24 return err
25 }
26
27 c.Prompt = aux.Prompt
28 c.MaxOutputTokens = aux.MaxOutputTokens
29 c.Temperature = aux.Temperature
30 c.TopP = aux.TopP
31 c.TopK = aux.TopK
32 c.PresencePenalty = aux.PresencePenalty
33 c.FrequencyPenalty = aux.FrequencyPenalty
34 c.ToolChoice = aux.ToolChoice
35
36 // Unmarshal Tools slice
37 c.Tools = make([]Tool, len(aux.Tools))
38 for i, rawTool := range aux.Tools {
39 tool, err := UnmarshalTool(rawTool)
40 if err != nil {
41 return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err)
42 }
43 c.Tools[i] = tool
44 }
45
46 // Unmarshal ProviderOptions
47 if len(aux.ProviderOptions) > 0 {
48 options, err := UnmarshalProviderOptions(aux.ProviderOptions)
49 if err != nil {
50 return err
51 }
52 c.ProviderOptions = options
53 }
54
55 return nil
56}
57
58// UnmarshalJSON implements json.Unmarshaler for Response.
59func (r *Response) UnmarshalJSON(data []byte) error {
60 var aux struct {
61 Content json.RawMessage `json:"content"`
62 FinishReason FinishReason `json:"finish_reason"`
63 Usage Usage `json:"usage"`
64 Warnings []CallWarning `json:"warnings"`
65 ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
66 }
67
68 if err := json.Unmarshal(data, &aux); err != nil {
69 return err
70 }
71
72 r.FinishReason = aux.FinishReason
73 r.Usage = aux.Usage
74 r.Warnings = aux.Warnings
75
76 // Unmarshal ResponseContent (need to know the type definition)
77 // If ResponseContent is []Content:
78 var rawContent []json.RawMessage
79 if err := json.Unmarshal(aux.Content, &rawContent); err != nil {
80 return err
81 }
82
83 content := make([]Content, len(rawContent))
84 for i, rawItem := range rawContent {
85 item, err := UnmarshalContent(rawItem)
86 if err != nil {
87 return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err)
88 }
89 content[i] = item
90 }
91 r.Content = content
92
93 // Unmarshal ProviderMetadata
94 if len(aux.ProviderMetadata) > 0 {
95 metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
96 if err != nil {
97 return err
98 }
99 r.ProviderMetadata = metadata
100 }
101
102 return nil
103}
104
105// UnmarshalJSON implements json.Unmarshaler for StreamPart.
106func (s *StreamPart) UnmarshalJSON(data []byte) error {
107 var aux struct {
108 Type StreamPartType `json:"type"`
109 ID string `json:"id"`
110 ToolCallName string `json:"tool_call_name"`
111 ToolCallInput string `json:"tool_call_input"`
112 Delta string `json:"delta"`
113 ProviderExecuted bool `json:"provider_executed"`
114 Usage Usage `json:"usage"`
115 FinishReason FinishReason `json:"finish_reason"`
116 Error error `json:"error"`
117 Warnings []CallWarning `json:"warnings"`
118 SourceType SourceType `json:"source_type"`
119 URL string `json:"url"`
120 Title string `json:"title"`
121 ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
122 }
123
124 if err := json.Unmarshal(data, &aux); err != nil {
125 return err
126 }
127
128 s.Type = aux.Type
129 s.ID = aux.ID
130 s.ToolCallName = aux.ToolCallName
131 s.ToolCallInput = aux.ToolCallInput
132 s.Delta = aux.Delta
133 s.ProviderExecuted = aux.ProviderExecuted
134 s.Usage = aux.Usage
135 s.FinishReason = aux.FinishReason
136 s.Error = aux.Error
137 s.Warnings = aux.Warnings
138 s.SourceType = aux.SourceType
139 s.URL = aux.URL
140 s.Title = aux.Title
141
142 // Unmarshal ProviderMetadata
143 if len(aux.ProviderMetadata) > 0 {
144 metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
145 if err != nil {
146 return err
147 }
148 s.ProviderMetadata = metadata
149 }
150
151 return nil
152}