provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/kujtimiihoxha/opencode/internal/llm/models"
  8	"github.com/kujtimiihoxha/opencode/internal/llm/tools"
  9	"github.com/kujtimiihoxha/opencode/internal/message"
 10)
 11
 12type EventType string
 13
 14const maxRetries = 8
 15
 16const (
 17	EventContentStart  EventType = "content_start"
 18	EventContentDelta  EventType = "content_delta"
 19	EventThinkingDelta EventType = "thinking_delta"
 20	EventContentStop   EventType = "content_stop"
 21	EventComplete      EventType = "complete"
 22	EventError         EventType = "error"
 23	EventWarning       EventType = "warning"
 24)
 25
 26type TokenUsage struct {
 27	InputTokens         int64
 28	OutputTokens        int64
 29	CacheCreationTokens int64
 30	CacheReadTokens     int64
 31}
 32
 33type ProviderResponse struct {
 34	Content      string
 35	ToolCalls    []message.ToolCall
 36	Usage        TokenUsage
 37	FinishReason message.FinishReason
 38}
 39
 40type ProviderEvent struct {
 41	Type EventType
 42
 43	Content  string
 44	Thinking string
 45	Response *ProviderResponse
 46
 47	Error error
 48}
 49type Provider interface {
 50	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 51
 52	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 53
 54	Model() models.Model
 55}
 56
 57type providerClientOptions struct {
 58	apiKey        string
 59	model         models.Model
 60	maxTokens     int64
 61	systemMessage string
 62
 63	anthropicOptions []AnthropicOption
 64	openaiOptions    []OpenAIOption
 65	geminiOptions    []GeminiOption
 66	bedrockOptions   []BedrockOption
 67}
 68
 69type ProviderClientOption func(*providerClientOptions)
 70
 71type ProviderClient interface {
 72	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 73	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 74}
 75
 76type baseProvider[C ProviderClient] struct {
 77	options providerClientOptions
 78	client  C
 79}
 80
 81func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
 82	clientOptions := providerClientOptions{}
 83	for _, o := range opts {
 84		o(&clientOptions)
 85	}
 86	switch providerName {
 87	case models.ProviderAnthropic:
 88		return &baseProvider[AnthropicClient]{
 89			options: clientOptions,
 90			client:  newAnthropicClient(clientOptions),
 91		}, nil
 92	case models.ProviderOpenAI:
 93		return &baseProvider[OpenAIClient]{
 94			options: clientOptions,
 95			client:  newOpenAIClient(clientOptions),
 96		}, nil
 97	case models.ProviderGemini:
 98		return &baseProvider[GeminiClient]{
 99			options: clientOptions,
100			client:  newGeminiClient(clientOptions),
101		}, nil
102	case models.ProviderBedrock:
103		return &baseProvider[BedrockClient]{
104			options: clientOptions,
105			client:  newBedrockClient(clientOptions),
106		}, nil
107	case models.ProviderMock:
108		// TODO: implement mock client for test
109		panic("not implemented")
110	}
111	return nil, fmt.Errorf("provider not supported: %s", providerName)
112}
113
114func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
115	for _, msg := range messages {
116		// The message has no content
117		if len(msg.Parts) == 0 {
118			continue
119		}
120		cleaned = append(cleaned, msg)
121	}
122	return
123}
124
125func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
126	messages = p.cleanMessages(messages)
127	return p.client.send(ctx, messages, tools)
128}
129
130func (p *baseProvider[C]) Model() models.Model {
131	return p.options.model
132}
133
134func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
135	messages = p.cleanMessages(messages)
136	return p.client.stream(ctx, messages, tools)
137}
138
139func WithAPIKey(apiKey string) ProviderClientOption {
140	return func(options *providerClientOptions) {
141		options.apiKey = apiKey
142	}
143}
144
145func WithModel(model models.Model) ProviderClientOption {
146	return func(options *providerClientOptions) {
147		options.model = model
148	}
149}
150
151func WithMaxTokens(maxTokens int64) ProviderClientOption {
152	return func(options *providerClientOptions) {
153		options.maxTokens = maxTokens
154	}
155}
156
157func WithSystemMessage(systemMessage string) ProviderClientOption {
158	return func(options *providerClientOptions) {
159		options.systemMessage = systemMessage
160	}
161}
162
163func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
164	return func(options *providerClientOptions) {
165		options.anthropicOptions = anthropicOptions
166	}
167}
168
169func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
170	return func(options *providerClientOptions) {
171		options.openaiOptions = openaiOptions
172	}
173}
174
175func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
176	return func(options *providerClientOptions) {
177		options.geminiOptions = geminiOptions
178	}
179}
180
181func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
182	return func(options *providerClientOptions) {
183		options.bedrockOptions = bedrockOptions
184	}
185}