gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/generative-ai-go/genai"
 13	"github.com/google/uuid"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18	"google.golang.org/api/iterator"
 19	"google.golang.org/api/option"
 20)
 21
 22type geminiOptions struct {
 23	disableCache bool
 24}
 25
 26type GeminiOption func(*geminiOptions)
 27
 28type geminiClient struct {
 29	providerOptions providerClientOptions
 30	options         geminiOptions
 31	client          *genai.Client
 32}
 33
 34type GeminiClient ProviderClient
 35
 36func newGeminiClient(opts providerClientOptions) GeminiClient {
 37	geminiOpts := geminiOptions{}
 38	for _, o := range opts.geminiOptions {
 39		o(&geminiOpts)
 40	}
 41
 42	client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
 43	if err != nil {
 44		logging.Error("Failed to create Gemini client", "error", err)
 45		return nil
 46	}
 47
 48	return &geminiClient{
 49		providerOptions: opts,
 50		options:         geminiOpts,
 51		client:          client,
 52	}
 53}
 54
 55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 56	var history []*genai.Content
 57
 58	// Add system message first
 59	history = append(history, &genai.Content{
 60		Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
 61		Role:  "user",
 62	})
 63
 64	// Add a system response to acknowledge the system message
 65	history = append(history, &genai.Content{
 66		Parts: []genai.Part{genai.Text("I'll help you with that.")},
 67		Role:  "model",
 68	})
 69
 70	for _, msg := range messages {
 71		switch msg.Role {
 72		case message.User:
 73			history = append(history, &genai.Content{
 74				Parts: []genai.Part{genai.Text(msg.Content().String())},
 75				Role:  "user",
 76			})
 77
 78		case message.Assistant:
 79			content := &genai.Content{
 80				Role:  "model",
 81				Parts: []genai.Part{},
 82			}
 83
 84			if msg.Content().String() != "" {
 85				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 86			}
 87
 88			if len(msg.ToolCalls()) > 0 {
 89				for _, call := range msg.ToolCalls() {
 90					args, _ := parseJsonToMap(call.Input)
 91					content.Parts = append(content.Parts, genai.FunctionCall{
 92						Name: call.Name,
 93						Args: args,
 94					})
 95				}
 96			}
 97
 98			history = append(history, content)
 99
100		case message.Tool:
101			for _, result := range msg.ToolResults() {
102				response := map[string]interface{}{"result": result.Content}
103				parsed, err := parseJsonToMap(result.Content)
104				if err == nil {
105					response = parsed
106				}
107
108				var toolCall message.ToolCall
109				for _, m := range messages {
110					if m.Role == message.Assistant {
111						for _, call := range m.ToolCalls() {
112							if call.ID == result.ToolCallID {
113								toolCall = call
114								break
115							}
116						}
117					}
118				}
119
120				history = append(history, &genai.Content{
121					Parts: []genai.Part{genai.FunctionResponse{
122						Name:     toolCall.Name,
123						Response: response,
124					}},
125					Role: "function",
126				})
127			}
128		}
129	}
130
131	return history
132}
133
134func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
135	geminiTool := &genai.Tool{}
136	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
137
138	for _, tool := range tools {
139		info := tool.Info()
140		declaration := &genai.FunctionDeclaration{
141			Name:        info.Name,
142			Description: info.Description,
143			Parameters: &genai.Schema{
144				Type:       genai.TypeObject,
145				Properties: convertSchemaProperties(info.Parameters),
146				Required:   info.Required,
147			},
148		}
149
150		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
151	}
152
153	return []*genai.Tool{geminiTool}
154}
155
156func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
157	reasonStr := reason.String()
158	switch {
159	case reasonStr == "STOP":
160		return message.FinishReasonEndTurn
161	case reasonStr == "MAX_TOKENS":
162		return message.FinishReasonMaxTokens
163	case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
164		return message.FinishReasonToolUse
165	default:
166		return message.FinishReasonUnknown
167	}
168}
169
170func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
171	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
172	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
173
174	// Convert tools
175	if len(tools) > 0 {
176		model.Tools = g.convertTools(tools)
177	}
178
179	// Convert messages
180	geminiMessages := g.convertMessages(messages)
181
182	cfg := config.Get()
183	if cfg.Debug {
184		jsonData, _ := json.Marshal(geminiMessages)
185		logging.Debug("Prepared messages", "messages", string(jsonData))
186	}
187
188	attempts := 0
189	for {
190		attempts++
191		chat := model.StartChat()
192		chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
193
194		lastMsg := geminiMessages[len(geminiMessages)-1]
195		var lastText string
196		for _, part := range lastMsg.Parts {
197			if text, ok := part.(genai.Text); ok {
198				lastText = string(text)
199				break
200			}
201		}
202
203		resp, err := chat.SendMessage(ctx, genai.Text(lastText))
204		// If there is an error we are going to see if we can retry the call
205		if err != nil {
206			retry, after, retryErr := g.shouldRetry(attempts, err)
207			if retryErr != nil {
208				return nil, retryErr
209			}
210			if retry {
211				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
212				select {
213				case <-ctx.Done():
214					return nil, ctx.Err()
215				case <-time.After(time.Duration(after) * time.Millisecond):
216					continue
217				}
218			}
219			return nil, retryErr
220		}
221
222		content := ""
223		var toolCalls []message.ToolCall
224
225		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
226			for _, part := range resp.Candidates[0].Content.Parts {
227				switch p := part.(type) {
228				case genai.Text:
229					content = string(p)
230				case genai.FunctionCall:
231					id := "call_" + uuid.New().String()
232					args, _ := json.Marshal(p.Args)
233					toolCalls = append(toolCalls, message.ToolCall{
234						ID:    id,
235						Name:  p.Name,
236						Input: string(args),
237						Type:  "function",
238					})
239				}
240			}
241		}
242
243		return &ProviderResponse{
244			Content:      content,
245			ToolCalls:    toolCalls,
246			Usage:        g.usage(resp),
247			FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
248		}, nil
249	}
250}
251
252func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
254	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
255
256	// Convert tools
257	if len(tools) > 0 {
258		model.Tools = g.convertTools(tools)
259	}
260
261	// Convert messages
262	geminiMessages := g.convertMessages(messages)
263
264	cfg := config.Get()
265	if cfg.Debug {
266		jsonData, _ := json.Marshal(geminiMessages)
267		logging.Debug("Prepared messages", "messages", string(jsonData))
268	}
269
270	attempts := 0
271	eventChan := make(chan ProviderEvent)
272
273	go func() {
274		defer close(eventChan)
275
276		for {
277			attempts++
278			chat := model.StartChat()
279			chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
280
281			lastMsg := geminiMessages[len(geminiMessages)-1]
282			var lastText string
283			for _, part := range lastMsg.Parts {
284				if text, ok := part.(genai.Text); ok {
285					lastText = string(text)
286					break
287				}
288			}
289
290			iter := chat.SendMessageStream(ctx, genai.Text(lastText))
291
292			currentContent := ""
293			toolCalls := []message.ToolCall{}
294			var finalResp *genai.GenerateContentResponse
295
296			eventChan <- ProviderEvent{Type: EventContentStart}
297
298			for {
299				resp, err := iter.Next()
300				if err == iterator.Done {
301					break
302				}
303				if err != nil {
304					retry, after, retryErr := g.shouldRetry(attempts, err)
305					if retryErr != nil {
306						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
307						return
308					}
309					if retry {
310						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
311						select {
312						case <-ctx.Done():
313							if ctx.Err() != nil {
314								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
315							}
316
317							return
318						case <-time.After(time.Duration(after) * time.Millisecond):
319							break
320						}
321					} else {
322						eventChan <- ProviderEvent{Type: EventError, Error: err}
323						return
324					}
325				}
326
327				finalResp = resp
328
329				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
330					for _, part := range resp.Candidates[0].Content.Parts {
331						switch p := part.(type) {
332						case genai.Text:
333							newText := string(p)
334							delta := newText[len(currentContent):]
335							if delta != "" {
336								eventChan <- ProviderEvent{
337									Type:    EventContentDelta,
338									Content: delta,
339								}
340								currentContent = newText
341							}
342						case genai.FunctionCall:
343							id := "call_" + uuid.New().String()
344							args, _ := json.Marshal(p.Args)
345							newCall := message.ToolCall{
346								ID:    id,
347								Name:  p.Name,
348								Input: string(args),
349								Type:  "function",
350							}
351
352							isNew := true
353							for _, existing := range toolCalls {
354								if existing.Name == newCall.Name && existing.Input == newCall.Input {
355									isNew = false
356									break
357								}
358							}
359
360							if isNew {
361								toolCalls = append(toolCalls, newCall)
362							}
363						}
364					}
365				}
366			}
367
368			eventChan <- ProviderEvent{Type: EventContentStop}
369
370			if finalResp != nil {
371				eventChan <- ProviderEvent{
372					Type: EventComplete,
373					Response: &ProviderResponse{
374						Content:      currentContent,
375						ToolCalls:    toolCalls,
376						Usage:        g.usage(finalResp),
377						FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
378					},
379				}
380				return
381			}
382
383			// If we get here, we need to retry
384			if attempts > maxRetries {
385				eventChan <- ProviderEvent{
386					Type:  EventError,
387					Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
388				}
389				return
390			}
391
392			// Wait before retrying
393			select {
394			case <-ctx.Done():
395				if ctx.Err() != nil {
396					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
397				}
398				return
399			case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
400				continue
401			}
402		}
403	}()
404
405	return eventChan
406}
407
408func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
409	// Check if error is a rate limit error
410	if attempts > maxRetries {
411		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
412	}
413
414	// Gemini doesn't have a standard error type we can check against
415	// So we'll check the error message for rate limit indicators
416	if errors.Is(err, io.EOF) {
417		return false, 0, err
418	}
419
420	errMsg := err.Error()
421	isRateLimit := false
422
423	// Check for common rate limit error messages
424	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
425		isRateLimit = true
426	}
427
428	if !isRateLimit {
429		return false, 0, err
430	}
431
432	// Calculate backoff with jitter
433	backoffMs := 2000 * (1 << (attempts - 1))
434	jitterMs := int(float64(backoffMs) * 0.2)
435	retryMs := backoffMs + jitterMs
436
437	return true, int64(retryMs), nil
438}
439
440func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
441	var toolCalls []message.ToolCall
442
443	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
444		for _, part := range resp.Candidates[0].Content.Parts {
445			if funcCall, ok := part.(genai.FunctionCall); ok {
446				id := "call_" + uuid.New().String()
447				args, _ := json.Marshal(funcCall.Args)
448				toolCalls = append(toolCalls, message.ToolCall{
449					ID:    id,
450					Name:  funcCall.Name,
451					Input: string(args),
452					Type:  "function",
453				})
454			}
455		}
456	}
457
458	return toolCalls
459}
460
461func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
462	if resp == nil || resp.UsageMetadata == nil {
463		return TokenUsage{}
464	}
465
466	return TokenUsage{
467		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
468		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
469		CacheCreationTokens: 0, // Not directly provided by Gemini
470		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
471	}
472}
473
474func WithGeminiDisableCache() GeminiOption {
475	return func(options *geminiOptions) {
476		options.disableCache = true
477	}
478}
479
480// Helper functions
481func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
482	var result map[string]interface{}
483	err := json.Unmarshal([]byte(jsonStr), &result)
484	return result, err
485}
486
487func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
488	properties := make(map[string]*genai.Schema)
489
490	for name, param := range parameters {
491		properties[name] = convertToSchema(param)
492	}
493
494	return properties
495}
496
497func convertToSchema(param interface{}) *genai.Schema {
498	schema := &genai.Schema{Type: genai.TypeString}
499
500	paramMap, ok := param.(map[string]interface{})
501	if !ok {
502		return schema
503	}
504
505	if desc, ok := paramMap["description"].(string); ok {
506		schema.Description = desc
507	}
508
509	typeVal, hasType := paramMap["type"]
510	if !hasType {
511		return schema
512	}
513
514	typeStr, ok := typeVal.(string)
515	if !ok {
516		return schema
517	}
518
519	schema.Type = mapJSONTypeToGenAI(typeStr)
520
521	switch typeStr {
522	case "array":
523		schema.Items = processArrayItems(paramMap)
524	case "object":
525		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
526			schema.Properties = convertSchemaProperties(props)
527		}
528	}
529
530	return schema
531}
532
533func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
534	items, ok := paramMap["items"].(map[string]interface{})
535	if !ok {
536		return nil
537	}
538
539	return convertToSchema(items)
540}
541
542func mapJSONTypeToGenAI(jsonType string) genai.Type {
543	switch jsonType {
544	case "string":
545		return genai.TypeString
546	case "number":
547		return genai.TypeNumber
548	case "integer":
549		return genai.TypeInteger
550	case "boolean":
551		return genai.TypeBoolean
552	case "array":
553		return genai.TypeArray
554	case "object":
555		return genai.TypeObject
556	default:
557		return genai.TypeString // Default to string for unknown types
558	}
559}
560
561func contains(s string, substrs ...string) bool {
562	for _, substr := range substrs {
563		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
564			return true
565		}
566	}
567	return false
568}