model.go

  1package ai
  2
  3import (
  4	"context"
  5	"fmt"
  6	"iter"
  7)
  8
  9type Usage struct {
 10	InputTokens         int64 `json:"input_tokens"`
 11	OutputTokens        int64 `json:"output_tokens"`
 12	TotalTokens         int64 `json:"total_tokens"`
 13	ReasoningTokens     int64 `json:"reasoning_tokens"`
 14	CacheCreationTokens int64 `json:"cache_creation_tokens"`
 15	CacheReadTokens     int64 `json:"cache_read_tokens"`
 16}
 17
 18func (u Usage) String() string {
 19	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
 20		u.InputTokens,
 21		u.OutputTokens,
 22		u.TotalTokens,
 23		u.ReasoningTokens,
 24		u.CacheCreationTokens,
 25		u.CacheReadTokens,
 26	)
 27}
 28
 29type ResponseContent []Content
 30
 31func (r ResponseContent) Text() string {
 32	for _, c := range r {
 33		if c.GetType() == ContentTypeText {
 34			return c.(TextContent).Text
 35		}
 36	}
 37	return ""
 38}
 39
 40type Response struct {
 41	Content      ResponseContent `json:"content"`
 42	FinishReason FinishReason    `json:"finish_reason"`
 43	Usage        Usage           `json:"usage"`
 44	Warnings     []CallWarning   `json:"warnings"`
 45
 46	// for provider specific response metadata, the key is the provider id
 47	ProviderMetadata map[string]map[string]any `json:"provider_metadata"`
 48}
 49
 50type StreamPartType string
 51
 52const (
 53	StreamPartTypeWarnings  StreamPartType = "warnings"
 54	StreamPartTypeTextStart StreamPartType = "text_start"
 55	StreamPartTypeTextDelta StreamPartType = "text_delta"
 56	StreamPartTypeTextEnd   StreamPartType = "text_end"
 57
 58	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
 59	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
 60	StreamPartTypeReasoningEnd   StreamPartType = "reasoning_end"
 61	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
 62	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
 63	StreamPartTypeToolInputEnd   StreamPartType = "tool_input_end"
 64	StreamPartTypeToolCall       StreamPartType = "tool_call"
 65	StreamPartTypeToolResult     StreamPartType = "tool_result"
 66	StreamPartTypeSource         StreamPartType = "source"
 67	StreamPartTypeFinish         StreamPartType = "finish"
 68	StreamPartTypeError          StreamPartType = "error"
 69)
 70
 71type StreamPart struct {
 72	Type             StreamPartType `json:"type"`
 73	ID               string         `json:"id"`
 74	ToolCallName     string         `json:"tool_call_name"`
 75	ToolCallInput    string         `json:"tool_call_input"`
 76	Delta            string         `json:"delta"`
 77	ProviderExecuted bool           `json:"provider_executed"`
 78	Usage            Usage          `json:"usage"`
 79	FinishReason     FinishReason   `json:"finish_reason"`
 80	Error            error          `json:"error"`
 81	Warnings         []CallWarning  `json:"warnings"`
 82
 83	// Source-related fields
 84	SourceType SourceType `json:"source_type"`
 85	URL        string     `json:"url"`
 86	Title      string     `json:"title"`
 87
 88	ProviderMetadata ProviderOptions `json:"provider_metadata"`
 89}
 90type StreamResponse = iter.Seq[StreamPart]
 91
 92type ToolChoice string
 93
 94const (
 95	ToolChoiceNone ToolChoice = "none"
 96	ToolChoiceAuto ToolChoice = "auto"
 97)
 98
 99func SpecificToolChoice(name string) ToolChoice {
100	return ToolChoice(name)
101}
102
103type Call struct {
104	Prompt           Prompt            `json:"prompt"`
105	MaxOutputTokens  *int64            `json:"max_output_tokens"`
106	Temperature      *float64          `json:"temperature"`
107	TopP             *float64          `json:"top_p"`
108	TopK             *int64            `json:"top_k"`
109	PresencePenalty  *float64          `json:"presence_penalty"`
110	FrequencyPenalty *float64          `json:"frequency_penalty"`
111	Tools            []Tool            `json:"tools"`
112	ToolChoice       *ToolChoice       `json:"tool_choice"`
113	Headers          map[string]string `json:"headers"`
114
115	// for provider specific options, the key is the provider id
116	ProviderOptions ProviderOptions `json:"provider_options"`
117}
118
119// CallWarningType represents the type of call warning.
120type CallWarningType string
121
122const (
123	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
124	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
125	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
126	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
127	// CallWarningTypeOther indicates other warnings.
128	CallWarningTypeOther CallWarningType = "other"
129)
130
131// CallWarning represents a warning from the model provider for this call.
132// The call will proceed, but e.g. some settings might not be supported,
133// which can lead to suboptimal results.
134type CallWarning struct {
135	Type    CallWarningType `json:"type"`
136	Setting string          `json:"setting"`
137	Tool    Tool            `json:"tool"`
138	Details string          `json:"details"`
139	Message string          `json:"message"`
140}
141
142type LanguageModel interface {
143	Generate(context.Context, Call) (*Response, error)
144	Stream(context.Context, Call) (StreamResponse, error)
145
146	Provider() string
147	Model() string
148}