provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/opencode-ai/opencode/internal/llm/models"
  8	"github.com/opencode-ai/opencode/internal/llm/tools"
  9	"github.com/opencode-ai/opencode/internal/message"
 10)
 11
 12type EventType string
 13
 14const maxRetries = 8
 15
 16const (
 17	EventContentStart  EventType = "content_start"
 18	EventToolUseStart  EventType = "tool_use_start"
 19	EventToolUseDelta  EventType = "tool_use_delta"
 20	EventToolUseStop   EventType = "tool_use_stop"
 21	EventContentDelta  EventType = "content_delta"
 22	EventThinkingDelta EventType = "thinking_delta"
 23	EventContentStop   EventType = "content_stop"
 24	EventComplete      EventType = "complete"
 25	EventError         EventType = "error"
 26	EventWarning       EventType = "warning"
 27)
 28
 29type TokenUsage struct {
 30	InputTokens         int64
 31	OutputTokens        int64
 32	CacheCreationTokens int64
 33	CacheReadTokens     int64
 34}
 35
 36type ProviderResponse struct {
 37	Content      string
 38	ToolCalls    []message.ToolCall
 39	Usage        TokenUsage
 40	FinishReason message.FinishReason
 41}
 42
 43type ProviderEvent struct {
 44	Type EventType
 45
 46	Content  string
 47	Thinking string
 48	Response *ProviderResponse
 49	ToolCall *message.ToolCall
 50	Error    error
 51}
 52type Provider interface {
 53	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 54
 55	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 56
 57	Model() models.Model
 58}
 59
 60type providerClientOptions struct {
 61	apiKey        string
 62	model         models.Model
 63	maxTokens     int64
 64	systemMessage string
 65
 66	anthropicOptions []AnthropicOption
 67	openaiOptions    []OpenAIOption
 68	geminiOptions    []GeminiOption
 69	bedrockOptions   []BedrockOption
 70}
 71
 72type ProviderClientOption func(*providerClientOptions)
 73
 74type ProviderClient interface {
 75	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 76	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 77}
 78
 79type baseProvider[C ProviderClient] struct {
 80	options providerClientOptions
 81	client  C
 82}
 83
 84func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
 85	clientOptions := providerClientOptions{}
 86	for _, o := range opts {
 87		o(&clientOptions)
 88	}
 89	switch providerName {
 90	case models.ProviderAnthropic:
 91		return &baseProvider[AnthropicClient]{
 92			options: clientOptions,
 93			client:  newAnthropicClient(clientOptions),
 94		}, nil
 95	case models.ProviderOpenAI:
 96		return &baseProvider[OpenAIClient]{
 97			options: clientOptions,
 98			client:  newOpenAIClient(clientOptions),
 99		}, nil
100	case models.ProviderGemini:
101		return &baseProvider[GeminiClient]{
102			options: clientOptions,
103			client:  newGeminiClient(clientOptions),
104		}, nil
105	case models.ProviderBedrock:
106		return &baseProvider[BedrockClient]{
107			options: clientOptions,
108			client:  newBedrockClient(clientOptions),
109		}, nil
110	case models.ProviderGROQ:
111		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
112			WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
113		)
114		return &baseProvider[OpenAIClient]{
115			options: clientOptions,
116			client:  newOpenAIClient(clientOptions),
117		}, nil
118	case models.ProviderAzure:
119		return &baseProvider[AzureClient]{
120			options: clientOptions,
121			client:  newAzureClient(clientOptions),
122		}, nil
123	case models.ProviderOpenRouter:
124		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
125			WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
126			WithOpenAIExtraHeaders(map[string]string{
127				"HTTP-Referer": "opencode.ai",
128				"X-Title":      "OpenCode",
129			}),
130		)
131		return &baseProvider[OpenAIClient]{
132			options: clientOptions,
133			client:  newOpenAIClient(clientOptions),
134		}, nil
135	case models.ProviderXAI:
136		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
137			WithOpenAIBaseURL("https://api.x.ai/v1"),
138		)
139		return &baseProvider[OpenAIClient]{
140			options: clientOptions,
141			client:  newOpenAIClient(clientOptions),
142		}, nil
143
144	case models.ProviderMock:
145		// TODO: implement mock client for test
146		panic("not implemented")
147	}
148	return nil, fmt.Errorf("provider not supported: %s", providerName)
149}
150
151func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
152	for _, msg := range messages {
153		// The message has no content
154		if len(msg.Parts) == 0 {
155			continue
156		}
157		cleaned = append(cleaned, msg)
158	}
159	return
160}
161
162func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
163	messages = p.cleanMessages(messages)
164	return p.client.send(ctx, messages, tools)
165}
166
167func (p *baseProvider[C]) Model() models.Model {
168	return p.options.model
169}
170
171func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
172	messages = p.cleanMessages(messages)
173	return p.client.stream(ctx, messages, tools)
174}
175
176func WithAPIKey(apiKey string) ProviderClientOption {
177	return func(options *providerClientOptions) {
178		options.apiKey = apiKey
179	}
180}
181
182func WithModel(model models.Model) ProviderClientOption {
183	return func(options *providerClientOptions) {
184		options.model = model
185	}
186}
187
188func WithMaxTokens(maxTokens int64) ProviderClientOption {
189	return func(options *providerClientOptions) {
190		options.maxTokens = maxTokens
191	}
192}
193
194func WithSystemMessage(systemMessage string) ProviderClientOption {
195	return func(options *providerClientOptions) {
196		options.systemMessage = systemMessage
197	}
198}
199
200func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
201	return func(options *providerClientOptions) {
202		options.anthropicOptions = anthropicOptions
203	}
204}
205
206func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
207	return func(options *providerClientOptions) {
208		options.openaiOptions = openaiOptions
209	}
210}
211
212func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
213	return func(options *providerClientOptions) {
214		options.geminiOptions = geminiOptions
215	}
216}
217
218func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
219	return func(options *providerClientOptions) {
220		options.bedrockOptions = bedrockOptions
221	}
222}