1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/generative-ai-go/genai"
13 "github.com/google/uuid"
14 "github.com/opencode-ai/opencode/internal/config"
15 "github.com/opencode-ai/opencode/internal/llm/tools"
16 "github.com/opencode-ai/opencode/internal/logging"
17 "github.com/opencode-ai/opencode/internal/message"
18 "google.golang.org/api/iterator"
19 "google.golang.org/api/option"
20)
21
22type geminiOptions struct {
23 disableCache bool
24}
25
26type GeminiOption func(*geminiOptions)
27
28type geminiClient struct {
29 providerOptions providerClientOptions
30 options geminiOptions
31 client *genai.Client
32}
33
34type GeminiClient ProviderClient
35
36func newGeminiClient(opts providerClientOptions) GeminiClient {
37 geminiOpts := geminiOptions{}
38 for _, o := range opts.geminiOptions {
39 o(&geminiOpts)
40 }
41
42 client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
43 if err != nil {
44 logging.Error("Failed to create Gemini client", "error", err)
45 return nil
46 }
47
48 return &geminiClient{
49 providerOptions: opts,
50 options: geminiOpts,
51 client: client,
52 }
53}
54
55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
56 var history []*genai.Content
57
58 // Add system message first
59 history = append(history, &genai.Content{
60 Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
61 Role: "user",
62 })
63
64 // Add a system response to acknowledge the system message
65 history = append(history, &genai.Content{
66 Parts: []genai.Part{genai.Text("I'll help you with that.")},
67 Role: "model",
68 })
69
70 for _, msg := range messages {
71 switch msg.Role {
72 case message.User:
73 history = append(history, &genai.Content{
74 Parts: []genai.Part{genai.Text(msg.Content().String())},
75 Role: "user",
76 })
77
78 case message.Assistant:
79 content := &genai.Content{
80 Role: "model",
81 Parts: []genai.Part{},
82 }
83
84 if msg.Content().String() != "" {
85 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
86 }
87
88 if len(msg.ToolCalls()) > 0 {
89 for _, call := range msg.ToolCalls() {
90 args, _ := parseJsonToMap(call.Input)
91 content.Parts = append(content.Parts, genai.FunctionCall{
92 Name: call.Name,
93 Args: args,
94 })
95 }
96 }
97
98 history = append(history, content)
99
100 case message.Tool:
101 for _, result := range msg.ToolResults() {
102 response := map[string]interface{}{"result": result.Content}
103 parsed, err := parseJsonToMap(result.Content)
104 if err == nil {
105 response = parsed
106 }
107
108 var toolCall message.ToolCall
109 for _, m := range messages {
110 if m.Role == message.Assistant {
111 for _, call := range m.ToolCalls() {
112 if call.ID == result.ToolCallID {
113 toolCall = call
114 break
115 }
116 }
117 }
118 }
119
120 history = append(history, &genai.Content{
121 Parts: []genai.Part{genai.FunctionResponse{
122 Name: toolCall.Name,
123 Response: response,
124 }},
125 Role: "function",
126 })
127 }
128 }
129 }
130
131 return history
132}
133
134func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
135 geminiTool := &genai.Tool{}
136 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
137
138 for _, tool := range tools {
139 info := tool.Info()
140 declaration := &genai.FunctionDeclaration{
141 Name: info.Name,
142 Description: info.Description,
143 Parameters: &genai.Schema{
144 Type: genai.TypeObject,
145 Properties: convertSchemaProperties(info.Parameters),
146 Required: info.Required,
147 },
148 }
149
150 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
151 }
152
153 return []*genai.Tool{geminiTool}
154}
155
156func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
157 reasonStr := reason.String()
158 switch {
159 case reasonStr == "STOP":
160 return message.FinishReasonEndTurn
161 case reasonStr == "MAX_TOKENS":
162 return message.FinishReasonMaxTokens
163 case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
164 return message.FinishReasonToolUse
165 default:
166 return message.FinishReasonUnknown
167 }
168}
169
170func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
171 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
172 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
173
174 // Convert tools
175 if len(tools) > 0 {
176 model.Tools = g.convertTools(tools)
177 }
178
179 // Convert messages
180 geminiMessages := g.convertMessages(messages)
181
182 cfg := config.Get()
183 if cfg.Debug {
184 jsonData, _ := json.Marshal(geminiMessages)
185 logging.Debug("Prepared messages", "messages", string(jsonData))
186 }
187
188 attempts := 0
189 for {
190 attempts++
191 chat := model.StartChat()
192 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
193
194 lastMsg := geminiMessages[len(geminiMessages)-1]
195 var lastText string
196 for _, part := range lastMsg.Parts {
197 if text, ok := part.(genai.Text); ok {
198 lastText = string(text)
199 break
200 }
201 }
202
203 resp, err := chat.SendMessage(ctx, genai.Text(lastText))
204 // If there is an error we are going to see if we can retry the call
205 if err != nil {
206 retry, after, retryErr := g.shouldRetry(attempts, err)
207 if retryErr != nil {
208 return nil, retryErr
209 }
210 if retry {
211 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
212 select {
213 case <-ctx.Done():
214 return nil, ctx.Err()
215 case <-time.After(time.Duration(after) * time.Millisecond):
216 continue
217 }
218 }
219 return nil, retryErr
220 }
221
222 content := ""
223 var toolCalls []message.ToolCall
224
225 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
226 for _, part := range resp.Candidates[0].Content.Parts {
227 switch p := part.(type) {
228 case genai.Text:
229 content = string(p)
230 case genai.FunctionCall:
231 id := "call_" + uuid.New().String()
232 args, _ := json.Marshal(p.Args)
233 toolCalls = append(toolCalls, message.ToolCall{
234 ID: id,
235 Name: p.Name,
236 Input: string(args),
237 Type: "function",
238 })
239 }
240 }
241 }
242
243 return &ProviderResponse{
244 Content: content,
245 ToolCalls: toolCalls,
246 Usage: g.usage(resp),
247 FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
248 }, nil
249 }
250}
251
252func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
254 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
255
256 // Convert tools
257 if len(tools) > 0 {
258 model.Tools = g.convertTools(tools)
259 }
260
261 // Convert messages
262 geminiMessages := g.convertMessages(messages)
263
264 cfg := config.Get()
265 if cfg.Debug {
266 jsonData, _ := json.Marshal(geminiMessages)
267 logging.Debug("Prepared messages", "messages", string(jsonData))
268 }
269
270 attempts := 0
271 eventChan := make(chan ProviderEvent)
272
273 go func() {
274 defer close(eventChan)
275
276 for {
277 attempts++
278 chat := model.StartChat()
279 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
280
281 lastMsg := geminiMessages[len(geminiMessages)-1]
282 var lastText string
283 for _, part := range lastMsg.Parts {
284 if text, ok := part.(genai.Text); ok {
285 lastText = string(text)
286 break
287 }
288 }
289
290 iter := chat.SendMessageStream(ctx, genai.Text(lastText))
291
292 currentContent := ""
293 toolCalls := []message.ToolCall{}
294 var finalResp *genai.GenerateContentResponse
295
296 eventChan <- ProviderEvent{Type: EventContentStart}
297
298 for {
299 resp, err := iter.Next()
300 if err == iterator.Done {
301 break
302 }
303 if err != nil {
304 retry, after, retryErr := g.shouldRetry(attempts, err)
305 if retryErr != nil {
306 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
307 return
308 }
309 if retry {
310 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
311 select {
312 case <-ctx.Done():
313 if ctx.Err() != nil {
314 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
315 }
316
317 return
318 case <-time.After(time.Duration(after) * time.Millisecond):
319 break
320 }
321 } else {
322 eventChan <- ProviderEvent{Type: EventError, Error: err}
323 return
324 }
325 }
326
327 finalResp = resp
328
329 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
330 for _, part := range resp.Candidates[0].Content.Parts {
331 switch p := part.(type) {
332 case genai.Text:
333 newText := string(p)
334 delta := newText[len(currentContent):]
335 if delta != "" {
336 eventChan <- ProviderEvent{
337 Type: EventContentDelta,
338 Content: delta,
339 }
340 currentContent = newText
341 }
342 case genai.FunctionCall:
343 id := "call_" + uuid.New().String()
344 args, _ := json.Marshal(p.Args)
345 newCall := message.ToolCall{
346 ID: id,
347 Name: p.Name,
348 Input: string(args),
349 Type: "function",
350 }
351
352 isNew := true
353 for _, existing := range toolCalls {
354 if existing.Name == newCall.Name && existing.Input == newCall.Input {
355 isNew = false
356 break
357 }
358 }
359
360 if isNew {
361 toolCalls = append(toolCalls, newCall)
362 }
363 }
364 }
365 }
366 }
367
368 eventChan <- ProviderEvent{Type: EventContentStop}
369
370 if finalResp != nil {
371 eventChan <- ProviderEvent{
372 Type: EventComplete,
373 Response: &ProviderResponse{
374 Content: currentContent,
375 ToolCalls: toolCalls,
376 Usage: g.usage(finalResp),
377 FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
378 },
379 }
380 return
381 }
382
383 // If we get here, we need to retry
384 if attempts > maxRetries {
385 eventChan <- ProviderEvent{
386 Type: EventError,
387 Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
388 }
389 return
390 }
391
392 // Wait before retrying
393 select {
394 case <-ctx.Done():
395 if ctx.Err() != nil {
396 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
397 }
398 return
399 case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
400 continue
401 }
402 }
403 }()
404
405 return eventChan
406}
407
408func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
409 // Check if error is a rate limit error
410 if attempts > maxRetries {
411 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
412 }
413
414 // Gemini doesn't have a standard error type we can check against
415 // So we'll check the error message for rate limit indicators
416 if errors.Is(err, io.EOF) {
417 return false, 0, err
418 }
419
420 errMsg := err.Error()
421 isRateLimit := false
422
423 // Check for common rate limit error messages
424 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
425 isRateLimit = true
426 }
427
428 if !isRateLimit {
429 return false, 0, err
430 }
431
432 // Calculate backoff with jitter
433 backoffMs := 2000 * (1 << (attempts - 1))
434 jitterMs := int(float64(backoffMs) * 0.2)
435 retryMs := backoffMs + jitterMs
436
437 return true, int64(retryMs), nil
438}
439
440func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
441 var toolCalls []message.ToolCall
442
443 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
444 for _, part := range resp.Candidates[0].Content.Parts {
445 if funcCall, ok := part.(genai.FunctionCall); ok {
446 id := "call_" + uuid.New().String()
447 args, _ := json.Marshal(funcCall.Args)
448 toolCalls = append(toolCalls, message.ToolCall{
449 ID: id,
450 Name: funcCall.Name,
451 Input: string(args),
452 Type: "function",
453 })
454 }
455 }
456 }
457
458 return toolCalls
459}
460
461func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
462 if resp == nil || resp.UsageMetadata == nil {
463 return TokenUsage{}
464 }
465
466 return TokenUsage{
467 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
468 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
469 CacheCreationTokens: 0, // Not directly provided by Gemini
470 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
471 }
472}
473
474func WithGeminiDisableCache() GeminiOption {
475 return func(options *geminiOptions) {
476 options.disableCache = true
477 }
478}
479
480// Helper functions
481func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
482 var result map[string]interface{}
483 err := json.Unmarshal([]byte(jsonStr), &result)
484 return result, err
485}
486
487func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
488 properties := make(map[string]*genai.Schema)
489
490 for name, param := range parameters {
491 properties[name] = convertToSchema(param)
492 }
493
494 return properties
495}
496
497func convertToSchema(param interface{}) *genai.Schema {
498 schema := &genai.Schema{Type: genai.TypeString}
499
500 paramMap, ok := param.(map[string]interface{})
501 if !ok {
502 return schema
503 }
504
505 if desc, ok := paramMap["description"].(string); ok {
506 schema.Description = desc
507 }
508
509 typeVal, hasType := paramMap["type"]
510 if !hasType {
511 return schema
512 }
513
514 typeStr, ok := typeVal.(string)
515 if !ok {
516 return schema
517 }
518
519 schema.Type = mapJSONTypeToGenAI(typeStr)
520
521 switch typeStr {
522 case "array":
523 schema.Items = processArrayItems(paramMap)
524 case "object":
525 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
526 schema.Properties = convertSchemaProperties(props)
527 }
528 }
529
530 return schema
531}
532
533func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
534 items, ok := paramMap["items"].(map[string]interface{})
535 if !ok {
536 return nil
537 }
538
539 return convertToSchema(items)
540}
541
542func mapJSONTypeToGenAI(jsonType string) genai.Type {
543 switch jsonType {
544 case "string":
545 return genai.TypeString
546 case "number":
547 return genai.TypeNumber
548 case "integer":
549 return genai.TypeInteger
550 case "boolean":
551 return genai.TypeBoolean
552 case "array":
553 return genai.TypeArray
554 case "object":
555 return genai.TypeObject
556 default:
557 return genai.TypeString // Default to string for unknown types
558 }
559}
560
561func contains(s string, substrs ...string) bool {
562 for _, substr := range substrs {
563 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
564 return true
565 }
566 }
567 return false
568}