model.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"fmt"
  6	"iter"
  7	"strings"
  8)
  9
 10// Usage represents token usage statistics for a model call.
 11type Usage struct {
 12	InputTokens         int64 `json:"input_tokens"`
 13	OutputTokens        int64 `json:"output_tokens"`
 14	TotalTokens         int64 `json:"total_tokens"`
 15	ReasoningTokens     int64 `json:"reasoning_tokens"`
 16	CacheCreationTokens int64 `json:"cache_creation_tokens"`
 17	CacheReadTokens     int64 `json:"cache_read_tokens"`
 18}
 19
 20func (u Usage) String() string {
 21	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
 22		u.InputTokens,
 23		u.OutputTokens,
 24		u.TotalTokens,
 25		u.ReasoningTokens,
 26		u.CacheCreationTokens,
 27		u.CacheReadTokens,
 28	)
 29}
 30
 31// ResponseContent represents the content of a model response.
 32type ResponseContent []Content
 33
 34// Text returns the text content of the response.
 35func (r ResponseContent) Text() string {
 36	for _, c := range r {
 37		if c.GetType() == ContentTypeText {
 38			if textContent, ok := AsContentType[TextContent](c); ok {
 39				return textContent.Text
 40			}
 41		}
 42	}
 43	return ""
 44}
 45
 46// Reasoning returns all reasoning content parts.
 47func (r ResponseContent) Reasoning() []ReasoningContent {
 48	var reasoning []ReasoningContent
 49	for _, c := range r {
 50		if c.GetType() == ContentTypeReasoning {
 51			if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
 52				reasoning = append(reasoning, reasoningContent)
 53			}
 54		}
 55	}
 56	return reasoning
 57}
 58
 59// ReasoningText returns all reasoning content as a concatenated string.
 60func (r ResponseContent) ReasoningText() string {
 61	var builder strings.Builder
 62	for _, reasoning := range r.Reasoning() {
 63		builder.WriteString(reasoning.Text)
 64	}
 65	return builder.String()
 66}
 67
 68// Files returns all file content parts.
 69func (r ResponseContent) Files() []FileContent {
 70	var files []FileContent
 71	for _, c := range r {
 72		if c.GetType() == ContentTypeFile {
 73			if fileContent, ok := AsContentType[FileContent](c); ok {
 74				files = append(files, fileContent)
 75			}
 76		}
 77	}
 78	return files
 79}
 80
 81// Sources returns all source content parts.
 82func (r ResponseContent) Sources() []SourceContent {
 83	var sources []SourceContent
 84	for _, c := range r {
 85		if c.GetType() == ContentTypeSource {
 86			if sourceContent, ok := AsContentType[SourceContent](c); ok {
 87				sources = append(sources, sourceContent)
 88			}
 89		}
 90	}
 91	return sources
 92}
 93
 94// ToolCalls returns all tool call content parts.
 95func (r ResponseContent) ToolCalls() []ToolCallContent {
 96	var toolCalls []ToolCallContent
 97	for _, c := range r {
 98		if c.GetType() == ContentTypeToolCall {
 99			if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
100				toolCalls = append(toolCalls, toolCallContent)
101			}
102		}
103	}
104	return toolCalls
105}
106
107// ToolResults returns all tool result content parts.
108func (r ResponseContent) ToolResults() []ToolResultContent {
109	var toolResults []ToolResultContent
110	for _, c := range r {
111		if c.GetType() == ContentTypeToolResult {
112			if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
113				toolResults = append(toolResults, toolResultContent)
114			}
115		}
116	}
117	return toolResults
118}
119
120// Response represents a response from a language model.
121type Response struct {
122	Content      ResponseContent `json:"content"`
123	FinishReason FinishReason    `json:"finish_reason"`
124	Usage        Usage           `json:"usage"`
125	Warnings     []CallWarning   `json:"warnings"`
126
127	// for provider specific response metadata, the key is the provider id
128	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
129}
130
131// StreamPartType represents the type of a stream part.
132type StreamPartType string
133
134const (
135	// StreamPartTypeWarnings represents warnings stream part type.
136	StreamPartTypeWarnings StreamPartType = "warnings"
137	// StreamPartTypeTextStart represents text start stream part type.
138	StreamPartTypeTextStart StreamPartType = "text_start"
139	// StreamPartTypeTextDelta represents text delta stream part type.
140	StreamPartTypeTextDelta StreamPartType = "text_delta"
141	// StreamPartTypeTextEnd represents text end stream part type.
142	StreamPartTypeTextEnd StreamPartType = "text_end"
143
144	// StreamPartTypeReasoningStart represents reasoning start stream part type.
145	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
146	// StreamPartTypeReasoningDelta represents reasoning delta stream part type.
147	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
148	// StreamPartTypeReasoningEnd represents reasoning end stream part type.
149	StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
150	// StreamPartTypeToolInputStart represents tool input start stream part type.
151	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
152	// StreamPartTypeToolInputDelta represents tool input delta stream part type.
153	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
154	// StreamPartTypeToolInputEnd represents tool input end stream part type.
155	StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
156	// StreamPartTypeToolCall represents tool call stream part type.
157	StreamPartTypeToolCall StreamPartType = "tool_call"
158	// StreamPartTypeToolResult represents tool result stream part type.
159	StreamPartTypeToolResult StreamPartType = "tool_result"
160	// StreamPartTypeSource represents source stream part type.
161	StreamPartTypeSource StreamPartType = "source"
162	// StreamPartTypeFinish represents finish stream part type.
163	StreamPartTypeFinish StreamPartType = "finish"
164	// StreamPartTypeError represents error stream part type.
165	StreamPartTypeError StreamPartType = "error"
166)
167
168// StreamPart represents a part of a streaming response.
169type StreamPart struct {
170	Type             StreamPartType `json:"type"`
171	ID               string         `json:"id"`
172	ToolCallName     string         `json:"tool_call_name"`
173	ToolCallInput    string         `json:"tool_call_input"`
174	Delta            string         `json:"delta"`
175	ProviderExecuted bool           `json:"provider_executed"`
176	Usage            Usage          `json:"usage"`
177	FinishReason     FinishReason   `json:"finish_reason"`
178	Error            error          `json:"error"`
179	Warnings         []CallWarning  `json:"warnings"`
180
181	// Source-related fields
182	SourceType SourceType `json:"source_type"`
183	URL        string     `json:"url"`
184	Title      string     `json:"title"`
185
186	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
187}
188
189// StreamResponse represents a streaming response sequence.
190type StreamResponse = iter.Seq[StreamPart]
191
192// ToolChoice represents the tool choice preference for a model call.
193type ToolChoice string
194
195const (
196	// ToolChoiceNone indicates no tools should be used.
197	ToolChoiceNone ToolChoice = "none"
198	// ToolChoiceAuto indicates tools should be used automatically.
199	ToolChoiceAuto ToolChoice = "auto"
200	// ToolChoiceRequired indicates tools are required.
201	ToolChoiceRequired ToolChoice = "required"
202)
203
204// SpecificToolChoice creates a tool choice for a specific tool name.
205func SpecificToolChoice(name string) ToolChoice {
206	return ToolChoice(name)
207}
208
209// Call represents a call to a language model.
210type Call struct {
211	Prompt           Prompt      `json:"prompt"`
212	MaxOutputTokens  *int64      `json:"max_output_tokens"`
213	Temperature      *float64    `json:"temperature"`
214	TopP             *float64    `json:"top_p"`
215	TopK             *int64      `json:"top_k"`
216	PresencePenalty  *float64    `json:"presence_penalty"`
217	FrequencyPenalty *float64    `json:"frequency_penalty"`
218	Tools            []Tool      `json:"tools"`
219	ToolChoice       *ToolChoice `json:"tool_choice"`
220
221	// for provider specific options, the key is the provider id
222	ProviderOptions ProviderOptions `json:"provider_options"`
223}
224
225// CallWarningType represents the type of call warning.
226type CallWarningType string
227
228const (
229	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
230	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
231	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
232	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
233	// CallWarningTypeOther indicates other warnings.
234	CallWarningTypeOther CallWarningType = "other"
235)
236
237// CallWarning represents a warning from the model provider for this call.
238// The call will proceed, but e.g. some settings might not be supported,
239// which can lead to suboptimal results.
240type CallWarning struct {
241	Type    CallWarningType `json:"type"`
242	Setting string          `json:"setting"`
243	Tool    Tool            `json:"tool"`
244	Details string          `json:"details"`
245	Message string          `json:"message"`
246}
247
248// LanguageModel represents a language model that can generate responses and stream responses.
249type LanguageModel interface {
250	Generate(context.Context, Call) (*Response, error)
251	Stream(context.Context, Call) (StreamResponse, error)
252
253	GenerateObject(context.Context, ObjectCall) (*ObjectResponse, error)
254	StreamObject(context.Context, ObjectCall) (ObjectStreamResponse, error)
255
256	Provider() string
257	Model() string
258}