model.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"fmt"
  6	"iter"
  7)
  8
  9type Usage struct {
 10	InputTokens         int64 `json:"input_tokens"`
 11	OutputTokens        int64 `json:"output_tokens"`
 12	TotalTokens         int64 `json:"total_tokens"`
 13	ReasoningTokens     int64 `json:"reasoning_tokens"`
 14	CacheCreationTokens int64 `json:"cache_creation_tokens"`
 15	CacheReadTokens     int64 `json:"cache_read_tokens"`
 16}
 17
 18func (u Usage) String() string {
 19	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
 20		u.InputTokens,
 21		u.OutputTokens,
 22		u.TotalTokens,
 23		u.ReasoningTokens,
 24		u.CacheCreationTokens,
 25		u.CacheReadTokens,
 26	)
 27}
 28
 29type ResponseContent []Content
 30
 31func (r ResponseContent) Text() string {
 32	for _, c := range r {
 33		if c.GetType() == ContentTypeText {
 34			return c.(TextContent).Text
 35		}
 36	}
 37	return ""
 38}
 39
 40// Reasoning returns all reasoning content parts.
 41func (r ResponseContent) Reasoning() []ReasoningContent {
 42	var reasoning []ReasoningContent
 43	for _, c := range r {
 44		if c.GetType() == ContentTypeReasoning {
 45			if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
 46				reasoning = append(reasoning, reasoningContent)
 47			}
 48		}
 49	}
 50	return reasoning
 51}
 52
 53// ReasoningText returns all reasoning content as a concatenated string.
 54func (r ResponseContent) ReasoningText() string {
 55	var text string
 56	for _, reasoning := range r.Reasoning() {
 57		text += reasoning.Text
 58	}
 59	return text
 60}
 61
 62// Files returns all file content parts.
 63func (r ResponseContent) Files() []FileContent {
 64	var files []FileContent
 65	for _, c := range r {
 66		if c.GetType() == ContentTypeFile {
 67			if fileContent, ok := AsContentType[FileContent](c); ok {
 68				files = append(files, fileContent)
 69			}
 70		}
 71	}
 72	return files
 73}
 74
 75// Sources returns all source content parts.
 76func (r ResponseContent) Sources() []SourceContent {
 77	var sources []SourceContent
 78	for _, c := range r {
 79		if c.GetType() == ContentTypeSource {
 80			if sourceContent, ok := AsContentType[SourceContent](c); ok {
 81				sources = append(sources, sourceContent)
 82			}
 83		}
 84	}
 85	return sources
 86}
 87
 88// ToolCalls returns all tool call content parts.
 89func (r ResponseContent) ToolCalls() []ToolCallContent {
 90	var toolCalls []ToolCallContent
 91	for _, c := range r {
 92		if c.GetType() == ContentTypeToolCall {
 93			if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
 94				toolCalls = append(toolCalls, toolCallContent)
 95			}
 96		}
 97	}
 98	return toolCalls
 99}
100
101// ToolResults returns all tool result content parts.
102func (r ResponseContent) ToolResults() []ToolResultContent {
103	var toolResults []ToolResultContent
104	for _, c := range r {
105		if c.GetType() == ContentTypeToolResult {
106			if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
107				toolResults = append(toolResults, toolResultContent)
108			}
109		}
110	}
111	return toolResults
112}
113
114type Response struct {
115	Content      ResponseContent `json:"content"`
116	FinishReason FinishReason    `json:"finish_reason"`
117	Usage        Usage           `json:"usage"`
118	Warnings     []CallWarning   `json:"warnings"`
119
120	// for provider specific response metadata, the key is the provider id
121	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
122}
123
124type StreamPartType string
125
126const (
127	StreamPartTypeWarnings  StreamPartType = "warnings"
128	StreamPartTypeTextStart StreamPartType = "text_start"
129	StreamPartTypeTextDelta StreamPartType = "text_delta"
130	StreamPartTypeTextEnd   StreamPartType = "text_end"
131
132	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
133	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
134	StreamPartTypeReasoningEnd   StreamPartType = "reasoning_end"
135	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
136	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
137	StreamPartTypeToolInputEnd   StreamPartType = "tool_input_end"
138	StreamPartTypeToolCall       StreamPartType = "tool_call"
139	StreamPartTypeToolResult     StreamPartType = "tool_result"
140	StreamPartTypeSource         StreamPartType = "source"
141	StreamPartTypeFinish         StreamPartType = "finish"
142	StreamPartTypeError          StreamPartType = "error"
143)
144
145type StreamPart struct {
146	Type             StreamPartType `json:"type"`
147	ID               string         `json:"id"`
148	ToolCallName     string         `json:"tool_call_name"`
149	ToolCallInput    string         `json:"tool_call_input"`
150	Delta            string         `json:"delta"`
151	ProviderExecuted bool           `json:"provider_executed"`
152	Usage            Usage          `json:"usage"`
153	FinishReason     FinishReason   `json:"finish_reason"`
154	Error            error          `json:"error"`
155	Warnings         []CallWarning  `json:"warnings"`
156
157	// Source-related fields
158	SourceType SourceType `json:"source_type"`
159	URL        string     `json:"url"`
160	Title      string     `json:"title"`
161
162	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
163}
164type StreamResponse = iter.Seq[StreamPart]
165
166type ToolChoice string
167
168const (
169	ToolChoiceNone     ToolChoice = "none"
170	ToolChoiceAuto     ToolChoice = "auto"
171	ToolChoiceRequired ToolChoice = "required"
172)
173
174func SpecificToolChoice(name string) ToolChoice {
175	return ToolChoice(name)
176}
177
178type Call struct {
179	Prompt           Prompt      `json:"prompt"`
180	MaxOutputTokens  *int64      `json:"max_output_tokens"`
181	Temperature      *float64    `json:"temperature"`
182	TopP             *float64    `json:"top_p"`
183	TopK             *int64      `json:"top_k"`
184	PresencePenalty  *float64    `json:"presence_penalty"`
185	FrequencyPenalty *float64    `json:"frequency_penalty"`
186	Tools            []Tool      `json:"tools"`
187	ToolChoice       *ToolChoice `json:"tool_choice"`
188
189	// for provider specific options, the key is the provider id
190	ProviderOptions ProviderOptions `json:"provider_options"`
191}
192
193// CallWarningType represents the type of call warning.
194type CallWarningType string
195
196const (
197	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
198	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
199	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
200	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
201	// CallWarningTypeOther indicates other warnings.
202	CallWarningTypeOther CallWarningType = "other"
203)
204
205// CallWarning represents a warning from the model provider for this call.
206// The call will proceed, but e.g. some settings might not be supported,
207// which can lead to suboptimal results.
208type CallWarning struct {
209	Type    CallWarningType `json:"type"`
210	Setting string          `json:"setting"`
211	Tool    Tool            `json:"tool"`
212	Details string          `json:"details"`
213	Message string          `json:"message"`
214}
215
216type LanguageModel interface {
217	Generate(context.Context, Call) (*Response, error)
218	Stream(context.Context, Call) (StreamResponse, error)
219
220	Provider() string
221	Model() string
222}