1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/generative-ai-go/genai"
13 "github.com/google/uuid"
14 "github.com/opencode-ai/opencode/internal/config"
15 "github.com/opencode-ai/opencode/internal/llm/tools"
16 "github.com/opencode-ai/opencode/internal/logging"
17 "github.com/opencode-ai/opencode/internal/message"
18 "google.golang.org/api/iterator"
19 "google.golang.org/api/option"
20)
21
22type geminiOptions struct {
23 disableCache bool
24}
25
26type GeminiOption func(*geminiOptions)
27
28type geminiClient struct {
29 providerOptions providerClientOptions
30 options geminiOptions
31 client *genai.Client
32}
33
34type GeminiClient ProviderClient
35
36func newGeminiClient(opts providerClientOptions) GeminiClient {
37 geminiOpts := geminiOptions{}
38 for _, o := range opts.geminiOptions {
39 o(&geminiOpts)
40 }
41
42 client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
43 if err != nil {
44 logging.Error("Failed to create Gemini client", "error", err)
45 return nil
46 }
47
48 return &geminiClient{
49 providerOptions: opts,
50 options: geminiOpts,
51 client: client,
52 }
53}
54
55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
56 var history []*genai.Content
57 for _, msg := range messages {
58 switch msg.Role {
59 case message.User:
60 history = append(history, &genai.Content{
61 Parts: []genai.Part{genai.Text(msg.Content().String())},
62 Role: "user",
63 })
64
65 case message.Assistant:
66 content := &genai.Content{
67 Role: "model",
68 Parts: []genai.Part{},
69 }
70
71 if msg.Content().String() != "" {
72 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
73 }
74
75 if len(msg.ToolCalls()) > 0 {
76 for _, call := range msg.ToolCalls() {
77 args, _ := parseJsonToMap(call.Input)
78 content.Parts = append(content.Parts, genai.FunctionCall{
79 Name: call.Name,
80 Args: args,
81 })
82 }
83 }
84
85 history = append(history, content)
86
87 case message.Tool:
88 for _, result := range msg.ToolResults() {
89 response := map[string]interface{}{"result": result.Content}
90 parsed, err := parseJsonToMap(result.Content)
91 if err == nil {
92 response = parsed
93 }
94
95 var toolCall message.ToolCall
96 for _, m := range messages {
97 if m.Role == message.Assistant {
98 for _, call := range m.ToolCalls() {
99 if call.ID == result.ToolCallID {
100 toolCall = call
101 break
102 }
103 }
104 }
105 }
106
107 history = append(history, &genai.Content{
108 Parts: []genai.Part{genai.FunctionResponse{
109 Name: toolCall.Name,
110 Response: response,
111 }},
112 Role: "function",
113 })
114 }
115 }
116 }
117
118 return history
119}
120
121func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
122 geminiTool := &genai.Tool{}
123 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
124
125 for _, tool := range tools {
126 info := tool.Info()
127 declaration := &genai.FunctionDeclaration{
128 Name: info.Name,
129 Description: info.Description,
130 Parameters: &genai.Schema{
131 Type: genai.TypeObject,
132 Properties: convertSchemaProperties(info.Parameters),
133 Required: info.Required,
134 },
135 }
136
137 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
138 }
139
140 return []*genai.Tool{geminiTool}
141}
142
143func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
144 switch {
145 case reason == genai.FinishReasonStop:
146 return message.FinishReasonEndTurn
147 case reason == genai.FinishReasonMaxTokens:
148 return message.FinishReasonMaxTokens
149 default:
150 return message.FinishReasonUnknown
151 }
152}
153
154func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
155 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
156 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
157 model.SystemInstruction = &genai.Content{
158 Parts: []genai.Part{
159 genai.Text(g.providerOptions.systemMessage),
160 },
161 }
162 // Convert tools
163 if len(tools) > 0 {
164 model.Tools = g.convertTools(tools)
165 }
166
167 // Convert messages
168 geminiMessages := g.convertMessages(messages)
169
170 cfg := config.Get()
171 if cfg.Debug {
172 jsonData, _ := json.Marshal(geminiMessages)
173 logging.Debug("Prepared messages", "messages", string(jsonData))
174 }
175
176 attempts := 0
177 for {
178 attempts++
179 var toolCalls []message.ToolCall
180 chat := model.StartChat()
181 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
182
183 lastMsg := geminiMessages[len(geminiMessages)-1]
184
185 resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
186 // If there is an error we are going to see if we can retry the call
187 if err != nil {
188 retry, after, retryErr := g.shouldRetry(attempts, err)
189 if retryErr != nil {
190 return nil, retryErr
191 }
192 if retry {
193 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
194 select {
195 case <-ctx.Done():
196 return nil, ctx.Err()
197 case <-time.After(time.Duration(after) * time.Millisecond):
198 continue
199 }
200 }
201 return nil, retryErr
202 }
203
204 content := ""
205
206 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
207 for _, part := range resp.Candidates[0].Content.Parts {
208 switch p := part.(type) {
209 case genai.Text:
210 content = string(p)
211 case genai.FunctionCall:
212 id := "call_" + uuid.New().String()
213 args, _ := json.Marshal(p.Args)
214 toolCalls = append(toolCalls, message.ToolCall{
215 ID: id,
216 Name: p.Name,
217 Input: string(args),
218 Type: "function",
219 Finished: true,
220 })
221 }
222 }
223 }
224 finishReason := g.finishReason(resp.Candidates[0].FinishReason)
225 if len(toolCalls) > 0 {
226 finishReason = message.FinishReasonToolUse
227 }
228
229 return &ProviderResponse{
230 Content: content,
231 ToolCalls: toolCalls,
232 Usage: g.usage(resp),
233 FinishReason: finishReason,
234 }, nil
235 }
236}
237
238func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
239 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
240 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
241 model.SystemInstruction = &genai.Content{
242 Parts: []genai.Part{
243 genai.Text(g.providerOptions.systemMessage),
244 },
245 }
246 // Convert tools
247 if len(tools) > 0 {
248 model.Tools = g.convertTools(tools)
249 }
250
251 // Convert messages
252 geminiMessages := g.convertMessages(messages)
253
254 cfg := config.Get()
255 if cfg.Debug {
256 jsonData, _ := json.Marshal(geminiMessages)
257 logging.Debug("Prepared messages", "messages", string(jsonData))
258 }
259
260 attempts := 0
261 eventChan := make(chan ProviderEvent)
262
263 go func() {
264 defer close(eventChan)
265
266 for {
267 attempts++
268 chat := model.StartChat()
269 chat.History = geminiMessages[:len(geminiMessages)-1]
270 lastMsg := geminiMessages[len(geminiMessages)-1]
271
272 iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
273
274 currentContent := ""
275 toolCalls := []message.ToolCall{}
276 var finalResp *genai.GenerateContentResponse
277
278 eventChan <- ProviderEvent{Type: EventContentStart}
279
280 for {
281 resp, err := iter.Next()
282 if err == iterator.Done {
283 break
284 }
285 if err != nil {
286 retry, after, retryErr := g.shouldRetry(attempts, err)
287 if retryErr != nil {
288 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
289 return
290 }
291 if retry {
292 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
293 select {
294 case <-ctx.Done():
295 if ctx.Err() != nil {
296 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
297 }
298
299 return
300 case <-time.After(time.Duration(after) * time.Millisecond):
301 break
302 }
303 } else {
304 eventChan <- ProviderEvent{Type: EventError, Error: err}
305 return
306 }
307 }
308
309 finalResp = resp
310
311 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
312 for _, part := range resp.Candidates[0].Content.Parts {
313 switch p := part.(type) {
314 case genai.Text:
315 delta := string(p)
316 if delta != "" {
317 eventChan <- ProviderEvent{
318 Type: EventContentDelta,
319 Content: delta,
320 }
321 currentContent += delta
322 }
323 case genai.FunctionCall:
324 id := "call_" + uuid.New().String()
325 args, _ := json.Marshal(p.Args)
326 newCall := message.ToolCall{
327 ID: id,
328 Name: p.Name,
329 Input: string(args),
330 Type: "function",
331 Finished: true,
332 }
333
334 isNew := true
335 for _, existing := range toolCalls {
336 if existing.Name == newCall.Name && existing.Input == newCall.Input {
337 isNew = false
338 break
339 }
340 }
341
342 if isNew {
343 toolCalls = append(toolCalls, newCall)
344 }
345 }
346 }
347 }
348 }
349
350 eventChan <- ProviderEvent{Type: EventContentStop}
351
352 if finalResp != nil {
353 finishReason := g.finishReason(finalResp.Candidates[0].FinishReason)
354 if len(toolCalls) > 0 {
355 finishReason = message.FinishReasonToolUse
356 }
357 eventChan <- ProviderEvent{
358 Type: EventComplete,
359 Response: &ProviderResponse{
360 Content: currentContent,
361 ToolCalls: toolCalls,
362 Usage: g.usage(finalResp),
363 FinishReason: finishReason,
364 },
365 }
366 return
367 }
368
369 }
370 }()
371
372 return eventChan
373}
374
375func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
376 // Check if error is a rate limit error
377 if attempts > maxRetries {
378 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
379 }
380
381 // Gemini doesn't have a standard error type we can check against
382 // So we'll check the error message for rate limit indicators
383 if errors.Is(err, io.EOF) {
384 return false, 0, err
385 }
386
387 errMsg := err.Error()
388 isRateLimit := false
389
390 // Check for common rate limit error messages
391 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
392 isRateLimit = true
393 }
394
395 if !isRateLimit {
396 return false, 0, err
397 }
398
399 // Calculate backoff with jitter
400 backoffMs := 2000 * (1 << (attempts - 1))
401 jitterMs := int(float64(backoffMs) * 0.2)
402 retryMs := backoffMs + jitterMs
403
404 return true, int64(retryMs), nil
405}
406
407func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
408 var toolCalls []message.ToolCall
409
410 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
411 for _, part := range resp.Candidates[0].Content.Parts {
412 if funcCall, ok := part.(genai.FunctionCall); ok {
413 id := "call_" + uuid.New().String()
414 args, _ := json.Marshal(funcCall.Args)
415 toolCalls = append(toolCalls, message.ToolCall{
416 ID: id,
417 Name: funcCall.Name,
418 Input: string(args),
419 Type: "function",
420 })
421 }
422 }
423 }
424
425 return toolCalls
426}
427
428func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
429 if resp == nil || resp.UsageMetadata == nil {
430 return TokenUsage{}
431 }
432
433 return TokenUsage{
434 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
435 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
436 CacheCreationTokens: 0, // Not directly provided by Gemini
437 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
438 }
439}
440
441func WithGeminiDisableCache() GeminiOption {
442 return func(options *geminiOptions) {
443 options.disableCache = true
444 }
445}
446
447// Helper functions
448func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
449 var result map[string]interface{}
450 err := json.Unmarshal([]byte(jsonStr), &result)
451 return result, err
452}
453
454func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
455 properties := make(map[string]*genai.Schema)
456
457 for name, param := range parameters {
458 properties[name] = convertToSchema(param)
459 }
460
461 return properties
462}
463
464func convertToSchema(param interface{}) *genai.Schema {
465 schema := &genai.Schema{Type: genai.TypeString}
466
467 paramMap, ok := param.(map[string]interface{})
468 if !ok {
469 return schema
470 }
471
472 if desc, ok := paramMap["description"].(string); ok {
473 schema.Description = desc
474 }
475
476 typeVal, hasType := paramMap["type"]
477 if !hasType {
478 return schema
479 }
480
481 typeStr, ok := typeVal.(string)
482 if !ok {
483 return schema
484 }
485
486 schema.Type = mapJSONTypeToGenAI(typeStr)
487
488 switch typeStr {
489 case "array":
490 schema.Items = processArrayItems(paramMap)
491 case "object":
492 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
493 schema.Properties = convertSchemaProperties(props)
494 }
495 }
496
497 return schema
498}
499
500func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
501 items, ok := paramMap["items"].(map[string]interface{})
502 if !ok {
503 return nil
504 }
505
506 return convertToSchema(items)
507}
508
509func mapJSONTypeToGenAI(jsonType string) genai.Type {
510 switch jsonType {
511 case "string":
512 return genai.TypeString
513 case "number":
514 return genai.TypeNumber
515 case "integer":
516 return genai.TypeInteger
517 case "boolean":
518 return genai.TypeBoolean
519 case "array":
520 return genai.TypeArray
521 case "object":
522 return genai.TypeObject
523 default:
524 return genai.TypeString // Default to string for unknown types
525 }
526}
527
528func contains(s string, substrs ...string) bool {
529 for _, substr := range substrs {
530 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
531 return true
532 }
533 }
534 return false
535}