1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/uuid"
13 "github.com/opencode-ai/opencode/internal/config"
14 "github.com/opencode-ai/opencode/internal/llm/tools"
15 "github.com/opencode-ai/opencode/internal/logging"
16 "github.com/opencode-ai/opencode/internal/message"
17 "google.golang.org/genai"
18)
19
20type geminiOptions struct {
21 disableCache bool
22}
23
24type GeminiOption func(*geminiOptions)
25
26type geminiClient struct {
27 providerOptions providerClientOptions
28 options geminiOptions
29 client *genai.Client
30}
31
32type GeminiClient ProviderClient
33
34func newGeminiClient(opts providerClientOptions) GeminiClient {
35 geminiOpts := geminiOptions{}
36 for _, o := range opts.geminiOptions {
37 o(&geminiOpts)
38 }
39
40 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
41 if err != nil {
42 logging.Error("Failed to create Gemini client", "error", err)
43 return nil
44 }
45
46 return &geminiClient{
47 providerOptions: opts,
48 options: geminiOpts,
49 client: client,
50 }
51}
52
53func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
54 var history []*genai.Content
55 for _, msg := range messages {
56 switch msg.Role {
57 case message.User:
58 var parts []*genai.Part
59 parts = append(parts, &genai.Part{Text: msg.Content().String()})
60 for _, binaryContent := range msg.BinaryContent() {
61 imageFormat := strings.Split(binaryContent.MIMEType, "/")
62 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
63 MIMEType: imageFormat[1],
64 Data: binaryContent.Data,
65 }})
66 }
67 history = append(history, &genai.Content{
68 Parts: parts,
69 Role: "user",
70 })
71 case message.Assistant:
72 var assistantParts []*genai.Part
73
74 if msg.Content().String() != "" {
75 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
76 }
77
78 if len(msg.ToolCalls()) > 0 {
79 for _, call := range msg.ToolCalls() {
80 args, _ := parseJsonToMap(call.Input)
81 assistantParts = append(assistantParts, &genai.Part{
82 FunctionCall: &genai.FunctionCall{
83 Name: call.Name,
84 Args: args,
85 },
86 })
87 }
88 }
89
90 if len(assistantParts) > 0 {
91 history = append(history, &genai.Content{
92 Role: "model",
93 Parts: assistantParts,
94 })
95 }
96
97 case message.Tool:
98 for _, result := range msg.ToolResults() {
99 response := map[string]interface{}{"result": result.Content}
100 parsed, err := parseJsonToMap(result.Content)
101 if err == nil {
102 response = parsed
103 }
104
105 var toolCall message.ToolCall
106 for _, m := range messages {
107 if m.Role == message.Assistant {
108 for _, call := range m.ToolCalls() {
109 if call.ID == result.ToolCallID {
110 toolCall = call
111 break
112 }
113 }
114 }
115 }
116
117 history = append(history, &genai.Content{
118 Parts: []*genai.Part{
119 {
120 FunctionResponse: &genai.FunctionResponse{
121 Name: toolCall.Name,
122 Response: response,
123 },
124 },
125 },
126 Role: "function",
127 })
128 }
129 }
130 }
131
132 return history
133}
134
135func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
136 geminiTool := &genai.Tool{}
137 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
138
139 for _, tool := range tools {
140 info := tool.Info()
141 declaration := &genai.FunctionDeclaration{
142 Name: info.Name,
143 Description: info.Description,
144 Parameters: &genai.Schema{
145 Type: genai.TypeObject,
146 Properties: convertSchemaProperties(info.Parameters),
147 Required: info.Required,
148 },
149 }
150
151 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
152 }
153
154 return []*genai.Tool{geminiTool}
155}
156
157func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
158 switch {
159 case reason == genai.FinishReasonStop:
160 return message.FinishReasonEndTurn
161 case reason == genai.FinishReasonMaxTokens:
162 return message.FinishReasonMaxTokens
163 default:
164 return message.FinishReasonUnknown
165 }
166}
167
168func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
169 // Convert messages
170 geminiMessages := g.convertMessages(messages)
171
172 cfg := config.Get()
173 if cfg.Debug {
174 jsonData, _ := json.Marshal(geminiMessages)
175 logging.Debug("Prepared messages", "messages", string(jsonData))
176 }
177
178 history := geminiMessages[:len(geminiMessages)-1] // All but last message
179 lastMsg := geminiMessages[len(geminiMessages)-1]
180 config := &genai.GenerateContentConfig{
181 MaxOutputTokens: int32(g.providerOptions.maxTokens),
182 SystemInstruction: &genai.Content{
183 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
184 },
185 }
186 if len(tools) > 0 {
187 config.Tools = g.convertTools(tools)
188 }
189 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
190
191 attempts := 0
192 for {
193 attempts++
194 var toolCalls []message.ToolCall
195
196 var lastMsgParts []genai.Part
197 for _, part := range lastMsg.Parts {
198 lastMsgParts = append(lastMsgParts, *part)
199 }
200 resp, err := chat.SendMessage(ctx, lastMsgParts...)
201 // If there is an error we are going to see if we can retry the call
202 if err != nil {
203 retry, after, retryErr := g.shouldRetry(attempts, err)
204 if retryErr != nil {
205 return nil, retryErr
206 }
207 if retry {
208 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
209 select {
210 case <-ctx.Done():
211 return nil, ctx.Err()
212 case <-time.After(time.Duration(after) * time.Millisecond):
213 continue
214 }
215 }
216 return nil, retryErr
217 }
218
219 content := ""
220
221 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
222 for _, part := range resp.Candidates[0].Content.Parts {
223 switch {
224 case part.Text != "":
225 content = string(part.Text)
226 case part.FunctionCall != nil:
227 id := "call_" + uuid.New().String()
228 args, _ := json.Marshal(part.FunctionCall.Args)
229 toolCalls = append(toolCalls, message.ToolCall{
230 ID: id,
231 Name: part.FunctionCall.Name,
232 Input: string(args),
233 Type: "function",
234 Finished: true,
235 })
236 }
237 }
238 }
239 finishReason := message.FinishReasonEndTurn
240 if len(resp.Candidates) > 0 {
241 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
242 }
243 if len(toolCalls) > 0 {
244 finishReason = message.FinishReasonToolUse
245 }
246
247 return &ProviderResponse{
248 Content: content,
249 ToolCalls: toolCalls,
250 Usage: g.usage(resp),
251 FinishReason: finishReason,
252 }, nil
253 }
254}
255
256func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
257 // Convert messages
258 geminiMessages := g.convertMessages(messages)
259
260 cfg := config.Get()
261 if cfg.Debug {
262 jsonData, _ := json.Marshal(geminiMessages)
263 logging.Debug("Prepared messages", "messages", string(jsonData))
264 }
265
266 history := geminiMessages[:len(geminiMessages)-1] // All but last message
267 lastMsg := geminiMessages[len(geminiMessages)-1]
268 config := &genai.GenerateContentConfig{
269 MaxOutputTokens: int32(g.providerOptions.maxTokens),
270 SystemInstruction: &genai.Content{
271 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
272 },
273 }
274 if len(tools) > 0 {
275 config.Tools = g.convertTools(tools)
276 }
277 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
278
279 attempts := 0
280 eventChan := make(chan ProviderEvent)
281
282 go func() {
283 defer close(eventChan)
284
285 for {
286 attempts++
287
288 currentContent := ""
289 toolCalls := []message.ToolCall{}
290 var finalResp *genai.GenerateContentResponse
291
292 eventChan <- ProviderEvent{Type: EventContentStart}
293
294 var lastMsgParts []genai.Part
295
296 for _, part := range lastMsg.Parts {
297 lastMsgParts = append(lastMsgParts, *part)
298 }
299 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
300 if err != nil {
301 retry, after, retryErr := g.shouldRetry(attempts, err)
302 if retryErr != nil {
303 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
304 return
305 }
306 if retry {
307 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
308 select {
309 case <-ctx.Done():
310 if ctx.Err() != nil {
311 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
312 }
313
314 return
315 case <-time.After(time.Duration(after) * time.Millisecond):
316 break
317 }
318 } else {
319 eventChan <- ProviderEvent{Type: EventError, Error: err}
320 return
321 }
322 }
323
324 finalResp = resp
325
326 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
327 for _, part := range resp.Candidates[0].Content.Parts {
328 switch {
329 case part.Text != "":
330 delta := string(part.Text)
331 if delta != "" {
332 eventChan <- ProviderEvent{
333 Type: EventContentDelta,
334 Content: delta,
335 }
336 currentContent += delta
337 }
338 case part.FunctionCall != nil:
339 id := "call_" + uuid.New().String()
340 args, _ := json.Marshal(part.FunctionCall.Args)
341 newCall := message.ToolCall{
342 ID: id,
343 Name: part.FunctionCall.Name,
344 Input: string(args),
345 Type: "function",
346 Finished: true,
347 }
348
349 isNew := true
350 for _, existing := range toolCalls {
351 if existing.Name == newCall.Name && existing.Input == newCall.Input {
352 isNew = false
353 break
354 }
355 }
356
357 if isNew {
358 toolCalls = append(toolCalls, newCall)
359 }
360 }
361 }
362 }
363 }
364
365 eventChan <- ProviderEvent{Type: EventContentStop}
366
367 if finalResp != nil {
368
369 finishReason := message.FinishReasonEndTurn
370 if len(finalResp.Candidates) > 0 {
371 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
372 }
373 if len(toolCalls) > 0 {
374 finishReason = message.FinishReasonToolUse
375 }
376 eventChan <- ProviderEvent{
377 Type: EventComplete,
378 Response: &ProviderResponse{
379 Content: currentContent,
380 ToolCalls: toolCalls,
381 Usage: g.usage(finalResp),
382 FinishReason: finishReason,
383 },
384 }
385 return
386 }
387
388 }
389 }()
390
391 return eventChan
392}
393
394func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
395 // Check if error is a rate limit error
396 if attempts > maxRetries {
397 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
398 }
399
400 // Gemini doesn't have a standard error type we can check against
401 // So we'll check the error message for rate limit indicators
402 if errors.Is(err, io.EOF) {
403 return false, 0, err
404 }
405
406 errMsg := err.Error()
407 isRateLimit := false
408
409 // Check for common rate limit error messages
410 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
411 isRateLimit = true
412 }
413
414 if !isRateLimit {
415 return false, 0, err
416 }
417
418 // Calculate backoff with jitter
419 backoffMs := 2000 * (1 << (attempts - 1))
420 jitterMs := int(float64(backoffMs) * 0.2)
421 retryMs := backoffMs + jitterMs
422
423 return true, int64(retryMs), nil
424}
425
426func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
427 var toolCalls []message.ToolCall
428
429 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
430 for _, part := range resp.Candidates[0].Content.Parts {
431 if part.FunctionCall != nil {
432 id := "call_" + uuid.New().String()
433 args, _ := json.Marshal(part.FunctionCall.Args)
434 toolCalls = append(toolCalls, message.ToolCall{
435 ID: id,
436 Name: part.FunctionCall.Name,
437 Input: string(args),
438 Type: "function",
439 })
440 }
441 }
442 }
443
444 return toolCalls
445}
446
447func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
448 if resp == nil || resp.UsageMetadata == nil {
449 return TokenUsage{}
450 }
451
452 return TokenUsage{
453 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
454 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
455 CacheCreationTokens: 0, // Not directly provided by Gemini
456 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
457 }
458}
459
460func WithGeminiDisableCache() GeminiOption {
461 return func(options *geminiOptions) {
462 options.disableCache = true
463 }
464}
465
466// Helper functions
467func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
468 var result map[string]interface{}
469 err := json.Unmarshal([]byte(jsonStr), &result)
470 return result, err
471}
472
473func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
474 properties := make(map[string]*genai.Schema)
475
476 for name, param := range parameters {
477 properties[name] = convertToSchema(param)
478 }
479
480 return properties
481}
482
483func convertToSchema(param interface{}) *genai.Schema {
484 schema := &genai.Schema{Type: genai.TypeString}
485
486 paramMap, ok := param.(map[string]interface{})
487 if !ok {
488 return schema
489 }
490
491 if desc, ok := paramMap["description"].(string); ok {
492 schema.Description = desc
493 }
494
495 typeVal, hasType := paramMap["type"]
496 if !hasType {
497 return schema
498 }
499
500 typeStr, ok := typeVal.(string)
501 if !ok {
502 return schema
503 }
504
505 schema.Type = mapJSONTypeToGenAI(typeStr)
506
507 switch typeStr {
508 case "array":
509 schema.Items = processArrayItems(paramMap)
510 case "object":
511 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
512 schema.Properties = convertSchemaProperties(props)
513 }
514 }
515
516 return schema
517}
518
519func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
520 items, ok := paramMap["items"].(map[string]interface{})
521 if !ok {
522 return nil
523 }
524
525 return convertToSchema(items)
526}
527
528func mapJSONTypeToGenAI(jsonType string) genai.Type {
529 switch jsonType {
530 case "string":
531 return genai.TypeString
532 case "number":
533 return genai.TypeNumber
534 case "integer":
535 return genai.TypeInteger
536 case "boolean":
537 return genai.TypeBoolean
538 case "array":
539 return genai.TypeArray
540 case "object":
541 return genai.TypeObject
542 default:
543 return genai.TypeString // Default to string for unknown types
544 }
545}
546
547func contains(s string, substrs ...string) bool {
548 for _, substr := range substrs {
549 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
550 return true
551 }
552 }
553 return false
554}