gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/generative-ai-go/genai"
 13	"github.com/google/uuid"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18	"google.golang.org/api/iterator"
 19	"google.golang.org/api/option"
 20)
 21
 22type geminiOptions struct {
 23	disableCache bool
 24}
 25
 26type GeminiOption func(*geminiOptions)
 27
 28type geminiClient struct {
 29	providerOptions providerClientOptions
 30	options         geminiOptions
 31	client          *genai.Client
 32}
 33
 34type GeminiClient ProviderClient
 35
 36func newGeminiClient(opts providerClientOptions) GeminiClient {
 37	geminiOpts := geminiOptions{}
 38	for _, o := range opts.geminiOptions {
 39		o(&geminiOpts)
 40	}
 41
 42	client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
 43	if err != nil {
 44		logging.Error("Failed to create Gemini client", "error", err)
 45		return nil
 46	}
 47
 48	return &geminiClient{
 49		providerOptions: opts,
 50		options:         geminiOpts,
 51		client:          client,
 52	}
 53}
 54
 55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 56	var history []*genai.Content
 57	for _, msg := range messages {
 58		switch msg.Role {
 59		case message.User:
 60			var parts []genai.Part
 61			parts = append(parts, genai.Text(msg.Content().String()))
 62			for _, binaryContent := range msg.BinaryContent() {
 63				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 64				parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
 65			}
 66			history = append(history, &genai.Content{
 67				Parts: parts,
 68				Role:  "user",
 69			})
 70		case message.Assistant:
 71			content := &genai.Content{
 72				Role:  "model",
 73				Parts: []genai.Part{},
 74			}
 75
 76			if msg.Content().String() != "" {
 77				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 78			}
 79
 80			if len(msg.ToolCalls()) > 0 {
 81				for _, call := range msg.ToolCalls() {
 82					args, _ := parseJsonToMap(call.Input)
 83					content.Parts = append(content.Parts, genai.FunctionCall{
 84						Name: call.Name,
 85						Args: args,
 86					})
 87				}
 88			}
 89
 90			history = append(history, content)
 91
 92		case message.Tool:
 93			for _, result := range msg.ToolResults() {
 94				response := map[string]interface{}{"result": result.Content}
 95				parsed, err := parseJsonToMap(result.Content)
 96				if err == nil {
 97					response = parsed
 98				}
 99
100				var toolCall message.ToolCall
101				for _, m := range messages {
102					if m.Role == message.Assistant {
103						for _, call := range m.ToolCalls() {
104							if call.ID == result.ToolCallID {
105								toolCall = call
106								break
107							}
108						}
109					}
110				}
111
112				history = append(history, &genai.Content{
113					Parts: []genai.Part{genai.FunctionResponse{
114						Name:     toolCall.Name,
115						Response: response,
116					}},
117					Role: "function",
118				})
119			}
120		}
121	}
122
123	return history
124}
125
126func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
127	geminiTool := &genai.Tool{}
128	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
129
130	for _, tool := range tools {
131		info := tool.Info()
132		declaration := &genai.FunctionDeclaration{
133			Name:        info.Name,
134			Description: info.Description,
135			Parameters: &genai.Schema{
136				Type:       genai.TypeObject,
137				Properties: convertSchemaProperties(info.Parameters),
138				Required:   info.Required,
139			},
140		}
141
142		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
143	}
144
145	return []*genai.Tool{geminiTool}
146}
147
148func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
149	switch {
150	case reason == genai.FinishReasonStop:
151		return message.FinishReasonEndTurn
152	case reason == genai.FinishReasonMaxTokens:
153		return message.FinishReasonMaxTokens
154	default:
155		return message.FinishReasonUnknown
156	}
157}
158
159func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
160	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
161	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
162	model.SystemInstruction = &genai.Content{
163		Parts: []genai.Part{
164			genai.Text(g.providerOptions.systemMessage),
165		},
166	}
167	// Convert tools
168	if len(tools) > 0 {
169		model.Tools = g.convertTools(tools)
170	}
171
172	// Convert messages
173	geminiMessages := g.convertMessages(messages)
174
175	cfg := config.Get()
176	if cfg.Debug {
177		jsonData, _ := json.Marshal(geminiMessages)
178		logging.Debug("Prepared messages", "messages", string(jsonData))
179	}
180
181	attempts := 0
182	for {
183		attempts++
184		var toolCalls []message.ToolCall
185		chat := model.StartChat()
186		chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
187
188		lastMsg := geminiMessages[len(geminiMessages)-1]
189
190		resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
191		// If there is an error we are going to see if we can retry the call
192		if err != nil {
193			retry, after, retryErr := g.shouldRetry(attempts, err)
194			if retryErr != nil {
195				return nil, retryErr
196			}
197			if retry {
198				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
199				select {
200				case <-ctx.Done():
201					return nil, ctx.Err()
202				case <-time.After(time.Duration(after) * time.Millisecond):
203					continue
204				}
205			}
206			return nil, retryErr
207		}
208
209		content := ""
210
211		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
212			for _, part := range resp.Candidates[0].Content.Parts {
213				switch p := part.(type) {
214				case genai.Text:
215					content = string(p)
216				case genai.FunctionCall:
217					id := "call_" + uuid.New().String()
218					args, _ := json.Marshal(p.Args)
219					toolCalls = append(toolCalls, message.ToolCall{
220						ID:       id,
221						Name:     p.Name,
222						Input:    string(args),
223						Type:     "function",
224						Finished: true,
225					})
226				}
227			}
228		}
229		finishReason := message.FinishReasonEndTurn
230		if len(resp.Candidates) > 0 {
231			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
232		}
233		if len(toolCalls) > 0 {
234			finishReason = message.FinishReasonToolUse
235		}
236
237		return &ProviderResponse{
238			Content:      content,
239			ToolCalls:    toolCalls,
240			Usage:        g.usage(resp),
241			FinishReason: finishReason,
242		}, nil
243	}
244}
245
246func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
247	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
248	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
249	model.SystemInstruction = &genai.Content{
250		Parts: []genai.Part{
251			genai.Text(g.providerOptions.systemMessage),
252		},
253	}
254	// Convert tools
255	if len(tools) > 0 {
256		model.Tools = g.convertTools(tools)
257	}
258
259	// Convert messages
260	geminiMessages := g.convertMessages(messages)
261
262	cfg := config.Get()
263	if cfg.Debug {
264		jsonData, _ := json.Marshal(geminiMessages)
265		logging.Debug("Prepared messages", "messages", string(jsonData))
266	}
267
268	attempts := 0
269	eventChan := make(chan ProviderEvent)
270
271	go func() {
272		defer close(eventChan)
273
274		for {
275			attempts++
276			chat := model.StartChat()
277			chat.History = geminiMessages[:len(geminiMessages)-1]
278			lastMsg := geminiMessages[len(geminiMessages)-1]
279
280			iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
281
282			currentContent := ""
283			toolCalls := []message.ToolCall{}
284			var finalResp *genai.GenerateContentResponse
285
286			eventChan <- ProviderEvent{Type: EventContentStart}
287
288			for {
289				resp, err := iter.Next()
290				if err == iterator.Done {
291					break
292				}
293				if err != nil {
294					retry, after, retryErr := g.shouldRetry(attempts, err)
295					if retryErr != nil {
296						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
297						return
298					}
299					if retry {
300						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
301						select {
302						case <-ctx.Done():
303							if ctx.Err() != nil {
304								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
305							}
306
307							return
308						case <-time.After(time.Duration(after) * time.Millisecond):
309							break
310						}
311					} else {
312						eventChan <- ProviderEvent{Type: EventError, Error: err}
313						return
314					}
315				}
316
317				finalResp = resp
318
319				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
320					for _, part := range resp.Candidates[0].Content.Parts {
321						switch p := part.(type) {
322						case genai.Text:
323							delta := string(p)
324							if delta != "" {
325								eventChan <- ProviderEvent{
326									Type:    EventContentDelta,
327									Content: delta,
328								}
329								currentContent += delta
330							}
331						case genai.FunctionCall:
332							id := "call_" + uuid.New().String()
333							args, _ := json.Marshal(p.Args)
334							newCall := message.ToolCall{
335								ID:       id,
336								Name:     p.Name,
337								Input:    string(args),
338								Type:     "function",
339								Finished: true,
340							}
341
342							isNew := true
343							for _, existing := range toolCalls {
344								if existing.Name == newCall.Name && existing.Input == newCall.Input {
345									isNew = false
346									break
347								}
348							}
349
350							if isNew {
351								toolCalls = append(toolCalls, newCall)
352							}
353						}
354					}
355				}
356			}
357
358			eventChan <- ProviderEvent{Type: EventContentStop}
359
360			if finalResp != nil {
361
362				finishReason := message.FinishReasonEndTurn
363				if len(finalResp.Candidates) > 0 {
364					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
365				}
366				if len(toolCalls) > 0 {
367					finishReason = message.FinishReasonToolUse
368				}
369				eventChan <- ProviderEvent{
370					Type: EventComplete,
371					Response: &ProviderResponse{
372						Content:      currentContent,
373						ToolCalls:    toolCalls,
374						Usage:        g.usage(finalResp),
375						FinishReason: finishReason,
376					},
377				}
378				return
379			}
380
381		}
382	}()
383
384	return eventChan
385}
386
387func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
388	// Check if error is a rate limit error
389	if attempts > maxRetries {
390		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
391	}
392
393	// Gemini doesn't have a standard error type we can check against
394	// So we'll check the error message for rate limit indicators
395	if errors.Is(err, io.EOF) {
396		return false, 0, err
397	}
398
399	errMsg := err.Error()
400	isRateLimit := false
401
402	// Check for common rate limit error messages
403	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
404		isRateLimit = true
405	}
406
407	if !isRateLimit {
408		return false, 0, err
409	}
410
411	// Calculate backoff with jitter
412	backoffMs := 2000 * (1 << (attempts - 1))
413	jitterMs := int(float64(backoffMs) * 0.2)
414	retryMs := backoffMs + jitterMs
415
416	return true, int64(retryMs), nil
417}
418
419func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
420	var toolCalls []message.ToolCall
421
422	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
423		for _, part := range resp.Candidates[0].Content.Parts {
424			if funcCall, ok := part.(genai.FunctionCall); ok {
425				id := "call_" + uuid.New().String()
426				args, _ := json.Marshal(funcCall.Args)
427				toolCalls = append(toolCalls, message.ToolCall{
428					ID:    id,
429					Name:  funcCall.Name,
430					Input: string(args),
431					Type:  "function",
432				})
433			}
434		}
435	}
436
437	return toolCalls
438}
439
440func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
441	if resp == nil || resp.UsageMetadata == nil {
442		return TokenUsage{}
443	}
444
445	return TokenUsage{
446		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
447		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
448		CacheCreationTokens: 0, // Not directly provided by Gemini
449		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
450	}
451}
452
453func WithGeminiDisableCache() GeminiOption {
454	return func(options *geminiOptions) {
455		options.disableCache = true
456	}
457}
458
459// Helper functions
460func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
461	var result map[string]interface{}
462	err := json.Unmarshal([]byte(jsonStr), &result)
463	return result, err
464}
465
466func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
467	properties := make(map[string]*genai.Schema)
468
469	for name, param := range parameters {
470		properties[name] = convertToSchema(param)
471	}
472
473	return properties
474}
475
476func convertToSchema(param interface{}) *genai.Schema {
477	schema := &genai.Schema{Type: genai.TypeString}
478
479	paramMap, ok := param.(map[string]interface{})
480	if !ok {
481		return schema
482	}
483
484	if desc, ok := paramMap["description"].(string); ok {
485		schema.Description = desc
486	}
487
488	typeVal, hasType := paramMap["type"]
489	if !hasType {
490		return schema
491	}
492
493	typeStr, ok := typeVal.(string)
494	if !ok {
495		return schema
496	}
497
498	schema.Type = mapJSONTypeToGenAI(typeStr)
499
500	switch typeStr {
501	case "array":
502		schema.Items = processArrayItems(paramMap)
503	case "object":
504		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
505			schema.Properties = convertSchemaProperties(props)
506		}
507	}
508
509	return schema
510}
511
512func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
513	items, ok := paramMap["items"].(map[string]interface{})
514	if !ok {
515		return nil
516	}
517
518	return convertToSchema(items)
519}
520
521func mapJSONTypeToGenAI(jsonType string) genai.Type {
522	switch jsonType {
523	case "string":
524		return genai.TypeString
525	case "number":
526		return genai.TypeNumber
527	case "integer":
528		return genai.TypeInteger
529	case "boolean":
530		return genai.TypeBoolean
531	case "array":
532		return genai.TypeArray
533	case "object":
534		return genai.TypeObject
535	default:
536		return genai.TypeString // Default to string for unknown types
537	}
538}
539
540func contains(s string, substrs ...string) bool {
541	for _, substr := range substrs {
542		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
543			return true
544		}
545	}
546	return false
547}