1package ai
2
3import (
4 "context"
5 "fmt"
6 "iter"
7)
8
9type Usage struct {
10 InputTokens int64 `json:"input_tokens"`
11 OutputTokens int64 `json:"output_tokens"`
12 TotalTokens int64 `json:"total_tokens"`
13 ReasoningTokens int64 `json:"reasoning_tokens"`
14 CacheCreationTokens int64 `json:"cache_creation_tokens"`
15 CacheReadTokens int64 `json:"cache_read_tokens"`
16}
17
18func (u Usage) String() string {
19 return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
20 u.InputTokens,
21 u.OutputTokens,
22 u.TotalTokens,
23 u.ReasoningTokens,
24 u.CacheCreationTokens,
25 u.CacheReadTokens,
26 )
27}
28
29type ResponseContent []Content
30
31func (r ResponseContent) Text() string {
32 for _, c := range r {
33 if c.GetType() == ContentTypeText {
34 return c.(TextContent).Text
35 }
36 }
37 return ""
38}
39
40type Response struct {
41 Content ResponseContent `json:"content"`
42 FinishReason FinishReason `json:"finish_reason"`
43 Usage Usage `json:"usage"`
44 Warnings []CallWarning `json:"warnings"`
45
46 // for provider specific response metadata, the key is the provider id
47 ProviderMetadata map[string]map[string]any `json:"provider_metadata"`
48}
49
50type StreamPartType string
51
52const (
53 StreamPartTypeWarnings StreamPartType = "warnings"
54 StreamPartTypeTextStart StreamPartType = "text_start"
55 StreamPartTypeTextDelta StreamPartType = "text_delta"
56 StreamPartTypeTextEnd StreamPartType = "text_end"
57
58 StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
59 StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
60 StreamPartTypeReasoningEnd StreamPartType = "reasoning_end"
61 StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
62 StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
63 StreamPartTypeToolInputEnd StreamPartType = "tool_input_end"
64 StreamPartTypeToolCall StreamPartType = "tool_call"
65 StreamPartTypeToolResult StreamPartType = "tool_result"
66 StreamPartTypeSource StreamPartType = "source"
67 StreamPartTypeFinish StreamPartType = "finish"
68 StreamPartTypeError StreamPartType = "error"
69)
70
71type StreamPart struct {
72 Type StreamPartType `json:"type"`
73 ID string `json:"id"`
74 ToolCallName string `json:"tool_call_name"`
75 ToolCallInput string `json:"tool_call_input"`
76 Delta string `json:"delta"`
77 ProviderExecuted bool `json:"provider_executed"`
78 Usage Usage `json:"usage"`
79 FinishReason FinishReason `json:"finish_reason"`
80 Error error `json:"error"`
81 Warnings []CallWarning `json:"warnings"`
82
83 // Source-related fields
84 SourceType SourceType `json:"source_type"`
85 URL string `json:"url"`
86 Title string `json:"title"`
87
88 ProviderMetadata ProviderOptions `json:"provider_metadata"`
89}
90type StreamResponse = iter.Seq[StreamPart]
91
92type ToolChoice string
93
94const (
95 ToolChoiceNone ToolChoice = "none"
96 ToolChoiceAuto ToolChoice = "auto"
97)
98
99func SpecificToolChoice(name string) ToolChoice {
100 return ToolChoice(name)
101}
102
103type Call struct {
104 Prompt Prompt `json:"prompt"`
105 MaxOutputTokens *int64 `json:"max_output_tokens"`
106 Temperature *float64 `json:"temperature"`
107 TopP *float64 `json:"top_p"`
108 TopK *int64 `json:"top_k"`
109 PresencePenalty *float64 `json:"presence_penalty"`
110 FrequencyPenalty *float64 `json:"frequency_penalty"`
111 Tools []Tool `json:"tools"`
112 ToolChoice *ToolChoice `json:"tool_choice"`
113 Headers map[string]string `json:"headers"`
114
115 // for provider specific options, the key is the provider id
116 ProviderOptions ProviderOptions `json:"provider_options"`
117}
118
119// CallWarningType represents the type of call warning.
120type CallWarningType string
121
122const (
123 // CallWarningTypeUnsupportedSetting indicates an unsupported setting.
124 CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
125 // CallWarningTypeUnsupportedTool indicates an unsupported tool.
126 CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
127 // CallWarningTypeOther indicates other warnings.
128 CallWarningTypeOther CallWarningType = "other"
129)
130
131// CallWarning represents a warning from the model provider for this call.
132// The call will proceed, but e.g. some settings might not be supported,
133// which can lead to suboptimal results.
134type CallWarning struct {
135 Type CallWarningType `json:"type"`
136 Setting string `json:"setting"`
137 Tool Tool `json:"tool"`
138 Details string `json:"details"`
139 Message string `json:"message"`
140}
141
142type LanguageModel interface {
143 Generate(context.Context, Call) (*Response, error)
144 Stream(context.Context, Call) (StreamResponse, error)
145
146 Provider() string
147 Model() string
148}