1package fantasy
2
3import (
4 "context"
5 "fmt"
6 "iter"
7)
8
9type Usage struct {
10 InputTokens int64 `json:"input_tokens"`
11 OutputTokens int64 `json:"output_tokens"`
12 TotalTokens int64 `json:"total_tokens"`
13 ReasoningTokens int64 `json:"reasoning_tokens"`
14 CacheCreationTokens int64 `json:"cache_creation_tokens"`
15 CacheReadTokens int64 `json:"cache_read_tokens"`
16}
17
18func (u Usage) String() string {
19 return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
20 u.InputTokens,
21 u.OutputTokens,
22 u.TotalTokens,
23 u.ReasoningTokens,
24 u.CacheCreationTokens,
25 u.CacheReadTokens,
26 )
27}
28
29type ResponseContent []Content
30
31func (r ResponseContent) Text() string {
32 for _, c := range r {
33 if c.GetType() == ContentTypeText {
34 return c.(TextContent).Text
35 }
36 }
37 return ""
38}
39
40// Reasoning returns all reasoning content parts.
41func (r ResponseContent) Reasoning() []ReasoningContent {
42 var reasoning []ReasoningContent
43 for _, c := range r {
44 if c.GetType() == ContentTypeReasoning {
45 if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
46 reasoning = append(reasoning, reasoningContent)
47 }
48 }
49 }
50 return reasoning
51}
52
53// ReasoningText returns all reasoning content as a concatenated string.
54func (r ResponseContent) ReasoningText() string {
55 var text string
56 for _, reasoning := range r.Reasoning() {
57 text += reasoning.Text
58 }
59 return text
60}
61
62// Files returns all file content parts.
63func (r ResponseContent) Files() []FileContent {
64 var files []FileContent
65 for _, c := range r {
66 if c.GetType() == ContentTypeFile {
67 if fileContent, ok := AsContentType[FileContent](c); ok {
68 files = append(files, fileContent)
69 }
70 }
71 }
72 return files
73}
74
75// Sources returns all source content parts.
76func (r ResponseContent) Sources() []SourceContent {
77 var sources []SourceContent
78 for _, c := range r {
79 if c.GetType() == ContentTypeSource {
80 if sourceContent, ok := AsContentType[SourceContent](c); ok {
81 sources = append(sources, sourceContent)
82 }
83 }
84 }
85 return sources
86}
87
88// ToolCalls returns all tool call content parts.
89func (r ResponseContent) ToolCalls() []ToolCallContent {
90 var toolCalls []ToolCallContent
91 for _, c := range r {
92 if c.GetType() == ContentTypeToolCall {
93 if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
94 toolCalls = append(toolCalls, toolCallContent)
95 }
96 }
97 }
98 return toolCalls
99}
100
101// ToolResults returns all tool result content parts.
102func (r ResponseContent) ToolResults() []ToolResultContent {
103 var toolResults []ToolResultContent
104 for _, c := range r {
105 if c.GetType() == ContentTypeToolResult {
106 if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
107 toolResults = append(toolResults, toolResultContent)
108 }
109 }
110 }
111 return toolResults
112}
113
114type Response struct {
115 Content ResponseContent `json:"content"`
116 FinishReason FinishReason `json:"finish_reason"`
117 Usage Usage `json:"usage"`
118 Warnings []CallWarning `json:"warnings"`
119
120 // for provider specific response metadata, the key is the provider id
121 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
122}
123
124type StreamPartType string
125
126const (
127 StreamPartTypeWarnings StreamPartType = "warnings"
128 StreamPartTypeTextStart StreamPartType = "text_start"
129 StreamPartTypeTextDelta StreamPartType = "text_delta"
130 StreamPartTypeTextEnd StreamPartType = "text_end"
131
132 StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
133 StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
134 StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
135 StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
136 StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
137 StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
138 StreamPartTypeToolCall StreamPartType = "tool_call"
139 StreamPartTypeToolResult StreamPartType = "tool_result"
140 StreamPartTypeSource StreamPartType = "source"
141 StreamPartTypeFinish StreamPartType = "finish"
142 StreamPartTypeError StreamPartType = "error"
143)
144
145type StreamPart struct {
146 Type StreamPartType `json:"type"`
147 ID string `json:"id"`
148 ToolCallName string `json:"tool_call_name"`
149 ToolCallInput string `json:"tool_call_input"`
150 Delta string `json:"delta"`
151 ProviderExecuted bool `json:"provider_executed"`
152 Usage Usage `json:"usage"`
153 FinishReason FinishReason `json:"finish_reason"`
154 Error error `json:"error"`
155 Warnings []CallWarning `json:"warnings"`
156
157 // Source-related fields
158 SourceType SourceType `json:"source_type"`
159 URL string `json:"url"`
160 Title string `json:"title"`
161
162 ProviderMetadata ProviderMetadata `json:"provider_metadata"`
163}
164type StreamResponse = iter.Seq[StreamPart]
165
166type ToolChoice string
167
168const (
169 ToolChoiceNone ToolChoice = "none"
170 ToolChoiceAuto ToolChoice = "auto"
171 ToolChoiceRequired ToolChoice = "required"
172)
173
174func SpecificToolChoice(name string) ToolChoice {
175 return ToolChoice(name)
176}
177
178type Call struct {
179 Prompt Prompt `json:"prompt"`
180 MaxOutputTokens *int64 `json:"max_output_tokens"`
181 Temperature *float64 `json:"temperature"`
182 TopP *float64 `json:"top_p"`
183 TopK *int64 `json:"top_k"`
184 PresencePenalty *float64 `json:"presence_penalty"`
185 FrequencyPenalty *float64 `json:"frequency_penalty"`
186 Tools []Tool `json:"tools"`
187 ToolChoice *ToolChoice `json:"tool_choice"`
188
189 // for provider specific options, the key is the provider id
190 ProviderOptions ProviderOptions `json:"provider_options"`
191}
192
193// CallWarningType represents the type of call warning.
194type CallWarningType string
195
196const (
197 // CallWarningTypeUnsupportedSetting indicates an unsupported setting.
198 CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
199 // CallWarningTypeUnsupportedTool indicates an unsupported tool.
200 CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
201 // CallWarningTypeOther indicates other warnings.
202 CallWarningTypeOther CallWarningType = "other"
203)
204
205// CallWarning represents a warning from the model provider for this call.
206// The call will proceed, but e.g. some settings might not be supported,
207// which can lead to suboptimal results.
208type CallWarning struct {
209 Type CallWarningType `json:"type"`
210 Setting string `json:"setting"`
211 Tool Tool `json:"tool"`
212 Details string `json:"details"`
213 Message string `json:"message"`
214}
215
216type LanguageModel interface {
217 Generate(context.Context, Call) (*Response, error)
218 Stream(context.Context, Call) (StreamResponse, error)
219
220 Provider() string
221 Model() string
222}