model_json.go

  1package fantasy
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6)
  7
  8// UnmarshalJSON implements json.Unmarshaler for Call.
  9func (c *Call) UnmarshalJSON(data []byte) error {
 10	var aux struct {
 11		Prompt           Prompt                     `json:"prompt"`
 12		MaxOutputTokens  *int64                     `json:"max_output_tokens"`
 13		Temperature      *float64                   `json:"temperature"`
 14		TopP             *float64                   `json:"top_p"`
 15		TopK             *int64                     `json:"top_k"`
 16		PresencePenalty  *float64                   `json:"presence_penalty"`
 17		FrequencyPenalty *float64                   `json:"frequency_penalty"`
 18		Tools            []json.RawMessage          `json:"tools"`
 19		ToolChoice       *ToolChoice                `json:"tool_choice"`
 20		ProviderOptions  map[string]json.RawMessage `json:"provider_options"`
 21	}
 22
 23	if err := json.Unmarshal(data, &aux); err != nil {
 24		return err
 25	}
 26
 27	c.Prompt = aux.Prompt
 28	c.MaxOutputTokens = aux.MaxOutputTokens
 29	c.Temperature = aux.Temperature
 30	c.TopP = aux.TopP
 31	c.TopK = aux.TopK
 32	c.PresencePenalty = aux.PresencePenalty
 33	c.FrequencyPenalty = aux.FrequencyPenalty
 34	c.ToolChoice = aux.ToolChoice
 35
 36	// Unmarshal Tools slice
 37	c.Tools = make([]Tool, len(aux.Tools))
 38	for i, rawTool := range aux.Tools {
 39		tool, err := UnmarshalTool(rawTool)
 40		if err != nil {
 41			return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err)
 42		}
 43		c.Tools[i] = tool
 44	}
 45
 46	// Unmarshal ProviderOptions
 47	if len(aux.ProviderOptions) > 0 {
 48		options, err := UnmarshalProviderOptions(aux.ProviderOptions)
 49		if err != nil {
 50			return err
 51		}
 52		c.ProviderOptions = options
 53	}
 54
 55	return nil
 56}
 57
 58// UnmarshalJSON implements json.Unmarshaler for Response.
 59func (r *Response) UnmarshalJSON(data []byte) error {
 60	var aux struct {
 61		Content          json.RawMessage            `json:"content"`
 62		FinishReason     FinishReason               `json:"finish_reason"`
 63		Usage            Usage                      `json:"usage"`
 64		Warnings         []CallWarning              `json:"warnings"`
 65		ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
 66	}
 67
 68	if err := json.Unmarshal(data, &aux); err != nil {
 69		return err
 70	}
 71
 72	r.FinishReason = aux.FinishReason
 73	r.Usage = aux.Usage
 74	r.Warnings = aux.Warnings
 75
 76	// Unmarshal ResponseContent (need to know the type definition)
 77	// If ResponseContent is []Content:
 78	var rawContent []json.RawMessage
 79	if err := json.Unmarshal(aux.Content, &rawContent); err != nil {
 80		return err
 81	}
 82
 83	content := make([]Content, len(rawContent))
 84	for i, rawItem := range rawContent {
 85		item, err := UnmarshalContent(rawItem)
 86		if err != nil {
 87			return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err)
 88		}
 89		content[i] = item
 90	}
 91	r.Content = content
 92
 93	// Unmarshal ProviderMetadata
 94	if len(aux.ProviderMetadata) > 0 {
 95		metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
 96		if err != nil {
 97			return err
 98		}
 99		r.ProviderMetadata = metadata
100	}
101
102	return nil
103}
104
105// MarshalJSON implements json.Marshaler for StreamPart.
106func (s StreamPart) MarshalJSON() ([]byte, error) {
107	type alias StreamPart
108	aux := struct {
109		alias
110		Error string `json:"error,omitempty"`
111	}{
112		alias: (alias)(s),
113	}
114
115	// Marshal error to string
116	if s.Error != nil {
117		aux.Error = s.Error.Error()
118	}
119
120	// Clear the original Error field to avoid duplicate marshaling
121	aux.alias.Error = nil
122
123	return json.Marshal(aux)
124}
125
126// UnmarshalJSON implements json.Unmarshaler for StreamPart.
127func (s *StreamPart) UnmarshalJSON(data []byte) error {
128	type alias StreamPart
129	aux := struct {
130		*alias
131		Error            string                     `json:"error"`
132		ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
133	}{
134		alias: (*alias)(s),
135	}
136
137	if err := json.Unmarshal(data, &aux); err != nil {
138		return err
139	}
140
141	// Unmarshal error string back to error type
142	if aux.Error != "" {
143		s.Error = fmt.Errorf("%s", aux.Error)
144	}
145
146	// Unmarshal ProviderMetadata
147	if len(aux.ProviderMetadata) > 0 {
148		metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
149		if err != nil {
150			return err
151		}
152		s.ProviderMetadata = metadata
153	}
154
155	return nil
156}