1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/kujtimiihoxha/opencode/internal/llm/models"
8 "github.com/kujtimiihoxha/opencode/internal/llm/tools"
9 "github.com/kujtimiihoxha/opencode/internal/message"
10)
11
12type EventType string
13
14const maxRetries = 8
15
16const (
17 EventContentStart EventType = "content_start"
18 EventContentDelta EventType = "content_delta"
19 EventThinkingDelta EventType = "thinking_delta"
20 EventContentStop EventType = "content_stop"
21 EventComplete EventType = "complete"
22 EventError EventType = "error"
23 EventWarning EventType = "warning"
24)
25
26type TokenUsage struct {
27 InputTokens int64
28 OutputTokens int64
29 CacheCreationTokens int64
30 CacheReadTokens int64
31}
32
33type ProviderResponse struct {
34 Content string
35 ToolCalls []message.ToolCall
36 Usage TokenUsage
37 FinishReason message.FinishReason
38}
39
40type ProviderEvent struct {
41 Type EventType
42
43 Content string
44 Thinking string
45 Response *ProviderResponse
46
47 Error error
48}
49type Provider interface {
50 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
51
52 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
53
54 Model() models.Model
55}
56
57type providerClientOptions struct {
58 apiKey string
59 model models.Model
60 maxTokens int64
61 systemMessage string
62
63 anthropicOptions []AnthropicOption
64 openaiOptions []OpenAIOption
65 geminiOptions []GeminiOption
66 bedrockOptions []BedrockOption
67}
68
69type ProviderClientOption func(*providerClientOptions)
70
71type ProviderClient interface {
72 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
73 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
74}
75
76type baseProvider[C ProviderClient] struct {
77 options providerClientOptions
78 client C
79}
80
81func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
82 clientOptions := providerClientOptions{}
83 for _, o := range opts {
84 o(&clientOptions)
85 }
86 switch providerName {
87 case models.ProviderAnthropic:
88 return &baseProvider[AnthropicClient]{
89 options: clientOptions,
90 client: newAnthropicClient(clientOptions),
91 }, nil
92 case models.ProviderOpenAI:
93 return &baseProvider[OpenAIClient]{
94 options: clientOptions,
95 client: newOpenAIClient(clientOptions),
96 }, nil
97 case models.ProviderGemini:
98 return &baseProvider[GeminiClient]{
99 options: clientOptions,
100 client: newGeminiClient(clientOptions),
101 }, nil
102 case models.ProviderBedrock:
103 return &baseProvider[BedrockClient]{
104 options: clientOptions,
105 client: newBedrockClient(clientOptions),
106 }, nil
107 case models.ProviderMock:
108 // TODO: implement mock client for test
109 panic("not implemented")
110 }
111 return nil, fmt.Errorf("provider not supported: %s", providerName)
112}
113
114func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
115 for _, msg := range messages {
116 // The message has no content
117 if len(msg.Parts) == 0 {
118 continue
119 }
120 cleaned = append(cleaned, msg)
121 }
122 return
123}
124
125func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
126 messages = p.cleanMessages(messages)
127 return p.client.send(ctx, messages, tools)
128}
129
130func (p *baseProvider[C]) Model() models.Model {
131 return p.options.model
132}
133
134func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
135 messages = p.cleanMessages(messages)
136 return p.client.stream(ctx, messages, tools)
137}
138
139func WithAPIKey(apiKey string) ProviderClientOption {
140 return func(options *providerClientOptions) {
141 options.apiKey = apiKey
142 }
143}
144
145func WithModel(model models.Model) ProviderClientOption {
146 return func(options *providerClientOptions) {
147 options.model = model
148 }
149}
150
151func WithMaxTokens(maxTokens int64) ProviderClientOption {
152 return func(options *providerClientOptions) {
153 options.maxTokens = maxTokens
154 }
155}
156
157func WithSystemMessage(systemMessage string) ProviderClientOption {
158 return func(options *providerClientOptions) {
159 options.systemMessage = systemMessage
160 }
161}
162
163func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
164 return func(options *providerClientOptions) {
165 options.anthropicOptions = anthropicOptions
166 }
167}
168
169func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
170 return func(options *providerClientOptions) {
171 options.openaiOptions = openaiOptions
172 }
173}
174
175func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
176 return func(options *providerClientOptions) {
177 options.geminiOptions = geminiOptions
178 }
179}
180
181func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
182 return func(options *providerClientOptions) {
183 options.bedrockOptions = bedrockOptions
184 }
185}