model.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"fmt"
  6	"iter"
  7)
  8
  9// Usage represents token usage statistics for a model call.
 10type Usage struct {
 11	InputTokens         int64 `json:"input_tokens"`
 12	OutputTokens        int64 `json:"output_tokens"`
 13	TotalTokens         int64 `json:"total_tokens"`
 14	ReasoningTokens     int64 `json:"reasoning_tokens"`
 15	CacheCreationTokens int64 `json:"cache_creation_tokens"`
 16	CacheReadTokens     int64 `json:"cache_read_tokens"`
 17}
 18
 19func (u Usage) String() string {
 20	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
 21		u.InputTokens,
 22		u.OutputTokens,
 23		u.TotalTokens,
 24		u.ReasoningTokens,
 25		u.CacheCreationTokens,
 26		u.CacheReadTokens,
 27	)
 28}
 29
 30// ResponseContent represents the content of a model response.
 31type ResponseContent []Content
 32
 33// Text returns the text content of the response.
 34func (r ResponseContent) Text() string {
 35	for _, c := range r {
 36		if c.GetType() == ContentTypeText {
 37			if textContent, ok := AsContentType[TextContent](c); ok {
 38				return textContent.Text
 39			}
 40		}
 41	}
 42	return ""
 43}
 44
 45// Reasoning returns all reasoning content parts.
 46func (r ResponseContent) Reasoning() []ReasoningContent {
 47	var reasoning []ReasoningContent
 48	for _, c := range r {
 49		if c.GetType() == ContentTypeReasoning {
 50			if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
 51				reasoning = append(reasoning, reasoningContent)
 52			}
 53		}
 54	}
 55	return reasoning
 56}
 57
 58// ReasoningText returns all reasoning content as a concatenated string.
 59func (r ResponseContent) ReasoningText() string {
 60	var text string
 61	for _, reasoning := range r.Reasoning() {
 62		text += reasoning.Text
 63	}
 64	return text
 65}
 66
 67// Files returns all file content parts.
 68func (r ResponseContent) Files() []FileContent {
 69	var files []FileContent
 70	for _, c := range r {
 71		if c.GetType() == ContentTypeFile {
 72			if fileContent, ok := AsContentType[FileContent](c); ok {
 73				files = append(files, fileContent)
 74			}
 75		}
 76	}
 77	return files
 78}
 79
 80// Sources returns all source content parts.
 81func (r ResponseContent) Sources() []SourceContent {
 82	var sources []SourceContent
 83	for _, c := range r {
 84		if c.GetType() == ContentTypeSource {
 85			if sourceContent, ok := AsContentType[SourceContent](c); ok {
 86				sources = append(sources, sourceContent)
 87			}
 88		}
 89	}
 90	return sources
 91}
 92
 93// ToolCalls returns all tool call content parts.
 94func (r ResponseContent) ToolCalls() []ToolCallContent {
 95	var toolCalls []ToolCallContent
 96	for _, c := range r {
 97		if c.GetType() == ContentTypeToolCall {
 98			if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
 99				toolCalls = append(toolCalls, toolCallContent)
100			}
101		}
102	}
103	return toolCalls
104}
105
106// ToolResults returns all tool result content parts.
107func (r ResponseContent) ToolResults() []ToolResultContent {
108	var toolResults []ToolResultContent
109	for _, c := range r {
110		if c.GetType() == ContentTypeToolResult {
111			if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
112				toolResults = append(toolResults, toolResultContent)
113			}
114		}
115	}
116	return toolResults
117}
118
119// Response represents a response from a language model.
120type Response struct {
121	Content      ResponseContent `json:"content"`
122	FinishReason FinishReason    `json:"finish_reason"`
123	Usage        Usage           `json:"usage"`
124	Warnings     []CallWarning   `json:"warnings"`
125
126	// for provider specific response metadata, the key is the provider id
127	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
128}
129
130// StreamPartType represents the type of a stream part.
131type StreamPartType string
132
133const (
134	// StreamPartTypeWarnings represents warnings stream part type.
135	StreamPartTypeWarnings StreamPartType = "warnings"
136	// StreamPartTypeTextStart represents text start stream part type.
137	StreamPartTypeTextStart StreamPartType = "text_start"
138	// StreamPartTypeTextDelta represents text delta stream part type.
139	StreamPartTypeTextDelta StreamPartType = "text_delta"
140	// StreamPartTypeTextEnd represents text end stream part type.
141	StreamPartTypeTextEnd StreamPartType = "text_end"
142
143	// StreamPartTypeReasoningStart represents reasoning start stream part type.
144	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
145	// StreamPartTypeReasoningDelta represents reasoning delta stream part type.
146	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
147	// StreamPartTypeReasoningEnd represents reasoning end stream part type.
148	StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
149	// StreamPartTypeToolInputStart represents tool input start stream part type.
150	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
151	// StreamPartTypeToolInputDelta represents tool input delta stream part type.
152	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
153	// StreamPartTypeToolInputEnd represents tool input end stream part type.
154	StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
155	// StreamPartTypeToolCall represents tool call stream part type.
156	StreamPartTypeToolCall StreamPartType = "tool_call"
157	// StreamPartTypeToolResult represents tool result stream part type.
158	StreamPartTypeToolResult StreamPartType = "tool_result"
159	// StreamPartTypeSource represents source stream part type.
160	StreamPartTypeSource StreamPartType = "source"
161	// StreamPartTypeFinish represents finish stream part type.
162	StreamPartTypeFinish StreamPartType = "finish"
163	// StreamPartTypeError represents error stream part type.
164	StreamPartTypeError StreamPartType = "error"
165)
166
167// StreamPart represents a part of a streaming response.
168type StreamPart struct {
169	Type             StreamPartType `json:"type"`
170	ID               string         `json:"id"`
171	ToolCallName     string         `json:"tool_call_name"`
172	ToolCallInput    string         `json:"tool_call_input"`
173	Delta            string         `json:"delta"`
174	ProviderExecuted bool           `json:"provider_executed"`
175	Usage            Usage          `json:"usage"`
176	FinishReason     FinishReason   `json:"finish_reason"`
177	Error            error          `json:"error"`
178	Warnings         []CallWarning  `json:"warnings"`
179
180	// Source-related fields
181	SourceType SourceType `json:"source_type"`
182	URL        string     `json:"url"`
183	Title      string     `json:"title"`
184
185	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
186}
187
188// StreamResponse represents a streaming response sequence.
189type StreamResponse = iter.Seq[StreamPart]
190
191// ToolChoice represents the tool choice preference for a model call.
192type ToolChoice string
193
194const (
195	// ToolChoiceNone indicates no tools should be used.
196	ToolChoiceNone ToolChoice = "none"
197	// ToolChoiceAuto indicates tools should be used automatically.
198	ToolChoiceAuto ToolChoice = "auto"
199	// ToolChoiceRequired indicates tools are required.
200	ToolChoiceRequired ToolChoice = "required"
201)
202
203// SpecificToolChoice creates a tool choice for a specific tool name.
204func SpecificToolChoice(name string) ToolChoice {
205	return ToolChoice(name)
206}
207
208// Call represents a call to a language model.
209type Call struct {
210	Prompt           Prompt      `json:"prompt"`
211	MaxOutputTokens  *int64      `json:"max_output_tokens"`
212	Temperature      *float64    `json:"temperature"`
213	TopP             *float64    `json:"top_p"`
214	TopK             *int64      `json:"top_k"`
215	PresencePenalty  *float64    `json:"presence_penalty"`
216	FrequencyPenalty *float64    `json:"frequency_penalty"`
217	Tools            []Tool      `json:"tools"`
218	ToolChoice       *ToolChoice `json:"tool_choice"`
219
220	// for provider specific options, the key is the provider id
221	ProviderOptions ProviderOptions `json:"provider_options"`
222}
223
224// CallWarningType represents the type of call warning.
225type CallWarningType string
226
227const (
228	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
229	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
230	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
231	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
232	// CallWarningTypeOther indicates other warnings.
233	CallWarningTypeOther CallWarningType = "other"
234)
235
236// CallWarning represents a warning from the model provider for this call.
237// The call will proceed, but e.g. some settings might not be supported,
238// which can lead to suboptimal results.
239type CallWarning struct {
240	Type    CallWarningType `json:"type"`
241	Setting string          `json:"setting"`
242	Tool    Tool            `json:"tool"`
243	Details string          `json:"details"`
244	Message string          `json:"message"`
245}
246
247// LanguageModel represents a language model that can generate responses and stream responses.
248type LanguageModel interface {
249	Generate(context.Context, Call) (*Response, error)
250	Stream(context.Context, Call) (StreamResponse, error)
251
252	GenerateObject(context.Context, ObjectCall) (*ObjectResponse, error)
253	StreamObject(context.Context, ObjectCall) (ObjectStreamResponse, error)
254
255	Provider() string
256	Model() string
257}