gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/generative-ai-go/genai"
 13	"github.com/google/uuid"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18	"google.golang.org/api/iterator"
 19	"google.golang.org/api/option"
 20)
 21
 22type geminiOptions struct {
 23	disableCache bool
 24}
 25
 26type GeminiOption func(*geminiOptions)
 27
 28type geminiClient struct {
 29	providerOptions providerClientOptions
 30	options         geminiOptions
 31	client          *genai.Client
 32}
 33
 34type GeminiClient ProviderClient
 35
 36func newGeminiClient(opts providerClientOptions) GeminiClient {
 37	geminiOpts := geminiOptions{}
 38	for _, o := range opts.geminiOptions {
 39		o(&geminiOpts)
 40	}
 41
 42	client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
 43	if err != nil {
 44		logging.Error("Failed to create Gemini client", "error", err)
 45		return nil
 46	}
 47
 48	return &geminiClient{
 49		providerOptions: opts,
 50		options:         geminiOpts,
 51		client:          client,
 52	}
 53}
 54
 55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 56	var history []*genai.Content
 57
 58	// Add system message first
 59	history = append(history, &genai.Content{
 60		Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
 61		Role:  "user",
 62	})
 63
 64	// Add a system response to acknowledge the system message
 65	history = append(history, &genai.Content{
 66		Parts: []genai.Part{genai.Text("I'll help you with that.")},
 67		Role:  "model",
 68	})
 69
 70	for _, msg := range messages {
 71		switch msg.Role {
 72		case message.User:
 73			history = append(history, &genai.Content{
 74				Parts: []genai.Part{genai.Text(msg.Content().String())},
 75				Role:  "user",
 76			})
 77
 78		case message.Assistant:
 79			content := &genai.Content{
 80				Role:  "model",
 81				Parts: []genai.Part{},
 82			}
 83
 84			if msg.Content().String() != "" {
 85				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 86			}
 87
 88			if len(msg.ToolCalls()) > 0 {
 89				for _, call := range msg.ToolCalls() {
 90					args, _ := parseJsonToMap(call.Input)
 91					content.Parts = append(content.Parts, genai.FunctionCall{
 92						Name: call.Name,
 93						Args: args,
 94					})
 95				}
 96			}
 97
 98			history = append(history, content)
 99
100		case message.Tool:
101			for _, result := range msg.ToolResults() {
102				response := map[string]interface{}{"result": result.Content}
103				parsed, err := parseJsonToMap(result.Content)
104				if err == nil {
105					response = parsed
106				}
107
108				var toolCall message.ToolCall
109				for _, m := range messages {
110					if m.Role == message.Assistant {
111						for _, call := range m.ToolCalls() {
112							if call.ID == result.ToolCallID {
113								toolCall = call
114								break
115							}
116						}
117					}
118				}
119
120				history = append(history, &genai.Content{
121					Parts: []genai.Part{genai.FunctionResponse{
122						Name:     toolCall.Name,
123						Response: response,
124					}},
125					Role: "function",
126				})
127			}
128		}
129	}
130
131	return history
132}
133
134func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
135	geminiTools := make([]*genai.Tool, 0, len(tools))
136
137	for _, tool := range tools {
138		info := tool.Info()
139		declaration := &genai.FunctionDeclaration{
140			Name:        info.Name,
141			Description: info.Description,
142			Parameters: &genai.Schema{
143				Type:       genai.TypeObject,
144				Properties: convertSchemaProperties(info.Parameters),
145				Required:   info.Required,
146			},
147		}
148
149		geminiTools = append(geminiTools, &genai.Tool{
150			FunctionDeclarations: []*genai.FunctionDeclaration{declaration},
151		})
152	}
153
154	return geminiTools
155}
156
157func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
158	reasonStr := reason.String()
159	switch {
160	case reasonStr == "STOP":
161		return message.FinishReasonEndTurn
162	case reasonStr == "MAX_TOKENS":
163		return message.FinishReasonMaxTokens
164	case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
165		return message.FinishReasonToolUse
166	default:
167		return message.FinishReasonUnknown
168	}
169}
170
171func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
172	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
173	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
174
175	// Convert tools
176	if len(tools) > 0 {
177		model.Tools = g.convertTools(tools)
178	}
179
180	// Convert messages
181	geminiMessages := g.convertMessages(messages)
182
183	cfg := config.Get()
184	if cfg.Debug {
185		jsonData, _ := json.Marshal(geminiMessages)
186		logging.Debug("Prepared messages", "messages", string(jsonData))
187	}
188
189	attempts := 0
190	for {
191		attempts++
192		chat := model.StartChat()
193		chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
194
195		lastMsg := geminiMessages[len(geminiMessages)-1]
196		var lastText string
197		for _, part := range lastMsg.Parts {
198			if text, ok := part.(genai.Text); ok {
199				lastText = string(text)
200				break
201			}
202		}
203
204		resp, err := chat.SendMessage(ctx, genai.Text(lastText))
205		// If there is an error we are going to see if we can retry the call
206		if err != nil {
207			retry, after, retryErr := g.shouldRetry(attempts, err)
208			if retryErr != nil {
209				return nil, retryErr
210			}
211			if retry {
212				logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
213				select {
214				case <-ctx.Done():
215					return nil, ctx.Err()
216				case <-time.After(time.Duration(after) * time.Millisecond):
217					continue
218				}
219			}
220			return nil, retryErr
221		}
222
223		content := ""
224		var toolCalls []message.ToolCall
225
226		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
227			for _, part := range resp.Candidates[0].Content.Parts {
228				switch p := part.(type) {
229				case genai.Text:
230					content = string(p)
231				case genai.FunctionCall:
232					id := "call_" + uuid.New().String()
233					args, _ := json.Marshal(p.Args)
234					toolCalls = append(toolCalls, message.ToolCall{
235						ID:    id,
236						Name:  p.Name,
237						Input: string(args),
238						Type:  "function",
239					})
240				}
241			}
242		}
243
244		return &ProviderResponse{
245			Content:      content,
246			ToolCalls:    toolCalls,
247			Usage:        g.usage(resp),
248			FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
249		}, nil
250	}
251}
252
253func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
254	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
255	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
256
257	// Convert tools
258	if len(tools) > 0 {
259		model.Tools = g.convertTools(tools)
260	}
261
262	// Convert messages
263	geminiMessages := g.convertMessages(messages)
264
265	cfg := config.Get()
266	if cfg.Debug {
267		jsonData, _ := json.Marshal(geminiMessages)
268		logging.Debug("Prepared messages", "messages", string(jsonData))
269	}
270
271	attempts := 0
272	eventChan := make(chan ProviderEvent)
273
274	go func() {
275		defer close(eventChan)
276
277		for {
278			attempts++
279			chat := model.StartChat()
280			chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
281
282			lastMsg := geminiMessages[len(geminiMessages)-1]
283			var lastText string
284			for _, part := range lastMsg.Parts {
285				if text, ok := part.(genai.Text); ok {
286					lastText = string(text)
287					break
288				}
289			}
290
291			iter := chat.SendMessageStream(ctx, genai.Text(lastText))
292
293			currentContent := ""
294			toolCalls := []message.ToolCall{}
295			var finalResp *genai.GenerateContentResponse
296
297			eventChan <- ProviderEvent{Type: EventContentStart}
298
299			for {
300				resp, err := iter.Next()
301				if err == iterator.Done {
302					break
303				}
304				if err != nil {
305					retry, after, retryErr := g.shouldRetry(attempts, err)
306					if retryErr != nil {
307						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
308						return
309					}
310					if retry {
311						logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
312						select {
313						case <-ctx.Done():
314							if ctx.Err() != nil {
315								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
316							}
317
318							return
319						case <-time.After(time.Duration(after) * time.Millisecond):
320							break
321						}
322					} else {
323						eventChan <- ProviderEvent{Type: EventError, Error: err}
324						return
325					}
326				}
327
328				finalResp = resp
329
330				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
331					for _, part := range resp.Candidates[0].Content.Parts {
332						switch p := part.(type) {
333						case genai.Text:
334							newText := string(p)
335							delta := newText[len(currentContent):]
336							if delta != "" {
337								eventChan <- ProviderEvent{
338									Type:    EventContentDelta,
339									Content: delta,
340								}
341								currentContent = newText
342							}
343						case genai.FunctionCall:
344							id := "call_" + uuid.New().String()
345							args, _ := json.Marshal(p.Args)
346							newCall := message.ToolCall{
347								ID:    id,
348								Name:  p.Name,
349								Input: string(args),
350								Type:  "function",
351							}
352
353							isNew := true
354							for _, existing := range toolCalls {
355								if existing.Name == newCall.Name && existing.Input == newCall.Input {
356									isNew = false
357									break
358								}
359							}
360
361							if isNew {
362								toolCalls = append(toolCalls, newCall)
363							}
364						}
365					}
366				}
367			}
368
369			eventChan <- ProviderEvent{Type: EventContentStop}
370
371			if finalResp != nil {
372				eventChan <- ProviderEvent{
373					Type: EventComplete,
374					Response: &ProviderResponse{
375						Content:      currentContent,
376						ToolCalls:    toolCalls,
377						Usage:        g.usage(finalResp),
378						FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
379					},
380				}
381				return
382			}
383
384			// If we get here, we need to retry
385			if attempts > maxRetries {
386				eventChan <- ProviderEvent{
387					Type:  EventError,
388					Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
389				}
390				return
391			}
392
393			// Wait before retrying
394			select {
395			case <-ctx.Done():
396				if ctx.Err() != nil {
397					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
398				}
399				return
400			case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
401				continue
402			}
403		}
404	}()
405
406	return eventChan
407}
408
409func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
410	// Check if error is a rate limit error
411	if attempts > maxRetries {
412		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
413	}
414
415	// Gemini doesn't have a standard error type we can check against
416	// So we'll check the error message for rate limit indicators
417	if errors.Is(err, io.EOF) {
418		return false, 0, err
419	}
420
421	errMsg := err.Error()
422	isRateLimit := false
423
424	// Check for common rate limit error messages
425	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
426		isRateLimit = true
427	}
428
429	if !isRateLimit {
430		return false, 0, err
431	}
432
433	// Calculate backoff with jitter
434	backoffMs := 2000 * (1 << (attempts - 1))
435	jitterMs := int(float64(backoffMs) * 0.2)
436	retryMs := backoffMs + jitterMs
437
438	return true, int64(retryMs), nil
439}
440
441func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
442	var toolCalls []message.ToolCall
443
444	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
445		for _, part := range resp.Candidates[0].Content.Parts {
446			if funcCall, ok := part.(genai.FunctionCall); ok {
447				id := "call_" + uuid.New().String()
448				args, _ := json.Marshal(funcCall.Args)
449				toolCalls = append(toolCalls, message.ToolCall{
450					ID:    id,
451					Name:  funcCall.Name,
452					Input: string(args),
453					Type:  "function",
454				})
455			}
456		}
457	}
458
459	return toolCalls
460}
461
462func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
463	if resp == nil || resp.UsageMetadata == nil {
464		return TokenUsage{}
465	}
466
467	return TokenUsage{
468		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
469		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
470		CacheCreationTokens: 0, // Not directly provided by Gemini
471		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
472	}
473}
474
475func WithGeminiDisableCache() GeminiOption {
476	return func(options *geminiOptions) {
477		options.disableCache = true
478	}
479}
480
481// Helper functions
482func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
483	var result map[string]interface{}
484	err := json.Unmarshal([]byte(jsonStr), &result)
485	return result, err
486}
487
488func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
489	properties := make(map[string]*genai.Schema)
490
491	for name, param := range parameters {
492		properties[name] = convertToSchema(param)
493	}
494
495	return properties
496}
497
498func convertToSchema(param interface{}) *genai.Schema {
499	schema := &genai.Schema{Type: genai.TypeString}
500
501	paramMap, ok := param.(map[string]interface{})
502	if !ok {
503		return schema
504	}
505
506	if desc, ok := paramMap["description"].(string); ok {
507		schema.Description = desc
508	}
509
510	typeVal, hasType := paramMap["type"]
511	if !hasType {
512		return schema
513	}
514
515	typeStr, ok := typeVal.(string)
516	if !ok {
517		return schema
518	}
519
520	schema.Type = mapJSONTypeToGenAI(typeStr)
521
522	switch typeStr {
523	case "array":
524		schema.Items = processArrayItems(paramMap)
525	case "object":
526		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
527			schema.Properties = convertSchemaProperties(props)
528		}
529	}
530
531	return schema
532}
533
534func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
535	items, ok := paramMap["items"].(map[string]interface{})
536	if !ok {
537		return nil
538	}
539
540	return convertToSchema(items)
541}
542
543func mapJSONTypeToGenAI(jsonType string) genai.Type {
544	switch jsonType {
545	case "string":
546		return genai.TypeString
547	case "number":
548		return genai.TypeNumber
549	case "integer":
550		return genai.TypeInteger
551	case "boolean":
552		return genai.TypeBoolean
553	case "array":
554		return genai.TypeArray
555	case "object":
556		return genai.TypeObject
557	default:
558		return genai.TypeString // Default to string for unknown types
559	}
560}
561
562func contains(s string, substrs ...string) bool {
563	for _, substr := range substrs {
564		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
565			return true
566		}
567	}
568	return false
569}