model_json.go

  1package fantasy
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6)
  7
  8// UnmarshalJSON implements json.Unmarshaler for Call.
  9func (c *Call) UnmarshalJSON(data []byte) error {
 10	var aux struct {
 11		Prompt           Prompt                     `json:"prompt"`
 12		MaxOutputTokens  *int64                     `json:"max_output_tokens"`
 13		Temperature      *float64                   `json:"temperature"`
 14		TopP             *float64                   `json:"top_p"`
 15		TopK             *int64                     `json:"top_k"`
 16		PresencePenalty  *float64                   `json:"presence_penalty"`
 17		FrequencyPenalty *float64                   `json:"frequency_penalty"`
 18		Tools            []json.RawMessage          `json:"tools"`
 19		ToolChoice       *ToolChoice                `json:"tool_choice"`
 20		ProviderOptions  map[string]json.RawMessage `json:"provider_options"`
 21	}
 22
 23	if err := json.Unmarshal(data, &aux); err != nil {
 24		return err
 25	}
 26
 27	c.Prompt = aux.Prompt
 28	c.MaxOutputTokens = aux.MaxOutputTokens
 29	c.Temperature = aux.Temperature
 30	c.TopP = aux.TopP
 31	c.TopK = aux.TopK
 32	c.PresencePenalty = aux.PresencePenalty
 33	c.FrequencyPenalty = aux.FrequencyPenalty
 34	c.ToolChoice = aux.ToolChoice
 35
 36	// Unmarshal Tools slice
 37	c.Tools = make([]Tool, len(aux.Tools))
 38	for i, rawTool := range aux.Tools {
 39		tool, err := UnmarshalTool(rawTool)
 40		if err != nil {
 41			return fmt.Errorf("failed to unmarshal tool at index %d: %w", i, err)
 42		}
 43		c.Tools[i] = tool
 44	}
 45
 46	// Unmarshal ProviderOptions
 47	if len(aux.ProviderOptions) > 0 {
 48		options, err := UnmarshalProviderOptions(aux.ProviderOptions)
 49		if err != nil {
 50			return err
 51		}
 52		c.ProviderOptions = options
 53	}
 54
 55	return nil
 56}
 57
 58// UnmarshalJSON implements json.Unmarshaler for Response.
 59func (r *Response) UnmarshalJSON(data []byte) error {
 60	var aux struct {
 61		Content          json.RawMessage            `json:"content"`
 62		FinishReason     FinishReason               `json:"finish_reason"`
 63		Usage            Usage                      `json:"usage"`
 64		Warnings         []CallWarning              `json:"warnings"`
 65		ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
 66	}
 67
 68	if err := json.Unmarshal(data, &aux); err != nil {
 69		return err
 70	}
 71
 72	r.FinishReason = aux.FinishReason
 73	r.Usage = aux.Usage
 74	r.Warnings = aux.Warnings
 75
 76	// Unmarshal ResponseContent (need to know the type definition)
 77	// If ResponseContent is []Content:
 78	var rawContent []json.RawMessage
 79	if err := json.Unmarshal(aux.Content, &rawContent); err != nil {
 80		return err
 81	}
 82
 83	content := make([]Content, len(rawContent))
 84	for i, rawItem := range rawContent {
 85		item, err := UnmarshalContent(rawItem)
 86		if err != nil {
 87			return fmt.Errorf("failed to unmarshal content at index %d: %w", i, err)
 88		}
 89		content[i] = item
 90	}
 91	r.Content = content
 92
 93	// Unmarshal ProviderMetadata
 94	if len(aux.ProviderMetadata) > 0 {
 95		metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
 96		if err != nil {
 97			return err
 98		}
 99		r.ProviderMetadata = metadata
100	}
101
102	return nil
103}
104
105// UnmarshalJSON implements json.Unmarshaler for StreamPart.
106func (s *StreamPart) UnmarshalJSON(data []byte) error {
107	var aux struct {
108		Type             StreamPartType             `json:"type"`
109		ID               string                     `json:"id"`
110		ToolCallName     string                     `json:"tool_call_name"`
111		ToolCallInput    string                     `json:"tool_call_input"`
112		Delta            string                     `json:"delta"`
113		ProviderExecuted bool                       `json:"provider_executed"`
114		Usage            Usage                      `json:"usage"`
115		FinishReason     FinishReason               `json:"finish_reason"`
116		Error            error                      `json:"error"`
117		Warnings         []CallWarning              `json:"warnings"`
118		SourceType       SourceType                 `json:"source_type"`
119		URL              string                     `json:"url"`
120		Title            string                     `json:"title"`
121		ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
122	}
123
124	if err := json.Unmarshal(data, &aux); err != nil {
125		return err
126	}
127
128	s.Type = aux.Type
129	s.ID = aux.ID
130	s.ToolCallName = aux.ToolCallName
131	s.ToolCallInput = aux.ToolCallInput
132	s.Delta = aux.Delta
133	s.ProviderExecuted = aux.ProviderExecuted
134	s.Usage = aux.Usage
135	s.FinishReason = aux.FinishReason
136	s.Error = aux.Error
137	s.Warnings = aux.Warnings
138	s.SourceType = aux.SourceType
139	s.URL = aux.URL
140	s.Title = aux.Title
141
142	// Unmarshal ProviderMetadata
143	if len(aux.ProviderMetadata) > 0 {
144		metadata, err := UnmarshalProviderMetadata(aux.ProviderMetadata)
145		if err != nil {
146			return err
147		}
148		s.ProviderMetadata = metadata
149	}
150
151	return nil
152}