gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/uuid"
 13	"github.com/opencode-ai/opencode/internal/config"
 14	"github.com/opencode-ai/opencode/internal/llm/tools"
 15	"github.com/opencode-ai/opencode/internal/logging"
 16	"github.com/opencode-ai/opencode/internal/message"
 17	"google.golang.org/genai"
 18)
 19
 20type geminiOptions struct {
 21	disableCache bool
 22}
 23
 24type GeminiOption func(*geminiOptions)
 25
 26type geminiClient struct {
 27	providerOptions providerClientOptions
 28	options         geminiOptions
 29	client          *genai.Client
 30}
 31
 32type GeminiClient ProviderClient
 33
 34func newGeminiClient(opts providerClientOptions) GeminiClient {
 35	geminiOpts := geminiOptions{}
 36	for _, o := range opts.geminiOptions {
 37		o(&geminiOpts)
 38	}
 39
 40	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 41	if err != nil {
 42		logging.Error("Failed to create Gemini client", "error", err)
 43		return nil
 44	}
 45
 46	return &geminiClient{
 47		providerOptions: opts,
 48		options:         geminiOpts,
 49		client:          client,
 50	}
 51}
 52
 53func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 54	var history []*genai.Content
 55	for _, msg := range messages {
 56		switch msg.Role {
 57		case message.User:
 58			var parts []*genai.Part
 59			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 60			for _, binaryContent := range msg.BinaryContent() {
 61				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 62				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 63					MIMEType: imageFormat[1],
 64					Data:     binaryContent.Data,
 65				}})
 66			}
 67			history = append(history, &genai.Content{
 68				Parts: parts,
 69				Role:  "user",
 70			})
 71		case message.Assistant:
 72			var assistantParts []*genai.Part
 73
 74			if msg.Content().String() != "" {
 75				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 76			}
 77
 78			if len(msg.ToolCalls()) > 0 {
 79				for _, call := range msg.ToolCalls() {
 80					args, _ := parseJsonToMap(call.Input)
 81					assistantParts = append(assistantParts, &genai.Part{
 82						FunctionCall: &genai.FunctionCall{
 83							Name: call.Name,
 84							Args: args,
 85						},
 86					})
 87				}
 88			}
 89
 90			if len(assistantParts) > 0 {
 91				history = append(history, &genai.Content{
 92					Role:  "model",
 93					Parts: assistantParts,
 94				})
 95			}
 96
 97		case message.Tool:
 98			for _, result := range msg.ToolResults() {
 99				response := map[string]interface{}{"result": result.Content}
100				parsed, err := parseJsonToMap(result.Content)
101				if err == nil {
102					response = parsed
103				}
104
105				var toolCall message.ToolCall
106				for _, m := range messages {
107					if m.Role == message.Assistant {
108						for _, call := range m.ToolCalls() {
109							if call.ID == result.ToolCallID {
110								toolCall = call
111								break
112							}
113						}
114					}
115				}
116
117				history = append(history, &genai.Content{
118					Parts: []*genai.Part{
119						{
120							FunctionResponse: &genai.FunctionResponse{
121								Name:     toolCall.Name,
122								Response: response,
123							},
124						},
125					},
126					Role: "function",
127				})
128			}
129		}
130	}
131
132	return history
133}
134
135func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
136	geminiTool := &genai.Tool{}
137	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
138
139	for _, tool := range tools {
140		info := tool.Info()
141		declaration := &genai.FunctionDeclaration{
142			Name:        info.Name,
143			Description: info.Description,
144			Parameters: &genai.Schema{
145				Type:       genai.TypeObject,
146				Properties: convertSchemaProperties(info.Parameters),
147				Required:   info.Required,
148			},
149		}
150
151		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
152	}
153
154	return []*genai.Tool{geminiTool}
155}
156
157func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
158	switch {
159	case reason == genai.FinishReasonStop:
160		return message.FinishReasonEndTurn
161	case reason == genai.FinishReasonMaxTokens:
162		return message.FinishReasonMaxTokens
163	default:
164		return message.FinishReasonUnknown
165	}
166}
167
168func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
169	// Convert messages
170	geminiMessages := g.convertMessages(messages)
171
172	cfg := config.Get()
173	if cfg.Debug {
174		jsonData, _ := json.Marshal(geminiMessages)
175		logging.Debug("Prepared messages", "messages", string(jsonData))
176	}
177
178	history := geminiMessages[:len(geminiMessages)-1] // All but last message
179	lastMsg := geminiMessages[len(geminiMessages)-1]
180	config := &genai.GenerateContentConfig{
181		MaxOutputTokens: int32(g.providerOptions.maxTokens),
182		SystemInstruction: &genai.Content{
183			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
184		},
185	}
186	if len(tools) > 0 {
187		config.Tools = g.convertTools(tools)
188	}
189	chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
190
191	attempts := 0
192	for {
193		attempts++
194		var toolCalls []message.ToolCall
195
196		var lastMsgParts []genai.Part
197		for _, part := range lastMsg.Parts {
198			lastMsgParts = append(lastMsgParts, *part)
199		}
200		resp, err := chat.SendMessage(ctx, lastMsgParts...)
201		// If there is an error we are going to see if we can retry the call
202		if err != nil {
203			retry, after, retryErr := g.shouldRetry(attempts, err)
204			if retryErr != nil {
205				return nil, retryErr
206			}
207			if retry {
208				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
209				select {
210				case <-ctx.Done():
211					return nil, ctx.Err()
212				case <-time.After(time.Duration(after) * time.Millisecond):
213					continue
214				}
215			}
216			return nil, retryErr
217		}
218
219		content := ""
220
221		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
222			for _, part := range resp.Candidates[0].Content.Parts {
223				switch {
224				case part.Text != "":
225					content = string(part.Text)
226				case part.FunctionCall != nil:
227					id := "call_" + uuid.New().String()
228					args, _ := json.Marshal(part.FunctionCall.Args)
229					toolCalls = append(toolCalls, message.ToolCall{
230						ID:       id,
231						Name:     part.FunctionCall.Name,
232						Input:    string(args),
233						Type:     "function",
234						Finished: true,
235					})
236				}
237			}
238		}
239		finishReason := message.FinishReasonEndTurn
240		if len(resp.Candidates) > 0 {
241			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
242		}
243		if len(toolCalls) > 0 {
244			finishReason = message.FinishReasonToolUse
245		}
246
247		return &ProviderResponse{
248			Content:      content,
249			ToolCalls:    toolCalls,
250			Usage:        g.usage(resp),
251			FinishReason: finishReason,
252		}, nil
253	}
254}
255
256func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
257	// Convert messages
258	geminiMessages := g.convertMessages(messages)
259
260	cfg := config.Get()
261	if cfg.Debug {
262		jsonData, _ := json.Marshal(geminiMessages)
263		logging.Debug("Prepared messages", "messages", string(jsonData))
264	}
265
266	history := geminiMessages[:len(geminiMessages)-1] // All but last message
267	lastMsg := geminiMessages[len(geminiMessages)-1]
268	config := &genai.GenerateContentConfig{
269		MaxOutputTokens: int32(g.providerOptions.maxTokens),
270		SystemInstruction: &genai.Content{
271			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
272		},
273	}
274	if len(tools) > 0 {
275		config.Tools = g.convertTools(tools)
276	}
277	chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
278
279	attempts := 0
280	eventChan := make(chan ProviderEvent)
281
282	go func() {
283		defer close(eventChan)
284
285		for {
286			attempts++
287
288			currentContent := ""
289			toolCalls := []message.ToolCall{}
290			var finalResp *genai.GenerateContentResponse
291
292			eventChan <- ProviderEvent{Type: EventContentStart}
293
294			var lastMsgParts []genai.Part
295
296			for _, part := range lastMsg.Parts {
297				lastMsgParts = append(lastMsgParts, *part)
298			}
299			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
300				if err != nil {
301					retry, after, retryErr := g.shouldRetry(attempts, err)
302					if retryErr != nil {
303						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
304						return
305					}
306					if retry {
307						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
308						select {
309						case <-ctx.Done():
310							if ctx.Err() != nil {
311								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
312							}
313
314							return
315						case <-time.After(time.Duration(after) * time.Millisecond):
316							break
317						}
318					} else {
319						eventChan <- ProviderEvent{Type: EventError, Error: err}
320						return
321					}
322				}
323
324				finalResp = resp
325
326				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
327					for _, part := range resp.Candidates[0].Content.Parts {
328						switch {
329						case part.Text != "":
330							delta := string(part.Text)
331							if delta != "" {
332								eventChan <- ProviderEvent{
333									Type:    EventContentDelta,
334									Content: delta,
335								}
336								currentContent += delta
337							}
338						case part.FunctionCall != nil:
339							id := "call_" + uuid.New().String()
340							args, _ := json.Marshal(part.FunctionCall.Args)
341							newCall := message.ToolCall{
342								ID:       id,
343								Name:     part.FunctionCall.Name,
344								Input:    string(args),
345								Type:     "function",
346								Finished: true,
347							}
348
349							isNew := true
350							for _, existing := range toolCalls {
351								if existing.Name == newCall.Name && existing.Input == newCall.Input {
352									isNew = false
353									break
354								}
355							}
356
357							if isNew {
358								toolCalls = append(toolCalls, newCall)
359							}
360						}
361					}
362				}
363			}
364
365			eventChan <- ProviderEvent{Type: EventContentStop}
366
367			if finalResp != nil {
368
369				finishReason := message.FinishReasonEndTurn
370				if len(finalResp.Candidates) > 0 {
371					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
372				}
373				if len(toolCalls) > 0 {
374					finishReason = message.FinishReasonToolUse
375				}
376				eventChan <- ProviderEvent{
377					Type: EventComplete,
378					Response: &ProviderResponse{
379						Content:      currentContent,
380						ToolCalls:    toolCalls,
381						Usage:        g.usage(finalResp),
382						FinishReason: finishReason,
383					},
384				}
385				return
386			}
387
388		}
389	}()
390
391	return eventChan
392}
393
394func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
395	// Check if error is a rate limit error
396	if attempts > maxRetries {
397		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
398	}
399
400	// Gemini doesn't have a standard error type we can check against
401	// So we'll check the error message for rate limit indicators
402	if errors.Is(err, io.EOF) {
403		return false, 0, err
404	}
405
406	errMsg := err.Error()
407	isRateLimit := false
408
409	// Check for common rate limit error messages
410	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
411		isRateLimit = true
412	}
413
414	if !isRateLimit {
415		return false, 0, err
416	}
417
418	// Calculate backoff with jitter
419	backoffMs := 2000 * (1 << (attempts - 1))
420	jitterMs := int(float64(backoffMs) * 0.2)
421	retryMs := backoffMs + jitterMs
422
423	return true, int64(retryMs), nil
424}
425
426func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
427	var toolCalls []message.ToolCall
428
429	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
430		for _, part := range resp.Candidates[0].Content.Parts {
431			if part.FunctionCall != nil {
432				id := "call_" + uuid.New().String()
433				args, _ := json.Marshal(part.FunctionCall.Args)
434				toolCalls = append(toolCalls, message.ToolCall{
435					ID:    id,
436					Name:  part.FunctionCall.Name,
437					Input: string(args),
438					Type:  "function",
439				})
440			}
441		}
442	}
443
444	return toolCalls
445}
446
447func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
448	if resp == nil || resp.UsageMetadata == nil {
449		return TokenUsage{}
450	}
451
452	return TokenUsage{
453		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
454		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
455		CacheCreationTokens: 0, // Not directly provided by Gemini
456		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
457	}
458}
459
460func WithGeminiDisableCache() GeminiOption {
461	return func(options *geminiOptions) {
462		options.disableCache = true
463	}
464}
465
466// Helper functions
467func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
468	var result map[string]interface{}
469	err := json.Unmarshal([]byte(jsonStr), &result)
470	return result, err
471}
472
473func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
474	properties := make(map[string]*genai.Schema)
475
476	for name, param := range parameters {
477		properties[name] = convertToSchema(param)
478	}
479
480	return properties
481}
482
483func convertToSchema(param interface{}) *genai.Schema {
484	schema := &genai.Schema{Type: genai.TypeString}
485
486	paramMap, ok := param.(map[string]interface{})
487	if !ok {
488		return schema
489	}
490
491	if desc, ok := paramMap["description"].(string); ok {
492		schema.Description = desc
493	}
494
495	typeVal, hasType := paramMap["type"]
496	if !hasType {
497		return schema
498	}
499
500	typeStr, ok := typeVal.(string)
501	if !ok {
502		return schema
503	}
504
505	schema.Type = mapJSONTypeToGenAI(typeStr)
506
507	switch typeStr {
508	case "array":
509		schema.Items = processArrayItems(paramMap)
510	case "object":
511		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
512			schema.Properties = convertSchemaProperties(props)
513		}
514	}
515
516	return schema
517}
518
519func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
520	items, ok := paramMap["items"].(map[string]interface{})
521	if !ok {
522		return nil
523	}
524
525	return convertToSchema(items)
526}
527
528func mapJSONTypeToGenAI(jsonType string) genai.Type {
529	switch jsonType {
530	case "string":
531		return genai.TypeString
532	case "number":
533		return genai.TypeNumber
534	case "integer":
535		return genai.TypeInteger
536	case "boolean":
537		return genai.TypeBoolean
538	case "array":
539		return genai.TypeArray
540	case "object":
541		return genai.TypeObject
542	default:
543		return genai.TypeString // Default to string for unknown types
544	}
545}
546
547func contains(s string, substrs ...string) bool {
548	for _, substr := range substrs {
549		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
550			return true
551		}
552	}
553	return false
554}