1package provider
2
3import (
4 "context"
5 "fmt"
6 "os"
7
8 "github.com/opencode-ai/opencode/internal/llm/models"
9 "github.com/opencode-ai/opencode/internal/llm/tools"
10 "github.com/opencode-ai/opencode/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventContentStop EventType = "content_stop"
25 EventComplete EventType = "complete"
26 EventError EventType = "error"
27 EventWarning EventType = "warning"
28)
29
30type TokenUsage struct {
31 InputTokens int64
32 OutputTokens int64
33 CacheCreationTokens int64
34 CacheReadTokens int64
35}
36
37type ProviderResponse struct {
38 Content string
39 ToolCalls []message.ToolCall
40 Usage TokenUsage
41 FinishReason message.FinishReason
42}
43
44type ProviderEvent struct {
45 Type EventType
46
47 Content string
48 Thinking string
49 Response *ProviderResponse
50 ToolCall *message.ToolCall
51 Error error
52}
53type Provider interface {
54 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
55
56 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
57
58 Model() models.Model
59}
60
61type providerClientOptions struct {
62 apiKey string
63 model models.Model
64 maxTokens int64
65 systemMessage string
66
67 anthropicOptions []AnthropicOption
68 openaiOptions []OpenAIOption
69 geminiOptions []GeminiOption
70 bedrockOptions []BedrockOption
71 copilotOptions []CopilotOption
72}
73
74type ProviderClientOption func(*providerClientOptions)
75
76type ProviderClient interface {
77 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
78 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
79}
80
81type baseProvider[C ProviderClient] struct {
82 options providerClientOptions
83 client C
84}
85
86func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
87 clientOptions := providerClientOptions{}
88 for _, o := range opts {
89 o(&clientOptions)
90 }
91 switch providerName {
92 case models.ProviderCopilot:
93 return &baseProvider[CopilotClient]{
94 options: clientOptions,
95 client: newCopilotClient(clientOptions),
96 }, nil
97 case models.ProviderAnthropic:
98 return &baseProvider[AnthropicClient]{
99 options: clientOptions,
100 client: newAnthropicClient(clientOptions),
101 }, nil
102 case models.ProviderOpenAI:
103 return &baseProvider[OpenAIClient]{
104 options: clientOptions,
105 client: newOpenAIClient(clientOptions),
106 }, nil
107 case models.ProviderGemini:
108 return &baseProvider[GeminiClient]{
109 options: clientOptions,
110 client: newGeminiClient(clientOptions),
111 }, nil
112 case models.ProviderBedrock:
113 return &baseProvider[BedrockClient]{
114 options: clientOptions,
115 client: newBedrockClient(clientOptions),
116 }, nil
117 case models.ProviderGROQ:
118 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
119 WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
120 )
121 return &baseProvider[OpenAIClient]{
122 options: clientOptions,
123 client: newOpenAIClient(clientOptions),
124 }, nil
125 case models.ProviderAzure:
126 return &baseProvider[AzureClient]{
127 options: clientOptions,
128 client: newAzureClient(clientOptions),
129 }, nil
130 case models.ProviderVertexAI:
131 return &baseProvider[VertexAIClient]{
132 options: clientOptions,
133 client: newVertexAIClient(clientOptions),
134 }, nil
135 case models.ProviderOpenRouter:
136 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
137 WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
138 WithOpenAIExtraHeaders(map[string]string{
139 "HTTP-Referer": "opencode.ai",
140 "X-Title": "OpenCode",
141 }),
142 )
143 return &baseProvider[OpenAIClient]{
144 options: clientOptions,
145 client: newOpenAIClient(clientOptions),
146 }, nil
147 case models.ProviderXAI:
148 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
149 WithOpenAIBaseURL("https://api.x.ai/v1"),
150 )
151 return &baseProvider[OpenAIClient]{
152 options: clientOptions,
153 client: newOpenAIClient(clientOptions),
154 }, nil
155 case models.ProviderLocal:
156 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
157 WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
158 )
159 return &baseProvider[OpenAIClient]{
160 options: clientOptions,
161 client: newOpenAIClient(clientOptions),
162 }, nil
163 case models.ProviderMock:
164 // TODO: implement mock client for test
165 panic("not implemented")
166 }
167 return nil, fmt.Errorf("provider not supported: %s", providerName)
168}
169
170func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
171 for _, msg := range messages {
172 // The message has no content
173 if len(msg.Parts) == 0 {
174 continue
175 }
176 cleaned = append(cleaned, msg)
177 }
178 return
179}
180
181func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
182 messages = p.cleanMessages(messages)
183 return p.client.send(ctx, messages, tools)
184}
185
186func (p *baseProvider[C]) Model() models.Model {
187 return p.options.model
188}
189
190func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
191 messages = p.cleanMessages(messages)
192 return p.client.stream(ctx, messages, tools)
193}
194
195func WithAPIKey(apiKey string) ProviderClientOption {
196 return func(options *providerClientOptions) {
197 options.apiKey = apiKey
198 }
199}
200
201func WithModel(model models.Model) ProviderClientOption {
202 return func(options *providerClientOptions) {
203 options.model = model
204 }
205}
206
207func WithMaxTokens(maxTokens int64) ProviderClientOption {
208 return func(options *providerClientOptions) {
209 options.maxTokens = maxTokens
210 }
211}
212
213func WithSystemMessage(systemMessage string) ProviderClientOption {
214 return func(options *providerClientOptions) {
215 options.systemMessage = systemMessage
216 }
217}
218
219func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
220 return func(options *providerClientOptions) {
221 options.anthropicOptions = anthropicOptions
222 }
223}
224
225func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
226 return func(options *providerClientOptions) {
227 options.openaiOptions = openaiOptions
228 }
229}
230
231func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
232 return func(options *providerClientOptions) {
233 options.geminiOptions = geminiOptions
234 }
235}
236
237func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
238 return func(options *providerClientOptions) {
239 options.bedrockOptions = bedrockOptions
240 }
241}
242
243func WithCopilotOptions(copilotOptions ...CopilotOption) ProviderClientOption {
244 return func(options *providerClientOptions) {
245 options.copilotOptions = copilotOptions
246 }
247}