1package fantasy
2
3import (
4 "context"
5 "fmt"
6 "iter"
7 "strings"
8)
9
10// Usage represents token usage statistics for a model call.
11type Usage struct {
12 InputTokens int64 `json:"input_tokens"`
13 OutputTokens int64 `json:"output_tokens"`
14 TotalTokens int64 `json:"total_tokens"`
15 ReasoningTokens int64 `json:"reasoning_tokens"`
16 CacheCreationTokens int64 `json:"cache_creation_tokens"`
17 CacheReadTokens int64 `json:"cache_read_tokens"`
18}
19
20func (u Usage) String() string {
21 return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
22 u.InputTokens,
23 u.OutputTokens,
24 u.TotalTokens,
25 u.ReasoningTokens,
26 u.CacheCreationTokens,
27 u.CacheReadTokens,
28 )
29}
30
31// ResponseContent represents the content of a model response.
32type ResponseContent []Content
33
34// Text returns the text content of the response.
35func (r ResponseContent) Text() string {
36 for _, c := range r {
37 if c.GetType() == ContentTypeText {
38 if textContent, ok := AsContentType[TextContent](c); ok {
39 return textContent.Text
40 }
41 }
42 }
43 return ""
44}
45
46// Reasoning returns all reasoning content parts.
47func (r ResponseContent) Reasoning() []ReasoningContent {
48 var reasoning []ReasoningContent
49 for _, c := range r {
50 if c.GetType() == ContentTypeReasoning {
51 if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
52 reasoning = append(reasoning, reasoningContent)
53 }
54 }
55 }
56 return reasoning
57}
58
59// ReasoningText returns all reasoning content as a concatenated string.
60func (r ResponseContent) ReasoningText() string {
61 var builder strings.Builder
62 for _, reasoning := range r.Reasoning() {
63 builder.WriteString(reasoning.Text)
64 }
65 return builder.String()
66}
67
68// Files returns all file content parts.
69func (r ResponseContent) Files() []FileContent {
70 var files []FileContent
71 for _, c := range r {
72 if c.GetType() == ContentTypeFile {
73 if fileContent, ok := AsContentType[FileContent](c); ok {
74 files = append(files, fileContent)
75 }
76 }
77 }
78 return files
79}
80
81// Sources returns all source content parts.
82func (r ResponseContent) Sources() []SourceContent {
83 var sources []SourceContent
84 for _, c := range r {
85 if c.GetType() == ContentTypeSource {
86 if sourceContent, ok := AsContentType[SourceContent](c); ok {
87 sources = append(sources, sourceContent)
88 }
89 }
90 }
91 return sources
92}
93
94// ToolCalls returns all tool call content parts.
95func (r ResponseContent) ToolCalls() []ToolCallContent {
96 var toolCalls []ToolCallContent
97 for _, c := range r {
98 if c.GetType() == ContentTypeToolCall {
99 if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
100 toolCalls = append(toolCalls, toolCallContent)
101 }
102 }
103 }
104 return toolCalls
105}
106
107// ToolResults returns all tool result content parts.
108func (r ResponseContent) ToolResults() []ToolResultContent {
109 var toolResults []ToolResultContent
110 for _, c := range r {
111 if c.GetType() == ContentTypeToolResult {
112 if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
113 toolResults = append(toolResults, toolResultContent)
114 }
115 }
116 }
117 return toolResults
118}
119
120// Response represents a response from a language model.
121type Response struct {
122 Content ResponseContent `json:"content"`
123 FinishReason FinishReason `json:"finish_reason"`
124 Usage Usage `json:"usage"`
125 Warnings []CallWarning `json:"warnings"`
126
127 // for provider specific response metadata, the key is the provider id
128 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
129}
130
131// StreamPartType represents the type of a stream part.
132type StreamPartType string
133
134const (
135 // StreamPartTypeWarnings represents warnings stream part type.
136 StreamPartTypeWarnings StreamPartType = "warnings"
137 // StreamPartTypeTextStart represents text start stream part type.
138 StreamPartTypeTextStart StreamPartType = "text_start"
139 // StreamPartTypeTextDelta represents text delta stream part type.
140 StreamPartTypeTextDelta StreamPartType = "text_delta"
141 // StreamPartTypeTextEnd represents text end stream part type.
142 StreamPartTypeTextEnd StreamPartType = "text_end"
143
144 // StreamPartTypeReasoningStart represents reasoning start stream part type.
145 StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
146 // StreamPartTypeReasoningDelta represents reasoning delta stream part type.
147 StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
148 // StreamPartTypeReasoningEnd represents reasoning end stream part type.
149 StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
150 // StreamPartTypeToolInputStart represents tool input start stream part type.
151 StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
152 // StreamPartTypeToolInputDelta represents tool input delta stream part type.
153 StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
154 // StreamPartTypeToolInputEnd represents tool input end stream part type.
155 StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
156 // StreamPartTypeToolCall represents tool call stream part type.
157 StreamPartTypeToolCall StreamPartType = "tool_call"
158 // StreamPartTypeToolResult represents tool result stream part type.
159 StreamPartTypeToolResult StreamPartType = "tool_result"
160 // StreamPartTypeSource represents source stream part type.
161 StreamPartTypeSource StreamPartType = "source"
162 // StreamPartTypeFinish represents finish stream part type.
163 StreamPartTypeFinish StreamPartType = "finish"
164 // StreamPartTypeError represents error stream part type.
165 StreamPartTypeError StreamPartType = "error"
166)
167
168// StreamPart represents a part of a streaming response.
169type StreamPart struct {
170 Type StreamPartType `json:"type"`
171 ID string `json:"id"`
172 ToolCallName string `json:"tool_call_name"`
173 ToolCallInput string `json:"tool_call_input"`
174 Delta string `json:"delta"`
175 ProviderExecuted bool `json:"provider_executed"`
176 Usage Usage `json:"usage"`
177 FinishReason FinishReason `json:"finish_reason"`
178 Error error `json:"error"`
179 Warnings []CallWarning `json:"warnings"`
180
181 // Source-related fields
182 SourceType SourceType `json:"source_type"`
183 URL string `json:"url"`
184 Title string `json:"title"`
185
186 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
187}
188
189// StreamResponse represents a streaming response sequence.
190type StreamResponse = iter.Seq[StreamPart]
191
192// ToolChoice represents the tool choice preference for a model call.
193type ToolChoice string
194
195const (
196 // ToolChoiceNone indicates no tools should be used.
197 ToolChoiceNone ToolChoice = "none"
198 // ToolChoiceAuto indicates tools should be used automatically.
199 ToolChoiceAuto ToolChoice = "auto"
200 // ToolChoiceRequired indicates tools are required.
201 ToolChoiceRequired ToolChoice = "required"
202)
203
204// SpecificToolChoice creates a tool choice for a specific tool name.
205func SpecificToolChoice(name string) ToolChoice {
206 return ToolChoice(name)
207}
208
209// Call represents a call to a language model.
210type Call struct {
211 Prompt Prompt `json:"prompt"`
212 MaxOutputTokens *int64 `json:"max_output_tokens"`
213 Temperature *float64 `json:"temperature"`
214 TopP *float64 `json:"top_p"`
215 TopK *int64 `json:"top_k"`
216 PresencePenalty *float64 `json:"presence_penalty"`
217 FrequencyPenalty *float64 `json:"frequency_penalty"`
218 Tools []Tool `json:"tools"`
219 ToolChoice *ToolChoice `json:"tool_choice"`
220
221 // for provider specific options, the key is the provider id
222 ProviderOptions ProviderOptions `json:"provider_options"`
223}
224
225// CallWarningType represents the type of call warning.
226type CallWarningType string
227
228const (
229 // CallWarningTypeUnsupportedSetting indicates an unsupported setting.
230 CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
231 // CallWarningTypeUnsupportedTool indicates an unsupported tool.
232 CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
233 // CallWarningTypeOther indicates other warnings.
234 CallWarningTypeOther CallWarningType = "other"
235)
236
237// CallWarning represents a warning from the model provider for this call.
238// The call will proceed, but e.g. some settings might not be supported,
239// which can lead to suboptimal results.
240type CallWarning struct {
241 Type CallWarningType `json:"type"`
242 Setting string `json:"setting"`
243 Tool Tool `json:"tool"`
244 Details string `json:"details"`
245 Message string `json:"message"`
246}
247
248// LanguageModel represents a language model that can generate responses and stream responses.
249type LanguageModel interface {
250 Generate(context.Context, Call) (*Response, error)
251 Stream(context.Context, Call) (StreamResponse, error)
252
253 GenerateObject(context.Context, ObjectCall) (*ObjectResponse, error)
254 StreamObject(context.Context, ObjectCall) (ObjectStreamResponse, error)
255
256 Provider() string
257 Model() string
258}