1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/generative-ai-go/genai"
13 "github.com/google/uuid"
14 "github.com/opencode-ai/opencode/internal/config"
15 "github.com/opencode-ai/opencode/internal/llm/tools"
16 "github.com/opencode-ai/opencode/internal/logging"
17 "github.com/opencode-ai/opencode/internal/message"
18 "google.golang.org/api/iterator"
19 "google.golang.org/api/option"
20)
21
22type geminiOptions struct {
23 disableCache bool
24}
25
26type GeminiOption func(*geminiOptions)
27
28type geminiClient struct {
29 providerOptions providerClientOptions
30 options geminiOptions
31 client *genai.Client
32}
33
34type GeminiClient ProviderClient
35
36func newGeminiClient(opts providerClientOptions) GeminiClient {
37 geminiOpts := geminiOptions{}
38 for _, o := range opts.geminiOptions {
39 o(&geminiOpts)
40 }
41
42 client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
43 if err != nil {
44 logging.Error("Failed to create Gemini client", "error", err)
45 return nil
46 }
47
48 return &geminiClient{
49 providerOptions: opts,
50 options: geminiOpts,
51 client: client,
52 }
53}
54
55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
56 var history []*genai.Content
57 for _, msg := range messages {
58 switch msg.Role {
59 case message.User:
60 var parts []genai.Part
61 parts = append(parts, genai.Text(msg.Content().String()))
62 for _, binaryContent := range msg.BinaryContent() {
63 imageFormat := strings.Split(binaryContent.MIMEType, "/")
64 parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
65 }
66 history = append(history, &genai.Content{
67 Parts: parts,
68 Role: "user",
69 })
70 case message.Assistant:
71 content := &genai.Content{
72 Role: "model",
73 Parts: []genai.Part{},
74 }
75
76 if msg.Content().String() != "" {
77 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
78 }
79
80 if len(msg.ToolCalls()) > 0 {
81 for _, call := range msg.ToolCalls() {
82 args, _ := parseJsonToMap(call.Input)
83 content.Parts = append(content.Parts, genai.FunctionCall{
84 Name: call.Name,
85 Args: args,
86 })
87 }
88 }
89
90 history = append(history, content)
91
92 case message.Tool:
93 for _, result := range msg.ToolResults() {
94 response := map[string]interface{}{"result": result.Content}
95 parsed, err := parseJsonToMap(result.Content)
96 if err == nil {
97 response = parsed
98 }
99
100 var toolCall message.ToolCall
101 for _, m := range messages {
102 if m.Role == message.Assistant {
103 for _, call := range m.ToolCalls() {
104 if call.ID == result.ToolCallID {
105 toolCall = call
106 break
107 }
108 }
109 }
110 }
111
112 history = append(history, &genai.Content{
113 Parts: []genai.Part{genai.FunctionResponse{
114 Name: toolCall.Name,
115 Response: response,
116 }},
117 Role: "function",
118 })
119 }
120 }
121 }
122
123 return history
124}
125
126func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
127 geminiTool := &genai.Tool{}
128 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
129
130 for _, tool := range tools {
131 info := tool.Info()
132 declaration := &genai.FunctionDeclaration{
133 Name: info.Name,
134 Description: info.Description,
135 Parameters: &genai.Schema{
136 Type: genai.TypeObject,
137 Properties: convertSchemaProperties(info.Parameters),
138 Required: info.Required,
139 },
140 }
141
142 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
143 }
144
145 return []*genai.Tool{geminiTool}
146}
147
148func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
149 switch {
150 case reason == genai.FinishReasonStop:
151 return message.FinishReasonEndTurn
152 case reason == genai.FinishReasonMaxTokens:
153 return message.FinishReasonMaxTokens
154 default:
155 return message.FinishReasonUnknown
156 }
157}
158
159func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
160 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
161 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
162 model.SystemInstruction = &genai.Content{
163 Parts: []genai.Part{
164 genai.Text(g.providerOptions.systemMessage),
165 },
166 }
167 // Convert tools
168 if len(tools) > 0 {
169 model.Tools = g.convertTools(tools)
170 }
171
172 // Convert messages
173 geminiMessages := g.convertMessages(messages)
174
175 cfg := config.Get()
176 if cfg.Debug {
177 jsonData, _ := json.Marshal(geminiMessages)
178 logging.Debug("Prepared messages", "messages", string(jsonData))
179 }
180
181 attempts := 0
182 for {
183 attempts++
184 var toolCalls []message.ToolCall
185 chat := model.StartChat()
186 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
187
188 lastMsg := geminiMessages[len(geminiMessages)-1]
189
190 resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
191 // If there is an error we are going to see if we can retry the call
192 if err != nil {
193 retry, after, retryErr := g.shouldRetry(attempts, err)
194 if retryErr != nil {
195 return nil, retryErr
196 }
197 if retry {
198 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
199 select {
200 case <-ctx.Done():
201 return nil, ctx.Err()
202 case <-time.After(time.Duration(after) * time.Millisecond):
203 continue
204 }
205 }
206 return nil, retryErr
207 }
208
209 content := ""
210
211 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
212 for _, part := range resp.Candidates[0].Content.Parts {
213 switch p := part.(type) {
214 case genai.Text:
215 content = string(p)
216 case genai.FunctionCall:
217 id := "call_" + uuid.New().String()
218 args, _ := json.Marshal(p.Args)
219 toolCalls = append(toolCalls, message.ToolCall{
220 ID: id,
221 Name: p.Name,
222 Input: string(args),
223 Type: "function",
224 Finished: true,
225 })
226 }
227 }
228 }
229 finishReason := message.FinishReasonEndTurn
230 if len(resp.Candidates) > 0 {
231 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
232 }
233 if len(toolCalls) > 0 {
234 finishReason = message.FinishReasonToolUse
235 }
236
237 return &ProviderResponse{
238 Content: content,
239 ToolCalls: toolCalls,
240 Usage: g.usage(resp),
241 FinishReason: finishReason,
242 }, nil
243 }
244}
245
246func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
247 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
248 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
249 model.SystemInstruction = &genai.Content{
250 Parts: []genai.Part{
251 genai.Text(g.providerOptions.systemMessage),
252 },
253 }
254 // Convert tools
255 if len(tools) > 0 {
256 model.Tools = g.convertTools(tools)
257 }
258
259 // Convert messages
260 geminiMessages := g.convertMessages(messages)
261
262 cfg := config.Get()
263 if cfg.Debug {
264 jsonData, _ := json.Marshal(geminiMessages)
265 logging.Debug("Prepared messages", "messages", string(jsonData))
266 }
267
268 attempts := 0
269 eventChan := make(chan ProviderEvent)
270
271 go func() {
272 defer close(eventChan)
273
274 for {
275 attempts++
276 chat := model.StartChat()
277 chat.History = geminiMessages[:len(geminiMessages)-1]
278 lastMsg := geminiMessages[len(geminiMessages)-1]
279
280 iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
281
282 currentContent := ""
283 toolCalls := []message.ToolCall{}
284 var finalResp *genai.GenerateContentResponse
285
286 eventChan <- ProviderEvent{Type: EventContentStart}
287
288 for {
289 resp, err := iter.Next()
290 if err == iterator.Done {
291 break
292 }
293 if err != nil {
294 retry, after, retryErr := g.shouldRetry(attempts, err)
295 if retryErr != nil {
296 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
297 return
298 }
299 if retry {
300 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
301 select {
302 case <-ctx.Done():
303 if ctx.Err() != nil {
304 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
305 }
306
307 return
308 case <-time.After(time.Duration(after) * time.Millisecond):
309 break
310 }
311 } else {
312 eventChan <- ProviderEvent{Type: EventError, Error: err}
313 return
314 }
315 }
316
317 finalResp = resp
318
319 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
320 for _, part := range resp.Candidates[0].Content.Parts {
321 switch p := part.(type) {
322 case genai.Text:
323 delta := string(p)
324 if delta != "" {
325 eventChan <- ProviderEvent{
326 Type: EventContentDelta,
327 Content: delta,
328 }
329 currentContent += delta
330 }
331 case genai.FunctionCall:
332 id := "call_" + uuid.New().String()
333 args, _ := json.Marshal(p.Args)
334 newCall := message.ToolCall{
335 ID: id,
336 Name: p.Name,
337 Input: string(args),
338 Type: "function",
339 Finished: true,
340 }
341
342 isNew := true
343 for _, existing := range toolCalls {
344 if existing.Name == newCall.Name && existing.Input == newCall.Input {
345 isNew = false
346 break
347 }
348 }
349
350 if isNew {
351 toolCalls = append(toolCalls, newCall)
352 }
353 }
354 }
355 }
356 }
357
358 eventChan <- ProviderEvent{Type: EventContentStop}
359
360 if finalResp != nil {
361
362 finishReason := message.FinishReasonEndTurn
363 if len(finalResp.Candidates) > 0 {
364 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
365 }
366 if len(toolCalls) > 0 {
367 finishReason = message.FinishReasonToolUse
368 }
369 eventChan <- ProviderEvent{
370 Type: EventComplete,
371 Response: &ProviderResponse{
372 Content: currentContent,
373 ToolCalls: toolCalls,
374 Usage: g.usage(finalResp),
375 FinishReason: finishReason,
376 },
377 }
378 return
379 }
380
381 }
382 }()
383
384 return eventChan
385}
386
387func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
388 // Check if error is a rate limit error
389 if attempts > maxRetries {
390 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
391 }
392
393 // Gemini doesn't have a standard error type we can check against
394 // So we'll check the error message for rate limit indicators
395 if errors.Is(err, io.EOF) {
396 return false, 0, err
397 }
398
399 errMsg := err.Error()
400 isRateLimit := false
401
402 // Check for common rate limit error messages
403 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
404 isRateLimit = true
405 }
406
407 if !isRateLimit {
408 return false, 0, err
409 }
410
411 // Calculate backoff with jitter
412 backoffMs := 2000 * (1 << (attempts - 1))
413 jitterMs := int(float64(backoffMs) * 0.2)
414 retryMs := backoffMs + jitterMs
415
416 return true, int64(retryMs), nil
417}
418
419func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
420 var toolCalls []message.ToolCall
421
422 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
423 for _, part := range resp.Candidates[0].Content.Parts {
424 if funcCall, ok := part.(genai.FunctionCall); ok {
425 id := "call_" + uuid.New().String()
426 args, _ := json.Marshal(funcCall.Args)
427 toolCalls = append(toolCalls, message.ToolCall{
428 ID: id,
429 Name: funcCall.Name,
430 Input: string(args),
431 Type: "function",
432 })
433 }
434 }
435 }
436
437 return toolCalls
438}
439
440func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
441 if resp == nil || resp.UsageMetadata == nil {
442 return TokenUsage{}
443 }
444
445 return TokenUsage{
446 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
447 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
448 CacheCreationTokens: 0, // Not directly provided by Gemini
449 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
450 }
451}
452
453func WithGeminiDisableCache() GeminiOption {
454 return func(options *geminiOptions) {
455 options.disableCache = true
456 }
457}
458
459// Helper functions
460func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
461 var result map[string]interface{}
462 err := json.Unmarshal([]byte(jsonStr), &result)
463 return result, err
464}
465
466func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
467 properties := make(map[string]*genai.Schema)
468
469 for name, param := range parameters {
470 properties[name] = convertToSchema(param)
471 }
472
473 return properties
474}
475
476func convertToSchema(param interface{}) *genai.Schema {
477 schema := &genai.Schema{Type: genai.TypeString}
478
479 paramMap, ok := param.(map[string]interface{})
480 if !ok {
481 return schema
482 }
483
484 if desc, ok := paramMap["description"].(string); ok {
485 schema.Description = desc
486 }
487
488 typeVal, hasType := paramMap["type"]
489 if !hasType {
490 return schema
491 }
492
493 typeStr, ok := typeVal.(string)
494 if !ok {
495 return schema
496 }
497
498 schema.Type = mapJSONTypeToGenAI(typeStr)
499
500 switch typeStr {
501 case "array":
502 schema.Items = processArrayItems(paramMap)
503 case "object":
504 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
505 schema.Properties = convertSchemaProperties(props)
506 }
507 }
508
509 return schema
510}
511
512func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
513 items, ok := paramMap["items"].(map[string]interface{})
514 if !ok {
515 return nil
516 }
517
518 return convertToSchema(items)
519}
520
521func mapJSONTypeToGenAI(jsonType string) genai.Type {
522 switch jsonType {
523 case "string":
524 return genai.TypeString
525 case "number":
526 return genai.TypeNumber
527 case "integer":
528 return genai.TypeInteger
529 case "boolean":
530 return genai.TypeBoolean
531 case "array":
532 return genai.TypeArray
533 case "object":
534 return genai.TypeObject
535 default:
536 return genai.TypeString // Default to string for unknown types
537 }
538}
539
540func contains(s string, substrs ...string) bool {
541 for _, substr := range substrs {
542 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
543 return true
544 }
545 }
546 return false
547}