1package fantasy
2
3import (
4 "context"
5 "fmt"
6 "iter"
7)
8
9// Usage represents token usage statistics for a model call.
10type Usage struct {
11 InputTokens int64 `json:"input_tokens"`
12 OutputTokens int64 `json:"output_tokens"`
13 TotalTokens int64 `json:"total_tokens"`
14 ReasoningTokens int64 `json:"reasoning_tokens"`
15 CacheCreationTokens int64 `json:"cache_creation_tokens"`
16 CacheReadTokens int64 `json:"cache_read_tokens"`
17}
18
19func (u Usage) String() string {
20 return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
21 u.InputTokens,
22 u.OutputTokens,
23 u.TotalTokens,
24 u.ReasoningTokens,
25 u.CacheCreationTokens,
26 u.CacheReadTokens,
27 )
28}
29
30// ResponseContent represents the content of a model response.
31type ResponseContent []Content
32
33// Text returns the text content of the response.
34func (r ResponseContent) Text() string {
35 for _, c := range r {
36 if c.GetType() == ContentTypeText {
37 return c.(TextContent).Text
38 }
39 }
40 return ""
41}
42
43// Reasoning returns all reasoning content parts.
44func (r ResponseContent) Reasoning() []ReasoningContent {
45 var reasoning []ReasoningContent
46 for _, c := range r {
47 if c.GetType() == ContentTypeReasoning {
48 if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
49 reasoning = append(reasoning, reasoningContent)
50 }
51 }
52 }
53 return reasoning
54}
55
56// ReasoningText returns all reasoning content as a concatenated string.
57func (r ResponseContent) ReasoningText() string {
58 var text string
59 for _, reasoning := range r.Reasoning() {
60 text += reasoning.Text
61 }
62 return text
63}
64
65// Files returns all file content parts.
66func (r ResponseContent) Files() []FileContent {
67 var files []FileContent
68 for _, c := range r {
69 if c.GetType() == ContentTypeFile {
70 if fileContent, ok := AsContentType[FileContent](c); ok {
71 files = append(files, fileContent)
72 }
73 }
74 }
75 return files
76}
77
78// Sources returns all source content parts.
79func (r ResponseContent) Sources() []SourceContent {
80 var sources []SourceContent
81 for _, c := range r {
82 if c.GetType() == ContentTypeSource {
83 if sourceContent, ok := AsContentType[SourceContent](c); ok {
84 sources = append(sources, sourceContent)
85 }
86 }
87 }
88 return sources
89}
90
91// ToolCalls returns all tool call content parts.
92func (r ResponseContent) ToolCalls() []ToolCallContent {
93 var toolCalls []ToolCallContent
94 for _, c := range r {
95 if c.GetType() == ContentTypeToolCall {
96 if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
97 toolCalls = append(toolCalls, toolCallContent)
98 }
99 }
100 }
101 return toolCalls
102}
103
104// ToolResults returns all tool result content parts.
105func (r ResponseContent) ToolResults() []ToolResultContent {
106 var toolResults []ToolResultContent
107 for _, c := range r {
108 if c.GetType() == ContentTypeToolResult {
109 if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
110 toolResults = append(toolResults, toolResultContent)
111 }
112 }
113 }
114 return toolResults
115}
116
117// Response represents a response from a language model.
118type Response struct {
119 Content ResponseContent `json:"content"`
120 FinishReason FinishReason `json:"finish_reason"`
121 Usage Usage `json:"usage"`
122 Warnings []CallWarning `json:"warnings"`
123
124 // for provider specific response metadata, the key is the provider id
125 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
126}
127
128// StreamPartType represents the type of a stream part.
129type StreamPartType string
130
131const (
132 // StreamPartTypeWarnings represents warnings stream part type.
133 StreamPartTypeWarnings StreamPartType = "warnings"
134 // StreamPartTypeTextStart represents text start stream part type.
135 StreamPartTypeTextStart StreamPartType = "text_start"
136 // StreamPartTypeTextDelta represents text delta stream part type.
137 StreamPartTypeTextDelta StreamPartType = "text_delta"
138 // StreamPartTypeTextEnd represents text end stream part type.
139 StreamPartTypeTextEnd StreamPartType = "text_end"
140
141 // StreamPartTypeReasoningStart represents reasoning start stream part type.
142 StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
143 // StreamPartTypeReasoningDelta represents reasoning delta stream part type.
144 StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
145 // StreamPartTypeReasoningEnd represents reasoning end stream part type.
146 StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
147 // StreamPartTypeToolInputStart represents tool input start stream part type.
148 StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
149 // StreamPartTypeToolInputDelta represents tool input delta stream part type.
150 StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
151 // StreamPartTypeToolInputEnd represents tool input end stream part type.
152 StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
153 // StreamPartTypeToolCall represents tool call stream part type.
154 StreamPartTypeToolCall StreamPartType = "tool_call"
155 // StreamPartTypeToolResult represents tool result stream part type.
156 StreamPartTypeToolResult StreamPartType = "tool_result"
157 // StreamPartTypeSource represents source stream part type.
158 StreamPartTypeSource StreamPartType = "source"
159 // StreamPartTypeFinish represents finish stream part type.
160 StreamPartTypeFinish StreamPartType = "finish"
161 // StreamPartTypeError represents error stream part type.
162 StreamPartTypeError StreamPartType = "error"
163)
164
165// StreamPart represents a part of a streaming response.
166type StreamPart struct {
167 Type StreamPartType `json:"type"`
168 ID string `json:"id"`
169 ToolCallName string `json:"tool_call_name"`
170 ToolCallInput string `json:"tool_call_input"`
171 Delta string `json:"delta"`
172 ProviderExecuted bool `json:"provider_executed"`
173 Usage Usage `json:"usage"`
174 FinishReason FinishReason `json:"finish_reason"`
175 Error error `json:"error"`
176 Warnings []CallWarning `json:"warnings"`
177
178 // Source-related fields
179 SourceType SourceType `json:"source_type"`
180 URL string `json:"url"`
181 Title string `json:"title"`
182
183 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
184}
185
186// StreamResponse represents a streaming response sequence.
187type StreamResponse = iter.Seq[StreamPart]
188
189// ToolChoice represents the tool choice preference for a model call.
190type ToolChoice string
191
192const (
193 // ToolChoiceNone indicates no tools should be used.
194 ToolChoiceNone ToolChoice = "none"
195 // ToolChoiceAuto indicates tools should be used automatically.
196 ToolChoiceAuto ToolChoice = "auto"
197 // ToolChoiceRequired indicates tools are required.
198 ToolChoiceRequired ToolChoice = "required"
199)
200
201// SpecificToolChoice creates a tool choice for a specific tool name.
202func SpecificToolChoice(name string) ToolChoice {
203 return ToolChoice(name)
204}
205
206// Call represents a call to a language model.
207type Call struct {
208 Prompt Prompt `json:"prompt"`
209 MaxOutputTokens *int64 `json:"max_output_tokens"`
210 Temperature *float64 `json:"temperature"`
211 TopP *float64 `json:"top_p"`
212 TopK *int64 `json:"top_k"`
213 PresencePenalty *float64 `json:"presence_penalty"`
214 FrequencyPenalty *float64 `json:"frequency_penalty"`
215 Tools []Tool `json:"tools"`
216 ToolChoice *ToolChoice `json:"tool_choice"`
217
218 // for provider specific options, the key is the provider id
219 ProviderOptions ProviderOptions `json:"provider_options"`
220}
221
222// CallWarningType represents the type of call warning.
223// CallWarningType represents the type of call warning.
224type CallWarningType string
225
226const (
227 // CallWarningTypeUnsupportedSetting indicates an unsupported setting.
228 CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
229 // CallWarningTypeUnsupportedTool indicates an unsupported tool.
230 CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
231 // CallWarningTypeOther indicates other warnings.
232 CallWarningTypeOther CallWarningType = "other"
233)
234
235// CallWarning represents a warning from the model provider for this call.
236// The call will proceed, but e.g. some settings might not be supported,
237// which can lead to suboptimal results.
238type CallWarning struct {
239 Type CallWarningType `json:"type"`
240 Setting string `json:"setting"`
241 Tool Tool `json:"tool"`
242 Details string `json:"details"`
243 Message string `json:"message"`
244}
245
246// LanguageModel represents a language model that can generate responses and stream responses.
247type LanguageModel interface {
248 Generate(context.Context, Call) (*Response, error)
249 Stream(context.Context, Call) (StreamResponse, error)
250
251 Provider() string
252 Model() string
253}