1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/generative-ai-go/genai"
13 "github.com/google/uuid"
14 "github.com/opencode-ai/opencode/internal/config"
15 "github.com/opencode-ai/opencode/internal/llm/tools"
16 "github.com/opencode-ai/opencode/internal/logging"
17 "github.com/opencode-ai/opencode/internal/message"
18 "google.golang.org/api/iterator"
19 "google.golang.org/api/option"
20)
21
22type geminiOptions struct {
23 disableCache bool
24}
25
26type GeminiOption func(*geminiOptions)
27
28type geminiClient struct {
29 providerOptions providerClientOptions
30 options geminiOptions
31 client *genai.Client
32}
33
34type GeminiClient ProviderClient
35
36func newGeminiClient(opts providerClientOptions) GeminiClient {
37 geminiOpts := geminiOptions{}
38 for _, o := range opts.geminiOptions {
39 o(&geminiOpts)
40 }
41
42 client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
43 if err != nil {
44 logging.Error("Failed to create Gemini client", "error", err)
45 return nil
46 }
47
48 return &geminiClient{
49 providerOptions: opts,
50 options: geminiOpts,
51 client: client,
52 }
53}
54
55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
56 var history []*genai.Content
57 for _, msg := range messages {
58 switch msg.Role {
59 case message.User:
60 history = append(history, &genai.Content{
61 Parts: []genai.Part{genai.Text(msg.Content().String())},
62 Role: "user",
63 })
64
65 case message.Assistant:
66 content := &genai.Content{
67 Role: "model",
68 Parts: []genai.Part{},
69 }
70
71 if msg.Content().String() != "" {
72 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
73 }
74
75 if len(msg.ToolCalls()) > 0 {
76 for _, call := range msg.ToolCalls() {
77 args, _ := parseJsonToMap(call.Input)
78 content.Parts = append(content.Parts, genai.FunctionCall{
79 Name: call.Name,
80 Args: args,
81 })
82 }
83 }
84
85 history = append(history, content)
86
87 case message.Tool:
88 for _, result := range msg.ToolResults() {
89 response := map[string]interface{}{"result": result.Content}
90 parsed, err := parseJsonToMap(result.Content)
91 if err == nil {
92 response = parsed
93 }
94
95 var toolCall message.ToolCall
96 for _, m := range messages {
97 if m.Role == message.Assistant {
98 for _, call := range m.ToolCalls() {
99 if call.ID == result.ToolCallID {
100 toolCall = call
101 break
102 }
103 }
104 }
105 }
106
107 history = append(history, &genai.Content{
108 Parts: []genai.Part{genai.FunctionResponse{
109 Name: toolCall.Name,
110 Response: response,
111 }},
112 Role: "function",
113 })
114 }
115 }
116 }
117
118 return history
119}
120
121func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
122 geminiTool := &genai.Tool{}
123 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
124
125 for _, tool := range tools {
126 info := tool.Info()
127 declaration := &genai.FunctionDeclaration{
128 Name: info.Name,
129 Description: info.Description,
130 Parameters: &genai.Schema{
131 Type: genai.TypeObject,
132 Properties: convertSchemaProperties(info.Parameters),
133 Required: info.Required,
134 },
135 }
136
137 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
138 }
139
140 return []*genai.Tool{geminiTool}
141}
142
143func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
144 switch {
145 case reason == genai.FinishReasonStop:
146 return message.FinishReasonEndTurn
147 case reason == genai.FinishReasonMaxTokens:
148 return message.FinishReasonMaxTokens
149 default:
150 return message.FinishReasonUnknown
151 }
152}
153
154func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
155 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
156 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
157 model.SystemInstruction = &genai.Content{
158 Parts: []genai.Part{
159 genai.Text(g.providerOptions.systemMessage),
160 },
161 }
162 // Convert tools
163 if len(tools) > 0 {
164 model.Tools = g.convertTools(tools)
165 }
166
167 // Convert messages
168 geminiMessages := g.convertMessages(messages)
169
170 cfg := config.Get()
171 if cfg.Debug {
172 jsonData, _ := json.Marshal(geminiMessages)
173 logging.Debug("Prepared messages", "messages", string(jsonData))
174 }
175
176 attempts := 0
177 for {
178 attempts++
179 var toolCalls []message.ToolCall
180 chat := model.StartChat()
181 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
182
183 lastMsg := geminiMessages[len(geminiMessages)-1]
184
185 resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
186 // If there is an error we are going to see if we can retry the call
187 if err != nil {
188 retry, after, retryErr := g.shouldRetry(attempts, err)
189 if retryErr != nil {
190 return nil, retryErr
191 }
192 if retry {
193 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
194 select {
195 case <-ctx.Done():
196 return nil, ctx.Err()
197 case <-time.After(time.Duration(after) * time.Millisecond):
198 continue
199 }
200 }
201 return nil, retryErr
202 }
203
204 content := ""
205
206 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
207 for _, part := range resp.Candidates[0].Content.Parts {
208 switch p := part.(type) {
209 case genai.Text:
210 content = string(p)
211 case genai.FunctionCall:
212 id := "call_" + uuid.New().String()
213 args, _ := json.Marshal(p.Args)
214 toolCalls = append(toolCalls, message.ToolCall{
215 ID: id,
216 Name: p.Name,
217 Input: string(args),
218 Type: "function",
219 Finished: true,
220 })
221 }
222 }
223 }
224 finishReason := message.FinishReasonEndTurn
225 if len(resp.Candidates) > 0 {
226 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
227 }
228 if len(toolCalls) > 0 {
229 finishReason = message.FinishReasonToolUse
230 }
231
232 return &ProviderResponse{
233 Content: content,
234 ToolCalls: toolCalls,
235 Usage: g.usage(resp),
236 FinishReason: finishReason,
237 }, nil
238 }
239}
240
241func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
242 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
243 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
244 model.SystemInstruction = &genai.Content{
245 Parts: []genai.Part{
246 genai.Text(g.providerOptions.systemMessage),
247 },
248 }
249 // Convert tools
250 if len(tools) > 0 {
251 model.Tools = g.convertTools(tools)
252 }
253
254 // Convert messages
255 geminiMessages := g.convertMessages(messages)
256
257 cfg := config.Get()
258 if cfg.Debug {
259 jsonData, _ := json.Marshal(geminiMessages)
260 logging.Debug("Prepared messages", "messages", string(jsonData))
261 }
262
263 attempts := 0
264 eventChan := make(chan ProviderEvent)
265
266 go func() {
267 defer close(eventChan)
268
269 for {
270 attempts++
271 chat := model.StartChat()
272 chat.History = geminiMessages[:len(geminiMessages)-1]
273 lastMsg := geminiMessages[len(geminiMessages)-1]
274
275 iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
276
277 currentContent := ""
278 toolCalls := []message.ToolCall{}
279 var finalResp *genai.GenerateContentResponse
280
281 eventChan <- ProviderEvent{Type: EventContentStart}
282
283 for {
284 resp, err := iter.Next()
285 if err == iterator.Done {
286 break
287 }
288 if err != nil {
289 retry, after, retryErr := g.shouldRetry(attempts, err)
290 if retryErr != nil {
291 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
292 return
293 }
294 if retry {
295 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
296 select {
297 case <-ctx.Done():
298 if ctx.Err() != nil {
299 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
300 }
301
302 return
303 case <-time.After(time.Duration(after) * time.Millisecond):
304 break
305 }
306 } else {
307 eventChan <- ProviderEvent{Type: EventError, Error: err}
308 return
309 }
310 }
311
312 finalResp = resp
313
314 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
315 for _, part := range resp.Candidates[0].Content.Parts {
316 switch p := part.(type) {
317 case genai.Text:
318 delta := string(p)
319 if delta != "" {
320 eventChan <- ProviderEvent{
321 Type: EventContentDelta,
322 Content: delta,
323 }
324 currentContent += delta
325 }
326 case genai.FunctionCall:
327 id := "call_" + uuid.New().String()
328 args, _ := json.Marshal(p.Args)
329 newCall := message.ToolCall{
330 ID: id,
331 Name: p.Name,
332 Input: string(args),
333 Type: "function",
334 Finished: true,
335 }
336
337 isNew := true
338 for _, existing := range toolCalls {
339 if existing.Name == newCall.Name && existing.Input == newCall.Input {
340 isNew = false
341 break
342 }
343 }
344
345 if isNew {
346 toolCalls = append(toolCalls, newCall)
347 }
348 }
349 }
350 }
351 }
352
353 eventChan <- ProviderEvent{Type: EventContentStop}
354
355 if finalResp != nil {
356
357 finishReason := message.FinishReasonEndTurn
358 if len(finalResp.Candidates) > 0 {
359 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
360 }
361 if len(toolCalls) > 0 {
362 finishReason = message.FinishReasonToolUse
363 }
364 eventChan <- ProviderEvent{
365 Type: EventComplete,
366 Response: &ProviderResponse{
367 Content: currentContent,
368 ToolCalls: toolCalls,
369 Usage: g.usage(finalResp),
370 FinishReason: finishReason,
371 },
372 }
373 return
374 }
375
376 }
377 }()
378
379 return eventChan
380}
381
382func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
383 // Check if error is a rate limit error
384 if attempts > maxRetries {
385 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
386 }
387
388 // Gemini doesn't have a standard error type we can check against
389 // So we'll check the error message for rate limit indicators
390 if errors.Is(err, io.EOF) {
391 return false, 0, err
392 }
393
394 errMsg := err.Error()
395 isRateLimit := false
396
397 // Check for common rate limit error messages
398 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
399 isRateLimit = true
400 }
401
402 if !isRateLimit {
403 return false, 0, err
404 }
405
406 // Calculate backoff with jitter
407 backoffMs := 2000 * (1 << (attempts - 1))
408 jitterMs := int(float64(backoffMs) * 0.2)
409 retryMs := backoffMs + jitterMs
410
411 return true, int64(retryMs), nil
412}
413
414func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
415 var toolCalls []message.ToolCall
416
417 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
418 for _, part := range resp.Candidates[0].Content.Parts {
419 if funcCall, ok := part.(genai.FunctionCall); ok {
420 id := "call_" + uuid.New().String()
421 args, _ := json.Marshal(funcCall.Args)
422 toolCalls = append(toolCalls, message.ToolCall{
423 ID: id,
424 Name: funcCall.Name,
425 Input: string(args),
426 Type: "function",
427 })
428 }
429 }
430 }
431
432 return toolCalls
433}
434
435func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
436 if resp == nil || resp.UsageMetadata == nil {
437 return TokenUsage{}
438 }
439
440 return TokenUsage{
441 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
442 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
443 CacheCreationTokens: 0, // Not directly provided by Gemini
444 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
445 }
446}
447
448func WithGeminiDisableCache() GeminiOption {
449 return func(options *geminiOptions) {
450 options.disableCache = true
451 }
452}
453
454// Helper functions
455func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
456 var result map[string]interface{}
457 err := json.Unmarshal([]byte(jsonStr), &result)
458 return result, err
459}
460
461func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
462 properties := make(map[string]*genai.Schema)
463
464 for name, param := range parameters {
465 properties[name] = convertToSchema(param)
466 }
467
468 return properties
469}
470
471func convertToSchema(param interface{}) *genai.Schema {
472 schema := &genai.Schema{Type: genai.TypeString}
473
474 paramMap, ok := param.(map[string]interface{})
475 if !ok {
476 return schema
477 }
478
479 if desc, ok := paramMap["description"].(string); ok {
480 schema.Description = desc
481 }
482
483 typeVal, hasType := paramMap["type"]
484 if !hasType {
485 return schema
486 }
487
488 typeStr, ok := typeVal.(string)
489 if !ok {
490 return schema
491 }
492
493 schema.Type = mapJSONTypeToGenAI(typeStr)
494
495 switch typeStr {
496 case "array":
497 schema.Items = processArrayItems(paramMap)
498 case "object":
499 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
500 schema.Properties = convertSchemaProperties(props)
501 }
502 }
503
504 return schema
505}
506
507func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
508 items, ok := paramMap["items"].(map[string]interface{})
509 if !ok {
510 return nil
511 }
512
513 return convertToSchema(items)
514}
515
516func mapJSONTypeToGenAI(jsonType string) genai.Type {
517 switch jsonType {
518 case "string":
519 return genai.TypeString
520 case "number":
521 return genai.TypeNumber
522 case "integer":
523 return genai.TypeInteger
524 case "boolean":
525 return genai.TypeBoolean
526 case "array":
527 return genai.TypeArray
528 case "object":
529 return genai.TypeObject
530 default:
531 return genai.TypeString // Default to string for unknown types
532 }
533}
534
535func contains(s string, substrs ...string) bool {
536 for _, substr := range substrs {
537 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
538 return true
539 }
540 }
541 return false
542}