provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"time"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12	"github.com/charmbracelet/crush/internal/resolver"
 13)
 14
 15type EventType string
 16
 17const maxRetries = 8
 18
 19const (
 20	EventContentStart   EventType = "content_start"
 21	EventToolUseStart   EventType = "tool_use_start"
 22	EventToolUseDelta   EventType = "tool_use_delta"
 23	EventToolUseStop    EventType = "tool_use_stop"
 24	EventContentDelta   EventType = "content_delta"
 25	EventThinkingDelta  EventType = "thinking_delta"
 26	EventSignatureDelta EventType = "signature_delta"
 27	EventContentStop    EventType = "content_stop"
 28	EventComplete       EventType = "complete"
 29	EventError          EventType = "error"
 30	EventWarning        EventType = "warning"
 31)
 32
 33type TokenUsage struct {
 34	InputTokens         int64
 35	OutputTokens        int64
 36	CacheCreationTokens int64
 37	CacheReadTokens     int64
 38}
 39
 40type ProviderResponse struct {
 41	Content      string
 42	ToolCalls    []message.ToolCall
 43	Usage        TokenUsage
 44	FinishReason message.FinishReason
 45}
 46
 47type ProviderEvent struct {
 48	Type EventType
 49
 50	Content   string
 51	Thinking  string
 52	Signature string
 53	Response  *ProviderResponse
 54	ToolCall  *message.ToolCall
 55	Error     error
 56}
 57
 58type Config struct {
 59	// The provider's id.
 60	ID string `json:"id,omitempty"`
 61	// The provider's name, used for display purposes.
 62	Name string `json:"name,omitempty"`
 63	// The provider's API endpoint.
 64	BaseURL string `json:"base_url,omitempty"`
 65	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 66	Type catwalk.Type `json:"type,omitempty"`
 67	// The provider's API key.
 68	APIKey string `json:"api_key,omitempty"`
 69	// Marks the provider as disabled.
 70	Disable bool `json:"disable,omitempty"`
 71
 72	// Custom system prompt prefix.
 73	SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
 74
 75	// Extra headers to send with each request to the provider.
 76	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 77	// Extra body
 78	ExtraBody map[string]any `json:"extra_body,omitempty"`
 79
 80	// Used to pass extra parameters to the provider.
 81	ExtraParams map[string]string `json:"-"`
 82
 83	// The provider models
 84	Models []catwalk.Model `json:"models,omitempty"`
 85}
 86
 87type Provider interface {
 88	Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 89
 90	Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 91
 92	Model(modelID string) *catwalk.Model
 93
 94	SetDebug(debug bool)
 95}
 96
 97type baseProvider struct {
 98	baseURL            string
 99	debug              bool
100	config             Config
101	apiKey             string
102	disableCache       bool
103	systemMessage      string
104	systemPromptPrefix string
105	maxTokens          int64
106	think              bool
107	reasoningEffort    string
108	resolver           resolver.Resolver
109	extraHeaders       map[string]string
110	extraBody          map[string]any
111	extraParams        map[string]string
112}
113
114type Option func(*baseProvider)
115
116func WithDisableCache(disableCache bool) Option {
117	return func(options *baseProvider) {
118		options.disableCache = disableCache
119	}
120}
121
122func WithSystemMessage(systemMessage string) Option {
123	return func(options *baseProvider) {
124		options.systemMessage = systemMessage
125	}
126}
127
128func WithMaxTokens(maxTokens int64) Option {
129	return func(options *baseProvider) {
130		options.maxTokens = maxTokens
131	}
132}
133
134func WithThinking(think bool) Option {
135	return func(options *baseProvider) {
136		options.think = think
137	}
138}
139
140func WithReasoningEffort(reasoningEffort string) Option {
141	return func(options *baseProvider) {
142		options.reasoningEffort = reasoningEffort
143	}
144}
145
146func WithDebug(debug bool) Option {
147	return func(options *baseProvider) {
148		options.debug = debug
149	}
150}
151
152func WithResolver(resolver resolver.Resolver) Option {
153	return func(options *baseProvider) {
154		options.resolver = resolver
155	}
156}
157
158func newBaseProvider(cfg Config, opts ...Option) (*baseProvider, error) {
159	provider := &baseProvider{
160		baseURL:            cfg.BaseURL,
161		config:             cfg,
162		apiKey:             cfg.APIKey,
163		extraHeaders:       cfg.ExtraHeaders,
164		extraBody:          cfg.ExtraBody,
165		systemPromptPrefix: cfg.SystemPromptPrefix,
166		resolver:           resolver.New(),
167	}
168	for _, o := range opts {
169		o(provider)
170	}
171
172	resolvedAPIKey, err := provider.resolver.ResolveValue(cfg.APIKey)
173	if err != nil {
174		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
175	}
176
177	resolvedBaseURL, err := provider.resolver.ResolveValue(cfg.BaseURL)
178	if err != nil {
179		resolvedBaseURL = ""
180	}
181	// Resolve extra headers
182	resolvedExtraHeaders := make(map[string]string)
183	for key, value := range cfg.ExtraHeaders {
184		resolvedValue, err := provider.resolver.ResolveValue(value)
185		if err != nil {
186			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
187		}
188		resolvedExtraHeaders[key] = resolvedValue
189	}
190
191	provider.apiKey = resolvedAPIKey
192	provider.baseURL = resolvedBaseURL
193	provider.extraHeaders = resolvedExtraHeaders
194	return provider, nil
195}
196
197func NewProvider(cfg Config, opts ...Option) (Provider, error) {
198	base, err := newBaseProvider(cfg, opts...)
199	if err != nil {
200		return nil, err
201	}
202	switch cfg.Type {
203	case catwalk.TypeAnthropic:
204		return NewAnthropicProvider(base, false), nil
205	case catwalk.TypeOpenAI:
206		return NewOpenAIProvider(base), nil
207	case catwalk.TypeGemini:
208		return NewGeminiProvider(base), nil
209	case catwalk.TypeBedrock:
210		return NewBedrockProvider(base), nil
211	case catwalk.TypeAzure:
212		return NewAzureProvider(base), nil
213	case catwalk.TypeVertexAI:
214		return NewVertexAIProvider(base), nil
215	}
216	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
217}
218
219func (p *baseProvider) cleanMessages(messages []message.Message) (cleaned []message.Message) {
220	for _, msg := range messages {
221		// The message has no content
222		if len(msg.Parts) == 0 {
223			continue
224		}
225		cleaned = append(cleaned, msg)
226	}
227	return
228}
229
230func (o *baseProvider) Model(model string) *catwalk.Model {
231	for _, m := range o.config.Models {
232		if m.ID == model {
233			return &m
234		}
235	}
236	return nil
237}
238
239func (o *baseProvider) SetDebug(debug bool) {
240	o.debug = debug
241}
242
243func (c *Config) TestConnection(resolver resolver.Resolver) error {
244	testURL := ""
245	headers := make(map[string]string)
246	apiKey, _ := resolver.ResolveValue(c.APIKey)
247	switch c.Type {
248	case catwalk.TypeOpenAI:
249		baseURL, _ := resolver.ResolveValue(c.BaseURL)
250		if baseURL == "" {
251			baseURL = "https://api.openai.com/v1"
252		}
253		testURL = baseURL + "/models"
254		headers["Authorization"] = "Bearer " + apiKey
255	case catwalk.TypeAnthropic:
256		baseURL, _ := resolver.ResolveValue(c.BaseURL)
257		if baseURL == "" {
258			baseURL = "https://api.anthropic.com/v1"
259		}
260		testURL = baseURL + "/models"
261		headers["x-api-key"] = apiKey
262		headers["anthropic-version"] = "2023-06-01"
263	}
264	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
265	defer cancel()
266	client := &http.Client{}
267	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
268	if err != nil {
269		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
270	}
271	for k, v := range headers {
272		req.Header.Set(k, v)
273	}
274	for k, v := range c.ExtraHeaders {
275		req.Header.Set(k, v)
276	}
277	b, err := client.Do(req)
278	if err != nil {
279		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
280	}
281	if b.StatusCode != http.StatusOK {
282		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
283	}
284	_ = b.Body.Close()
285	return nil
286}