1package fantasy
2
3import (
4 "context"
5 "fmt"
6 "iter"
7)
8
9// Usage represents token usage statistics for a model call.
10type Usage struct {
11 InputTokens int64 `json:"input_tokens"`
12 OutputTokens int64 `json:"output_tokens"`
13 TotalTokens int64 `json:"total_tokens"`
14 ReasoningTokens int64 `json:"reasoning_tokens"`
15 CacheCreationTokens int64 `json:"cache_creation_tokens"`
16 CacheReadTokens int64 `json:"cache_read_tokens"`
17}
18
19func (u Usage) String() string {
20 return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
21 u.InputTokens,
22 u.OutputTokens,
23 u.TotalTokens,
24 u.ReasoningTokens,
25 u.CacheCreationTokens,
26 u.CacheReadTokens,
27 )
28}
29
30// ResponseContent represents the content of a model response.
31type ResponseContent []Content
32
33// Text returns the text content of the response.
34func (r ResponseContent) Text() string {
35 for _, c := range r {
36 if c.GetType() == ContentTypeText {
37 if textContent, ok := AsContentType[TextContent](c); ok {
38 return textContent.Text
39 }
40 }
41 }
42 return ""
43}
44
45// Reasoning returns all reasoning content parts.
46func (r ResponseContent) Reasoning() []ReasoningContent {
47 var reasoning []ReasoningContent
48 for _, c := range r {
49 if c.GetType() == ContentTypeReasoning {
50 if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
51 reasoning = append(reasoning, reasoningContent)
52 }
53 }
54 }
55 return reasoning
56}
57
58// ReasoningText returns all reasoning content as a concatenated string.
59func (r ResponseContent) ReasoningText() string {
60 var text string
61 for _, reasoning := range r.Reasoning() {
62 text += reasoning.Text
63 }
64 return text
65}
66
67// Files returns all file content parts.
68func (r ResponseContent) Files() []FileContent {
69 var files []FileContent
70 for _, c := range r {
71 if c.GetType() == ContentTypeFile {
72 if fileContent, ok := AsContentType[FileContent](c); ok {
73 files = append(files, fileContent)
74 }
75 }
76 }
77 return files
78}
79
80// Sources returns all source content parts.
81func (r ResponseContent) Sources() []SourceContent {
82 var sources []SourceContent
83 for _, c := range r {
84 if c.GetType() == ContentTypeSource {
85 if sourceContent, ok := AsContentType[SourceContent](c); ok {
86 sources = append(sources, sourceContent)
87 }
88 }
89 }
90 return sources
91}
92
93// ToolCalls returns all tool call content parts.
94func (r ResponseContent) ToolCalls() []ToolCallContent {
95 var toolCalls []ToolCallContent
96 for _, c := range r {
97 if c.GetType() == ContentTypeToolCall {
98 if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
99 toolCalls = append(toolCalls, toolCallContent)
100 }
101 }
102 }
103 return toolCalls
104}
105
106// ToolResults returns all tool result content parts.
107func (r ResponseContent) ToolResults() []ToolResultContent {
108 var toolResults []ToolResultContent
109 for _, c := range r {
110 if c.GetType() == ContentTypeToolResult {
111 if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
112 toolResults = append(toolResults, toolResultContent)
113 }
114 }
115 }
116 return toolResults
117}
118
119// Response represents a response from a language model.
120type Response struct {
121 Content ResponseContent `json:"content"`
122 FinishReason FinishReason `json:"finish_reason"`
123 Usage Usage `json:"usage"`
124 Warnings []CallWarning `json:"warnings"`
125
126 // for provider specific response metadata, the key is the provider id
127 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
128}
129
130// StreamPartType represents the type of a stream part.
131type StreamPartType string
132
133const (
134 // StreamPartTypeWarnings represents warnings stream part type.
135 StreamPartTypeWarnings StreamPartType = "warnings"
136 // StreamPartTypeTextStart represents text start stream part type.
137 StreamPartTypeTextStart StreamPartType = "text_start"
138 // StreamPartTypeTextDelta represents text delta stream part type.
139 StreamPartTypeTextDelta StreamPartType = "text_delta"
140 // StreamPartTypeTextEnd represents text end stream part type.
141 StreamPartTypeTextEnd StreamPartType = "text_end"
142
143 // StreamPartTypeReasoningStart represents reasoning start stream part type.
144 StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
145 // StreamPartTypeReasoningDelta represents reasoning delta stream part type.
146 StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
147 // StreamPartTypeReasoningEnd represents reasoning end stream part type.
148 StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
149 // StreamPartTypeToolInputStart represents tool input start stream part type.
150 StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
151 // StreamPartTypeToolInputDelta represents tool input delta stream part type.
152 StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
153 // StreamPartTypeToolInputEnd represents tool input end stream part type.
154 StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
155 // StreamPartTypeToolCall represents tool call stream part type.
156 StreamPartTypeToolCall StreamPartType = "tool_call"
157 // StreamPartTypeToolResult represents tool result stream part type.
158 StreamPartTypeToolResult StreamPartType = "tool_result"
159 // StreamPartTypeSource represents source stream part type.
160 StreamPartTypeSource StreamPartType = "source"
161 // StreamPartTypeFinish represents finish stream part type.
162 StreamPartTypeFinish StreamPartType = "finish"
163 // StreamPartTypeError represents error stream part type.
164 StreamPartTypeError StreamPartType = "error"
165)
166
167// StreamPart represents a part of a streaming response.
168type StreamPart struct {
169 Type StreamPartType `json:"type"`
170 ID string `json:"id"`
171 ToolCallName string `json:"tool_call_name"`
172 ToolCallInput string `json:"tool_call_input"`
173 Delta string `json:"delta"`
174 ProviderExecuted bool `json:"provider_executed"`
175 Usage Usage `json:"usage"`
176 FinishReason FinishReason `json:"finish_reason"`
177 Error error `json:"error"`
178 Warnings []CallWarning `json:"warnings"`
179
180 // Source-related fields
181 SourceType SourceType `json:"source_type"`
182 URL string `json:"url"`
183 Title string `json:"title"`
184
185 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
186}
187
188// StreamResponse represents a streaming response sequence.
189type StreamResponse = iter.Seq[StreamPart]
190
191// ToolChoice represents the tool choice preference for a model call.
192type ToolChoice string
193
194const (
195 // ToolChoiceNone indicates no tools should be used.
196 ToolChoiceNone ToolChoice = "none"
197 // ToolChoiceAuto indicates tools should be used automatically.
198 ToolChoiceAuto ToolChoice = "auto"
199 // ToolChoiceRequired indicates tools are required.
200 ToolChoiceRequired ToolChoice = "required"
201)
202
203// SpecificToolChoice creates a tool choice for a specific tool name.
204func SpecificToolChoice(name string) ToolChoice {
205 return ToolChoice(name)
206}
207
208// Call represents a call to a language model.
209type Call struct {
210 Prompt Prompt `json:"prompt"`
211 MaxOutputTokens *int64 `json:"max_output_tokens"`
212 Temperature *float64 `json:"temperature"`
213 TopP *float64 `json:"top_p"`
214 TopK *int64 `json:"top_k"`
215 PresencePenalty *float64 `json:"presence_penalty"`
216 FrequencyPenalty *float64 `json:"frequency_penalty"`
217 Tools []Tool `json:"tools"`
218 ToolChoice *ToolChoice `json:"tool_choice"`
219
220 // for provider specific options, the key is the provider id
221 ProviderOptions ProviderOptions `json:"provider_options"`
222}
223
224// CallWarningType represents the type of call warning.
225type CallWarningType string
226
227const (
228 // CallWarningTypeUnsupportedSetting indicates an unsupported setting.
229 CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
230 // CallWarningTypeUnsupportedTool indicates an unsupported tool.
231 CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
232 // CallWarningTypeOther indicates other warnings.
233 CallWarningTypeOther CallWarningType = "other"
234)
235
236// CallWarning represents a warning from the model provider for this call.
237// The call will proceed, but e.g. some settings might not be supported,
238// which can lead to suboptimal results.
239type CallWarning struct {
240 Type CallWarningType `json:"type"`
241 Setting string `json:"setting"`
242 Tool Tool `json:"tool"`
243 Details string `json:"details"`
244 Message string `json:"message"`
245}
246
247// LanguageModel represents a language model that can generate responses and stream responses.
248type LanguageModel interface {
249 Generate(context.Context, Call) (*Response, error)
250 Stream(context.Context, Call) (StreamResponse, error)
251
252 GenerateObject(context.Context, ObjectCall) (*ObjectResponse, error)
253 StreamObject(context.Context, ObjectCall) (ObjectStreamResponse, error)
254
255 Provider() string
256 Model() string
257}