1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventContentStop EventType = "content_stop"
25 EventComplete EventType = "complete"
26 EventError EventType = "error"
27 EventWarning EventType = "warning"
28)
29
30type TokenUsage struct {
31 InputTokens int64
32 OutputTokens int64
33 CacheCreationTokens int64
34 CacheReadTokens int64
35}
36
37type ProviderResponse struct {
38 Content string
39 ToolCalls []message.ToolCall
40 Usage TokenUsage
41 FinishReason message.FinishReason
42}
43
44type ProviderEvent struct {
45 Type EventType
46
47 Content string
48 Thinking string
49 Response *ProviderResponse
50 ToolCall *message.ToolCall
51 Error error
52}
53type Provider interface {
54 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
55
56 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
57
58 Model() config.Model
59}
60
61type providerClientOptions struct {
62 baseURL string
63 apiKey string
64 modelType config.ModelType
65 model func(config.ModelType) config.Model
66 disableCache bool
67 systemMessage string
68 extraHeaders map[string]string
69 extraParams map[string]string
70}
71
72type ProviderClientOption func(*providerClientOptions)
73
74type ProviderClient interface {
75 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
76 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
77
78 Model() config.Model
79}
80
81type baseProvider[C ProviderClient] struct {
82 options providerClientOptions
83 client C
84}
85
86func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
87 for _, msg := range messages {
88 // The message has no content
89 if len(msg.Parts) == 0 {
90 continue
91 }
92 cleaned = append(cleaned, msg)
93 }
94 return
95}
96
97func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
98 messages = p.cleanMessages(messages)
99 return p.client.send(ctx, messages, tools)
100}
101
102func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
103 messages = p.cleanMessages(messages)
104 return p.client.stream(ctx, messages, tools)
105}
106
107func (p *baseProvider[C]) Model() config.Model {
108 return p.client.Model()
109}
110
111func WithModel(model config.ModelType) ProviderClientOption {
112 return func(options *providerClientOptions) {
113 options.modelType = model
114 }
115}
116
117func WithDisableCache(disableCache bool) ProviderClientOption {
118 return func(options *providerClientOptions) {
119 options.disableCache = disableCache
120 }
121}
122
123func WithSystemMessage(systemMessage string) ProviderClientOption {
124 return func(options *providerClientOptions) {
125 options.systemMessage = systemMessage
126 }
127}
128
129func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
130 clientOptions := providerClientOptions{
131 baseURL: cfg.BaseURL,
132 apiKey: cfg.APIKey,
133 extraHeaders: cfg.ExtraHeaders,
134 model: func(tp config.ModelType) config.Model {
135 return config.GetModel(tp)
136 },
137 }
138 for _, o := range opts {
139 o(&clientOptions)
140 }
141 switch cfg.ProviderType {
142 case provider.TypeAnthropic:
143 return &baseProvider[AnthropicClient]{
144 options: clientOptions,
145 client: newAnthropicClient(clientOptions, false),
146 }, nil
147 case provider.TypeOpenAI:
148 return &baseProvider[OpenAIClient]{
149 options: clientOptions,
150 client: newOpenAIClient(clientOptions),
151 }, nil
152 case provider.TypeGemini:
153 return &baseProvider[GeminiClient]{
154 options: clientOptions,
155 client: newGeminiClient(clientOptions),
156 }, nil
157 case provider.TypeBedrock:
158 return &baseProvider[BedrockClient]{
159 options: clientOptions,
160 client: newBedrockClient(clientOptions),
161 }, nil
162 case provider.TypeAzure:
163 return &baseProvider[AzureClient]{
164 options: clientOptions,
165 client: newAzureClient(clientOptions),
166 }, nil
167 case provider.TypeVertexAI:
168 return &baseProvider[VertexAIClient]{
169 options: clientOptions,
170 client: newVertexAIClient(clientOptions),
171 }, nil
172 case provider.TypeXAI:
173 clientOptions.baseURL = "https://api.x.ai/v1"
174 return &baseProvider[OpenAIClient]{
175 options: clientOptions,
176 client: newOpenAIClient(clientOptions),
177 }, nil
178 }
179 return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
180}