1package fantasy
  2
  3import (
  4	"context"
  5	"fmt"
  6	"iter"
  7)
  8
  9// Usage represents token usage statistics for a model call.
 10type Usage struct {
 11	InputTokens         int64 `json:"input_tokens"`
 12	OutputTokens        int64 `json:"output_tokens"`
 13	TotalTokens         int64 `json:"total_tokens"`
 14	ReasoningTokens     int64 `json:"reasoning_tokens"`
 15	CacheCreationTokens int64 `json:"cache_creation_tokens"`
 16	CacheReadTokens     int64 `json:"cache_read_tokens"`
 17}
 18
 19func (u Usage) String() string {
 20	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
 21		u.InputTokens,
 22		u.OutputTokens,
 23		u.TotalTokens,
 24		u.ReasoningTokens,
 25		u.CacheCreationTokens,
 26		u.CacheReadTokens,
 27	)
 28}
 29
 30// ResponseContent represents the content of a model response.
 31type ResponseContent []Content
 32
 33// Text returns the text content of the response.
 34func (r ResponseContent) Text() string {
 35	for _, c := range r {
 36		if c.GetType() == ContentTypeText {
 37			return c.(TextContent).Text
 38		}
 39	}
 40	return ""
 41}
 42
 43// Reasoning returns all reasoning content parts.
 44func (r ResponseContent) Reasoning() []ReasoningContent {
 45	var reasoning []ReasoningContent
 46	for _, c := range r {
 47		if c.GetType() == ContentTypeReasoning {
 48			if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
 49				reasoning = append(reasoning, reasoningContent)
 50			}
 51		}
 52	}
 53	return reasoning
 54}
 55
 56// ReasoningText returns all reasoning content as a concatenated string.
 57func (r ResponseContent) ReasoningText() string {
 58	var text string
 59	for _, reasoning := range r.Reasoning() {
 60		text += reasoning.Text
 61	}
 62	return text
 63}
 64
 65// Files returns all file content parts.
 66func (r ResponseContent) Files() []FileContent {
 67	var files []FileContent
 68	for _, c := range r {
 69		if c.GetType() == ContentTypeFile {
 70			if fileContent, ok := AsContentType[FileContent](c); ok {
 71				files = append(files, fileContent)
 72			}
 73		}
 74	}
 75	return files
 76}
 77
 78// Sources returns all source content parts.
 79func (r ResponseContent) Sources() []SourceContent {
 80	var sources []SourceContent
 81	for _, c := range r {
 82		if c.GetType() == ContentTypeSource {
 83			if sourceContent, ok := AsContentType[SourceContent](c); ok {
 84				sources = append(sources, sourceContent)
 85			}
 86		}
 87	}
 88	return sources
 89}
 90
 91// ToolCalls returns all tool call content parts.
 92func (r ResponseContent) ToolCalls() []ToolCallContent {
 93	var toolCalls []ToolCallContent
 94	for _, c := range r {
 95		if c.GetType() == ContentTypeToolCall {
 96			if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
 97				toolCalls = append(toolCalls, toolCallContent)
 98			}
 99		}
100	}
101	return toolCalls
102}
103
104// ToolResults returns all tool result content parts.
105func (r ResponseContent) ToolResults() []ToolResultContent {
106	var toolResults []ToolResultContent
107	for _, c := range r {
108		if c.GetType() == ContentTypeToolResult {
109			if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
110				toolResults = append(toolResults, toolResultContent)
111			}
112		}
113	}
114	return toolResults
115}
116
117// Response represents a response from a language model.
118type Response struct {
119	Content      ResponseContent `json:"content"`
120	FinishReason FinishReason    `json:"finish_reason"`
121	Usage        Usage           `json:"usage"`
122	Warnings     []CallWarning   `json:"warnings"`
123
124	// for provider specific response metadata, the key is the provider id
125	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
126}
127
128// StreamPartType represents the type of a stream part.
129type StreamPartType string
130
131const (
132	// StreamPartTypeWarnings represents warnings stream part type.
133	StreamPartTypeWarnings StreamPartType = "warnings"
134	// StreamPartTypeTextStart represents text start stream part type.
135	StreamPartTypeTextStart StreamPartType = "text_start"
136	// StreamPartTypeTextDelta represents text delta stream part type.
137	StreamPartTypeTextDelta StreamPartType = "text_delta"
138	// StreamPartTypeTextEnd represents text end stream part type.
139	StreamPartTypeTextEnd StreamPartType = "text_end"
140
141	// StreamPartTypeReasoningStart represents reasoning start stream part type.
142	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
143	// StreamPartTypeReasoningDelta represents reasoning delta stream part type.
144	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
145	// StreamPartTypeReasoningEnd represents reasoning end stream part type.
146	StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
147	// StreamPartTypeToolInputStart represents tool input start stream part type.
148	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
149	// StreamPartTypeToolInputDelta represents tool input delta stream part type.
150	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
151	// StreamPartTypeToolInputEnd represents tool input end stream part type.
152	StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
153	// StreamPartTypeToolCall represents tool call stream part type.
154	StreamPartTypeToolCall StreamPartType = "tool_call"
155	// StreamPartTypeToolResult represents tool result stream part type.
156	StreamPartTypeToolResult StreamPartType = "tool_result"
157	// StreamPartTypeSource represents source stream part type.
158	StreamPartTypeSource StreamPartType = "source"
159	// StreamPartTypeFinish represents finish stream part type.
160	StreamPartTypeFinish StreamPartType = "finish"
161	// StreamPartTypeError represents error stream part type.
162	StreamPartTypeError StreamPartType = "error"
163)
164
165// StreamPart represents a part of a streaming response.
166type StreamPart struct {
167	Type             StreamPartType `json:"type"`
168	ID               string         `json:"id"`
169	ToolCallName     string         `json:"tool_call_name"`
170	ToolCallInput    string         `json:"tool_call_input"`
171	Delta            string         `json:"delta"`
172	ProviderExecuted bool           `json:"provider_executed"`
173	Usage            Usage          `json:"usage"`
174	FinishReason     FinishReason   `json:"finish_reason"`
175	Error            error          `json:"error"`
176	Warnings         []CallWarning  `json:"warnings"`
177
178	// Source-related fields
179	SourceType SourceType `json:"source_type"`
180	URL        string     `json:"url"`
181	Title      string     `json:"title"`
182
183	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
184}
185
186// StreamResponse represents a streaming response sequence.
187type StreamResponse = iter.Seq[StreamPart]
188
189// ToolChoice represents the tool choice preference for a model call.
190type ToolChoice string
191
192const (
193	// ToolChoiceNone indicates no tools should be used.
194	ToolChoiceNone ToolChoice = "none"
195	// ToolChoiceAuto indicates tools should be used automatically.
196	ToolChoiceAuto ToolChoice = "auto"
197	// ToolChoiceRequired indicates tools are required.
198	ToolChoiceRequired ToolChoice = "required"
199)
200
201// SpecificToolChoice creates a tool choice for a specific tool name.
202func SpecificToolChoice(name string) ToolChoice {
203	return ToolChoice(name)
204}
205
206// Call represents a call to a language model.
207type Call struct {
208	Prompt           Prompt      `json:"prompt"`
209	MaxOutputTokens  *int64      `json:"max_output_tokens"`
210	Temperature      *float64    `json:"temperature"`
211	TopP             *float64    `json:"top_p"`
212	TopK             *int64      `json:"top_k"`
213	PresencePenalty  *float64    `json:"presence_penalty"`
214	FrequencyPenalty *float64    `json:"frequency_penalty"`
215	Tools            []Tool      `json:"tools"`
216	ToolChoice       *ToolChoice `json:"tool_choice"`
217
218	// for provider specific options, the key is the provider id
219	ProviderOptions ProviderOptions `json:"provider_options"`
220}
221
222// CallWarningType represents the type of call warning.
223// CallWarningType represents the type of call warning.
224type CallWarningType string
225
226const (
227	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
228	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
229	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
230	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
231	// CallWarningTypeOther indicates other warnings.
232	CallWarningTypeOther CallWarningType = "other"
233)
234
235// CallWarning represents a warning from the model provider for this call.
236// The call will proceed, but e.g. some settings might not be supported,
237// which can lead to suboptimal results.
238type CallWarning struct {
239	Type    CallWarningType `json:"type"`
240	Setting string          `json:"setting"`
241	Tool    Tool            `json:"tool"`
242	Details string          `json:"details"`
243	Message string          `json:"message"`
244}
245
246// LanguageModel represents a language model that can generate responses and stream responses.
247type LanguageModel interface {
248	Generate(context.Context, Call) (*Response, error)
249	Stream(context.Context, Call) (StreamResponse, error)
250
251	Provider() string
252	Model() string
253}