1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventContentStop EventType = "content_stop"
25 EventComplete EventType = "complete"
26 EventError EventType = "error"
27 EventWarning EventType = "warning"
28)
29
30type TokenUsage struct {
31 InputTokens int64
32 OutputTokens int64
33 CacheCreationTokens int64
34 CacheReadTokens int64
35}
36
37type ProviderResponse struct {
38 Content string
39 ToolCalls []message.ToolCall
40 Usage TokenUsage
41 FinishReason message.FinishReason
42}
43
44type ProviderEvent struct {
45 Type EventType
46
47 Content string
48 Thinking string
49 Response *ProviderResponse
50 ToolCall *message.ToolCall
51 Error error
52}
53type Provider interface {
54 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
55
56 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
57
58 Model() config.Model
59}
60
61type providerClientOptions struct {
62 baseURL string
63 apiKey string
64 modelType config.ModelType
65 model func(config.ModelType) config.Model
66 disableCache bool
67 systemMessage string
68 maxTokens int64
69 extraHeaders map[string]string
70 extraParams map[string]string
71}
72
73type ProviderClientOption func(*providerClientOptions)
74
75type ProviderClient interface {
76 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
77 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
78
79 Model() config.Model
80}
81
82type baseProvider[C ProviderClient] struct {
83 options providerClientOptions
84 client C
85}
86
87func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
88 for _, msg := range messages {
89 // The message has no content
90 if len(msg.Parts) == 0 {
91 continue
92 }
93 cleaned = append(cleaned, msg)
94 }
95 return
96}
97
98func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
99 messages = p.cleanMessages(messages)
100 return p.client.send(ctx, messages, tools)
101}
102
103func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
104 messages = p.cleanMessages(messages)
105 return p.client.stream(ctx, messages, tools)
106}
107
108func (p *baseProvider[C]) Model() config.Model {
109 return p.client.Model()
110}
111
112func WithModel(model config.ModelType) ProviderClientOption {
113 return func(options *providerClientOptions) {
114 options.modelType = model
115 }
116}
117
118func WithDisableCache(disableCache bool) ProviderClientOption {
119 return func(options *providerClientOptions) {
120 options.disableCache = disableCache
121 }
122}
123
124func WithSystemMessage(systemMessage string) ProviderClientOption {
125 return func(options *providerClientOptions) {
126 options.systemMessage = systemMessage
127 }
128}
129
130func WithMaxTokens(maxTokens int64) ProviderClientOption {
131 return func(options *providerClientOptions) {
132 options.maxTokens = maxTokens
133 }
134}
135
136func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
137 clientOptions := providerClientOptions{
138 baseURL: cfg.BaseURL,
139 apiKey: cfg.APIKey,
140 extraHeaders: cfg.ExtraHeaders,
141 model: func(tp config.ModelType) config.Model {
142 return config.GetModel(tp)
143 },
144 }
145 for _, o := range opts {
146 o(&clientOptions)
147 }
148 switch cfg.ProviderType {
149 case provider.TypeAnthropic:
150 return &baseProvider[AnthropicClient]{
151 options: clientOptions,
152 client: newAnthropicClient(clientOptions, false),
153 }, nil
154 case provider.TypeOpenAI:
155 return &baseProvider[OpenAIClient]{
156 options: clientOptions,
157 client: newOpenAIClient(clientOptions),
158 }, nil
159 case provider.TypeGemini:
160 return &baseProvider[GeminiClient]{
161 options: clientOptions,
162 client: newGeminiClient(clientOptions),
163 }, nil
164 case provider.TypeBedrock:
165 return &baseProvider[BedrockClient]{
166 options: clientOptions,
167 client: newBedrockClient(clientOptions),
168 }, nil
169 case provider.TypeAzure:
170 return &baseProvider[AzureClient]{
171 options: clientOptions,
172 client: newAzureClient(clientOptions),
173 }, nil
174 case provider.TypeVertexAI:
175 return &baseProvider[VertexAIClient]{
176 options: clientOptions,
177 client: newVertexAIClient(clientOptions),
178 }, nil
179 case provider.TypeXAI:
180 clientOptions.baseURL = "https://api.x.ai/v1"
181 return &baseProvider[OpenAIClient]{
182 options: clientOptions,
183 client: newOpenAIClient(clientOptions),
184 }, nil
185 }
186 return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
187}