1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/uuid"
13 "github.com/opencode-ai/opencode/internal/config"
14 "github.com/opencode-ai/opencode/internal/llm/tools"
15 "github.com/opencode-ai/opencode/internal/logging"
16 "github.com/opencode-ai/opencode/internal/message"
17 "google.golang.org/genai"
18)
19
20type geminiOptions struct {
21 disableCache bool
22}
23
24type GeminiOption func(*geminiOptions)
25
26type geminiClient struct {
27 providerOptions providerClientOptions
28 options geminiOptions
29 client *genai.Client
30}
31
32type GeminiClient ProviderClient
33
34func newGeminiClient(opts providerClientOptions) GeminiClient {
35 geminiOpts := geminiOptions{}
36 for _, o := range opts.geminiOptions {
37 o(&geminiOpts)
38 }
39
40 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
41 if err != nil {
42 logging.Error("Failed to create Gemini client", "error", err)
43 return nil
44 }
45
46 return &geminiClient{
47 providerOptions: opts,
48 options: geminiOpts,
49 client: client,
50 }
51}
52
53func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
54 var history []*genai.Content
55 for _, msg := range messages {
56 switch msg.Role {
57 case message.User:
58 var parts []*genai.Part
59 parts = append(parts, &genai.Part{Text: msg.Content().String()})
60 for _, binaryContent := range msg.BinaryContent() {
61 imageFormat := strings.Split(binaryContent.MIMEType, "/")
62 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
63 MIMEType: imageFormat[1],
64 Data: binaryContent.Data,
65 }})
66 }
67 history = append(history, &genai.Content{
68 Parts: parts,
69 Role: "user",
70 })
71 case message.Assistant:
72 content := &genai.Content{
73 Role: "model",
74 Parts: []*genai.Part{},
75 }
76
77 if msg.Content().String() != "" {
78 content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
79 }
80
81 if len(msg.ToolCalls()) > 0 {
82 for _, call := range msg.ToolCalls() {
83 args, _ := parseJsonToMap(call.Input)
84 content.Parts = append(content.Parts, &genai.Part{
85 FunctionCall: &genai.FunctionCall{
86 Name: call.Name,
87 Args: args,
88 },
89 })
90 }
91 }
92
93 history = append(history, content)
94
95 case message.Tool:
96 for _, result := range msg.ToolResults() {
97 response := map[string]interface{}{"result": result.Content}
98 parsed, err := parseJsonToMap(result.Content)
99 if err == nil {
100 response = parsed
101 }
102
103 var toolCall message.ToolCall
104 for _, m := range messages {
105 if m.Role == message.Assistant {
106 for _, call := range m.ToolCalls() {
107 if call.ID == result.ToolCallID {
108 toolCall = call
109 break
110 }
111 }
112 }
113 }
114
115 history = append(history, &genai.Content{
116 Parts: []*genai.Part{
117 {
118 FunctionResponse: &genai.FunctionResponse{
119 Name: toolCall.Name,
120 Response: response,
121 },
122 },
123 },
124 Role: "function",
125 })
126 }
127 }
128 }
129
130 return history
131}
132
133func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
134 geminiTool := &genai.Tool{}
135 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
136
137 for _, tool := range tools {
138 info := tool.Info()
139 declaration := &genai.FunctionDeclaration{
140 Name: info.Name,
141 Description: info.Description,
142 Parameters: &genai.Schema{
143 Type: genai.TypeObject,
144 Properties: convertSchemaProperties(info.Parameters),
145 Required: info.Required,
146 },
147 }
148
149 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
150 }
151
152 return []*genai.Tool{geminiTool}
153}
154
155func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
156 switch {
157 case reason == genai.FinishReasonStop:
158 return message.FinishReasonEndTurn
159 case reason == genai.FinishReasonMaxTokens:
160 return message.FinishReasonMaxTokens
161 default:
162 return message.FinishReasonUnknown
163 }
164}
165
166func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
167 // Convert messages
168 geminiMessages := g.convertMessages(messages)
169
170 cfg := config.Get()
171 if cfg.Debug {
172 jsonData, _ := json.Marshal(geminiMessages)
173 logging.Debug("Prepared messages", "messages", string(jsonData))
174 }
175
176 history := geminiMessages[:len(geminiMessages)-1] // All but last message
177 lastMsg := geminiMessages[len(geminiMessages)-1]
178 config := &genai.GenerateContentConfig{
179 MaxOutputTokens: int32(g.providerOptions.maxTokens),
180 SystemInstruction: &genai.Content{
181 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
182 },
183 }
184 if len(tools) > 0 {
185 config.Tools = g.convertTools(tools)
186 }
187 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
188
189 attempts := 0
190 for {
191 attempts++
192 var toolCalls []message.ToolCall
193
194 var lastMsgParts []genai.Part
195 for _, part := range lastMsg.Parts {
196 lastMsgParts = append(lastMsgParts, *part)
197 }
198 resp, err := chat.SendMessage(ctx, lastMsgParts...)
199 // If there is an error we are going to see if we can retry the call
200 if err != nil {
201 retry, after, retryErr := g.shouldRetry(attempts, err)
202 if retryErr != nil {
203 return nil, retryErr
204 }
205 if retry {
206 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
207 select {
208 case <-ctx.Done():
209 return nil, ctx.Err()
210 case <-time.After(time.Duration(after) * time.Millisecond):
211 continue
212 }
213 }
214 return nil, retryErr
215 }
216
217 content := ""
218
219 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
220 for _, part := range resp.Candidates[0].Content.Parts {
221 switch {
222 case part.Text != "":
223 content = string(part.Text)
224 case part.FunctionCall != nil:
225 id := "call_" + uuid.New().String()
226 args, _ := json.Marshal(part.FunctionCall.Args)
227 toolCalls = append(toolCalls, message.ToolCall{
228 ID: id,
229 Name: part.FunctionCall.Name,
230 Input: string(args),
231 Type: "function",
232 Finished: true,
233 })
234 }
235 }
236 }
237 finishReason := message.FinishReasonEndTurn
238 if len(resp.Candidates) > 0 {
239 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
240 }
241 if len(toolCalls) > 0 {
242 finishReason = message.FinishReasonToolUse
243 }
244
245 return &ProviderResponse{
246 Content: content,
247 ToolCalls: toolCalls,
248 Usage: g.usage(resp),
249 FinishReason: finishReason,
250 }, nil
251 }
252}
253
254func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
255 // Convert messages
256 geminiMessages := g.convertMessages(messages)
257
258 cfg := config.Get()
259 if cfg.Debug {
260 jsonData, _ := json.Marshal(geminiMessages)
261 logging.Debug("Prepared messages", "messages", string(jsonData))
262 }
263
264 history := geminiMessages[:len(geminiMessages)-1] // All but last message
265 lastMsg := geminiMessages[len(geminiMessages)-1]
266 config := &genai.GenerateContentConfig{
267 MaxOutputTokens: int32(g.providerOptions.maxTokens),
268 SystemInstruction: &genai.Content{
269 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
270 },
271 }
272 if len(tools) > 0 {
273 config.Tools = g.convertTools(tools)
274 }
275 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
276
277 attempts := 0
278 eventChan := make(chan ProviderEvent)
279
280 go func() {
281 defer close(eventChan)
282
283 for {
284 attempts++
285
286 currentContent := ""
287 toolCalls := []message.ToolCall{}
288 var finalResp *genai.GenerateContentResponse
289
290 eventChan <- ProviderEvent{Type: EventContentStart}
291
292 var lastMsgParts []genai.Part
293
294 for _, part := range lastMsg.Parts {
295 lastMsgParts = append(lastMsgParts, *part)
296 }
297 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
298 if err != nil {
299 retry, after, retryErr := g.shouldRetry(attempts, err)
300 if retryErr != nil {
301 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
302 return
303 }
304 if retry {
305 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
306 select {
307 case <-ctx.Done():
308 if ctx.Err() != nil {
309 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
310 }
311
312 return
313 case <-time.After(time.Duration(after) * time.Millisecond):
314 break
315 }
316 } else {
317 eventChan <- ProviderEvent{Type: EventError, Error: err}
318 return
319 }
320 }
321
322 finalResp = resp
323
324 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
325 for _, part := range resp.Candidates[0].Content.Parts {
326 switch {
327 case part.Text != "":
328 delta := string(part.Text)
329 if delta != "" {
330 eventChan <- ProviderEvent{
331 Type: EventContentDelta,
332 Content: delta,
333 }
334 currentContent += delta
335 }
336 case part.FunctionCall != nil:
337 id := "call_" + uuid.New().String()
338 args, _ := json.Marshal(part.FunctionCall.Args)
339 newCall := message.ToolCall{
340 ID: id,
341 Name: part.FunctionCall.Name,
342 Input: string(args),
343 Type: "function",
344 Finished: true,
345 }
346
347 isNew := true
348 for _, existing := range toolCalls {
349 if existing.Name == newCall.Name && existing.Input == newCall.Input {
350 isNew = false
351 break
352 }
353 }
354
355 if isNew {
356 toolCalls = append(toolCalls, newCall)
357 }
358 }
359 }
360 }
361 }
362
363 eventChan <- ProviderEvent{Type: EventContentStop}
364
365 if finalResp != nil {
366
367 finishReason := message.FinishReasonEndTurn
368 if len(finalResp.Candidates) > 0 {
369 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
370 }
371 if len(toolCalls) > 0 {
372 finishReason = message.FinishReasonToolUse
373 }
374 eventChan <- ProviderEvent{
375 Type: EventComplete,
376 Response: &ProviderResponse{
377 Content: currentContent,
378 ToolCalls: toolCalls,
379 Usage: g.usage(finalResp),
380 FinishReason: finishReason,
381 },
382 }
383 return
384 }
385
386 }
387 }()
388
389 return eventChan
390}
391
392func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
393 // Check if error is a rate limit error
394 if attempts > maxRetries {
395 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
396 }
397
398 // Gemini doesn't have a standard error type we can check against
399 // So we'll check the error message for rate limit indicators
400 if errors.Is(err, io.EOF) {
401 return false, 0, err
402 }
403
404 errMsg := err.Error()
405 isRateLimit := false
406
407 // Check for common rate limit error messages
408 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
409 isRateLimit = true
410 }
411
412 if !isRateLimit {
413 return false, 0, err
414 }
415
416 // Calculate backoff with jitter
417 backoffMs := 2000 * (1 << (attempts - 1))
418 jitterMs := int(float64(backoffMs) * 0.2)
419 retryMs := backoffMs + jitterMs
420
421 return true, int64(retryMs), nil
422}
423
424func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
425 var toolCalls []message.ToolCall
426
427 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
428 for _, part := range resp.Candidates[0].Content.Parts {
429 if part.FunctionCall != nil {
430 id := "call_" + uuid.New().String()
431 args, _ := json.Marshal(part.FunctionCall.Args)
432 toolCalls = append(toolCalls, message.ToolCall{
433 ID: id,
434 Name: part.FunctionCall.Name,
435 Input: string(args),
436 Type: "function",
437 })
438 }
439 }
440 }
441
442 return toolCalls
443}
444
445func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
446 if resp == nil || resp.UsageMetadata == nil {
447 return TokenUsage{}
448 }
449
450 return TokenUsage{
451 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
452 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
453 CacheCreationTokens: 0, // Not directly provided by Gemini
454 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
455 }
456}
457
458func WithGeminiDisableCache() GeminiOption {
459 return func(options *geminiOptions) {
460 options.disableCache = true
461 }
462}
463
464// Helper functions
465func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
466 var result map[string]interface{}
467 err := json.Unmarshal([]byte(jsonStr), &result)
468 return result, err
469}
470
471func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
472 properties := make(map[string]*genai.Schema)
473
474 for name, param := range parameters {
475 properties[name] = convertToSchema(param)
476 }
477
478 return properties
479}
480
481func convertToSchema(param interface{}) *genai.Schema {
482 schema := &genai.Schema{Type: genai.TypeString}
483
484 paramMap, ok := param.(map[string]interface{})
485 if !ok {
486 return schema
487 }
488
489 if desc, ok := paramMap["description"].(string); ok {
490 schema.Description = desc
491 }
492
493 typeVal, hasType := paramMap["type"]
494 if !hasType {
495 return schema
496 }
497
498 typeStr, ok := typeVal.(string)
499 if !ok {
500 return schema
501 }
502
503 schema.Type = mapJSONTypeToGenAI(typeStr)
504
505 switch typeStr {
506 case "array":
507 schema.Items = processArrayItems(paramMap)
508 case "object":
509 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
510 schema.Properties = convertSchemaProperties(props)
511 }
512 }
513
514 return schema
515}
516
517func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
518 items, ok := paramMap["items"].(map[string]interface{})
519 if !ok {
520 return nil
521 }
522
523 return convertToSchema(items)
524}
525
526func mapJSONTypeToGenAI(jsonType string) genai.Type {
527 switch jsonType {
528 case "string":
529 return genai.TypeString
530 case "number":
531 return genai.TypeNumber
532 case "integer":
533 return genai.TypeInteger
534 case "boolean":
535 return genai.TypeBoolean
536 case "array":
537 return genai.TypeArray
538 case "object":
539 return genai.TypeObject
540 default:
541 return genai.TypeString // Default to string for unknown types
542 }
543}
544
545func contains(s string, substrs ...string) bool {
546 for _, substr := range substrs {
547 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
548 return true
549 }
550 }
551 return false
552}