gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/generative-ai-go/genai"
 13	"github.com/google/uuid"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18	"google.golang.org/api/iterator"
 19	"google.golang.org/api/option"
 20)
 21
 22type geminiOptions struct {
 23	disableCache bool
 24}
 25
 26type GeminiOption func(*geminiOptions)
 27
 28type geminiClient struct {
 29	providerOptions providerClientOptions
 30	options         geminiOptions
 31	client          *genai.Client
 32}
 33
 34type GeminiClient ProviderClient
 35
 36func newGeminiClient(opts providerClientOptions) GeminiClient {
 37	geminiOpts := geminiOptions{}
 38	for _, o := range opts.geminiOptions {
 39		o(&geminiOpts)
 40	}
 41
 42	client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
 43	if err != nil {
 44		logging.Error("Failed to create Gemini client", "error", err)
 45		return nil
 46	}
 47
 48	return &geminiClient{
 49		providerOptions: opts,
 50		options:         geminiOpts,
 51		client:          client,
 52	}
 53}
 54
 55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 56	var history []*genai.Content
 57	for _, msg := range messages {
 58		switch msg.Role {
 59		case message.User:
 60			history = append(history, &genai.Content{
 61				Parts: []genai.Part{genai.Text(msg.Content().String())},
 62				Role:  "user",
 63			})
 64
 65		case message.Assistant:
 66			content := &genai.Content{
 67				Role:  "model",
 68				Parts: []genai.Part{},
 69			}
 70
 71			if msg.Content().String() != "" {
 72				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 73			}
 74
 75			if len(msg.ToolCalls()) > 0 {
 76				for _, call := range msg.ToolCalls() {
 77					args, _ := parseJsonToMap(call.Input)
 78					content.Parts = append(content.Parts, genai.FunctionCall{
 79						Name: call.Name,
 80						Args: args,
 81					})
 82				}
 83			}
 84
 85			history = append(history, content)
 86
 87		case message.Tool:
 88			for _, result := range msg.ToolResults() {
 89				response := map[string]interface{}{"result": result.Content}
 90				parsed, err := parseJsonToMap(result.Content)
 91				if err == nil {
 92					response = parsed
 93				}
 94
 95				var toolCall message.ToolCall
 96				for _, m := range messages {
 97					if m.Role == message.Assistant {
 98						for _, call := range m.ToolCalls() {
 99							if call.ID == result.ToolCallID {
100								toolCall = call
101								break
102							}
103						}
104					}
105				}
106
107				history = append(history, &genai.Content{
108					Parts: []genai.Part{genai.FunctionResponse{
109						Name:     toolCall.Name,
110						Response: response,
111					}},
112					Role: "function",
113				})
114			}
115		}
116	}
117
118	return history
119}
120
121func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
122	geminiTool := &genai.Tool{}
123	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
124
125	for _, tool := range tools {
126		info := tool.Info()
127		declaration := &genai.FunctionDeclaration{
128			Name:        info.Name,
129			Description: info.Description,
130			Parameters: &genai.Schema{
131				Type:       genai.TypeObject,
132				Properties: convertSchemaProperties(info.Parameters),
133				Required:   info.Required,
134			},
135		}
136
137		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
138	}
139
140	return []*genai.Tool{geminiTool}
141}
142
143func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
144	switch {
145	case reason == genai.FinishReasonStop:
146		return message.FinishReasonEndTurn
147	case reason == genai.FinishReasonMaxTokens:
148		return message.FinishReasonMaxTokens
149	default:
150		return message.FinishReasonUnknown
151	}
152}
153
154func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
155	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
156	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
157	model.SystemInstruction = &genai.Content{
158		Parts: []genai.Part{
159			genai.Text(g.providerOptions.systemMessage),
160		},
161	}
162	// Convert tools
163	if len(tools) > 0 {
164		model.Tools = g.convertTools(tools)
165	}
166
167	// Convert messages
168	geminiMessages := g.convertMessages(messages)
169
170	cfg := config.Get()
171	if cfg.Debug {
172		jsonData, _ := json.Marshal(geminiMessages)
173		logging.Debug("Prepared messages", "messages", string(jsonData))
174	}
175
176	attempts := 0
177	for {
178		attempts++
179		var toolCalls []message.ToolCall
180		chat := model.StartChat()
181		chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
182
183		lastMsg := geminiMessages[len(geminiMessages)-1]
184
185		resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
186		// If there is an error we are going to see if we can retry the call
187		if err != nil {
188			retry, after, retryErr := g.shouldRetry(attempts, err)
189			if retryErr != nil {
190				return nil, retryErr
191			}
192			if retry {
193				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
194				select {
195				case <-ctx.Done():
196					return nil, ctx.Err()
197				case <-time.After(time.Duration(after) * time.Millisecond):
198					continue
199				}
200			}
201			return nil, retryErr
202		}
203
204		content := ""
205
206		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
207			for _, part := range resp.Candidates[0].Content.Parts {
208				switch p := part.(type) {
209				case genai.Text:
210					content = string(p)
211				case genai.FunctionCall:
212					id := "call_" + uuid.New().String()
213					args, _ := json.Marshal(p.Args)
214					toolCalls = append(toolCalls, message.ToolCall{
215						ID:       id,
216						Name:     p.Name,
217						Input:    string(args),
218						Type:     "function",
219						Finished: true,
220					})
221				}
222			}
223		}
224		finishReason := message.FinishReasonEndTurn
225		if len(resp.Candidates) > 0 {
226			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
227		}
228		if len(toolCalls) > 0 {
229			finishReason = message.FinishReasonToolUse
230		}
231
232		return &ProviderResponse{
233			Content:      content,
234			ToolCalls:    toolCalls,
235			Usage:        g.usage(resp),
236			FinishReason: finishReason,
237		}, nil
238	}
239}
240
241func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
242	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
243	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
244	model.SystemInstruction = &genai.Content{
245		Parts: []genai.Part{
246			genai.Text(g.providerOptions.systemMessage),
247		},
248	}
249	// Convert tools
250	if len(tools) > 0 {
251		model.Tools = g.convertTools(tools)
252	}
253
254	// Convert messages
255	geminiMessages := g.convertMessages(messages)
256
257	cfg := config.Get()
258	if cfg.Debug {
259		jsonData, _ := json.Marshal(geminiMessages)
260		logging.Debug("Prepared messages", "messages", string(jsonData))
261	}
262
263	attempts := 0
264	eventChan := make(chan ProviderEvent)
265
266	go func() {
267		defer close(eventChan)
268
269		for {
270			attempts++
271			chat := model.StartChat()
272			chat.History = geminiMessages[:len(geminiMessages)-1]
273			lastMsg := geminiMessages[len(geminiMessages)-1]
274
275			iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
276
277			currentContent := ""
278			toolCalls := []message.ToolCall{}
279			var finalResp *genai.GenerateContentResponse
280
281			eventChan <- ProviderEvent{Type: EventContentStart}
282
283			for {
284				resp, err := iter.Next()
285				if err == iterator.Done {
286					break
287				}
288				if err != nil {
289					retry, after, retryErr := g.shouldRetry(attempts, err)
290					if retryErr != nil {
291						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
292						return
293					}
294					if retry {
295						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
296						select {
297						case <-ctx.Done():
298							if ctx.Err() != nil {
299								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
300							}
301
302							return
303						case <-time.After(time.Duration(after) * time.Millisecond):
304							break
305						}
306					} else {
307						eventChan <- ProviderEvent{Type: EventError, Error: err}
308						return
309					}
310				}
311
312				finalResp = resp
313
314				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
315					for _, part := range resp.Candidates[0].Content.Parts {
316						switch p := part.(type) {
317						case genai.Text:
318							delta := string(p)
319							if delta != "" {
320								eventChan <- ProviderEvent{
321									Type:    EventContentDelta,
322									Content: delta,
323								}
324								currentContent += delta
325							}
326						case genai.FunctionCall:
327							id := "call_" + uuid.New().String()
328							args, _ := json.Marshal(p.Args)
329							newCall := message.ToolCall{
330								ID:       id,
331								Name:     p.Name,
332								Input:    string(args),
333								Type:     "function",
334								Finished: true,
335							}
336
337							isNew := true
338							for _, existing := range toolCalls {
339								if existing.Name == newCall.Name && existing.Input == newCall.Input {
340									isNew = false
341									break
342								}
343							}
344
345							if isNew {
346								toolCalls = append(toolCalls, newCall)
347							}
348						}
349					}
350				}
351			}
352
353			eventChan <- ProviderEvent{Type: EventContentStop}
354
355			if finalResp != nil {
356
357				finishReason := message.FinishReasonEndTurn
358				if len(finalResp.Candidates) > 0 {
359					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
360				}
361				if len(toolCalls) > 0 {
362					finishReason = message.FinishReasonToolUse
363				}
364				eventChan <- ProviderEvent{
365					Type: EventComplete,
366					Response: &ProviderResponse{
367						Content:      currentContent,
368						ToolCalls:    toolCalls,
369						Usage:        g.usage(finalResp),
370						FinishReason: finishReason,
371					},
372				}
373				return
374			}
375
376		}
377	}()
378
379	return eventChan
380}
381
382func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
383	// Check if error is a rate limit error
384	if attempts > maxRetries {
385		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
386	}
387
388	// Gemini doesn't have a standard error type we can check against
389	// So we'll check the error message for rate limit indicators
390	if errors.Is(err, io.EOF) {
391		return false, 0, err
392	}
393
394	errMsg := err.Error()
395	isRateLimit := false
396
397	// Check for common rate limit error messages
398	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
399		isRateLimit = true
400	}
401
402	if !isRateLimit {
403		return false, 0, err
404	}
405
406	// Calculate backoff with jitter
407	backoffMs := 2000 * (1 << (attempts - 1))
408	jitterMs := int(float64(backoffMs) * 0.2)
409	retryMs := backoffMs + jitterMs
410
411	return true, int64(retryMs), nil
412}
413
414func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
415	var toolCalls []message.ToolCall
416
417	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
418		for _, part := range resp.Candidates[0].Content.Parts {
419			if funcCall, ok := part.(genai.FunctionCall); ok {
420				id := "call_" + uuid.New().String()
421				args, _ := json.Marshal(funcCall.Args)
422				toolCalls = append(toolCalls, message.ToolCall{
423					ID:    id,
424					Name:  funcCall.Name,
425					Input: string(args),
426					Type:  "function",
427				})
428			}
429		}
430	}
431
432	return toolCalls
433}
434
435func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
436	if resp == nil || resp.UsageMetadata == nil {
437		return TokenUsage{}
438	}
439
440	return TokenUsage{
441		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
442		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
443		CacheCreationTokens: 0, // Not directly provided by Gemini
444		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
445	}
446}
447
448func WithGeminiDisableCache() GeminiOption {
449	return func(options *geminiOptions) {
450		options.disableCache = true
451	}
452}
453
454// Helper functions
455func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
456	var result map[string]interface{}
457	err := json.Unmarshal([]byte(jsonStr), &result)
458	return result, err
459}
460
461func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
462	properties := make(map[string]*genai.Schema)
463
464	for name, param := range parameters {
465		properties[name] = convertToSchema(param)
466	}
467
468	return properties
469}
470
471func convertToSchema(param interface{}) *genai.Schema {
472	schema := &genai.Schema{Type: genai.TypeString}
473
474	paramMap, ok := param.(map[string]interface{})
475	if !ok {
476		return schema
477	}
478
479	if desc, ok := paramMap["description"].(string); ok {
480		schema.Description = desc
481	}
482
483	typeVal, hasType := paramMap["type"]
484	if !hasType {
485		return schema
486	}
487
488	typeStr, ok := typeVal.(string)
489	if !ok {
490		return schema
491	}
492
493	schema.Type = mapJSONTypeToGenAI(typeStr)
494
495	switch typeStr {
496	case "array":
497		schema.Items = processArrayItems(paramMap)
498	case "object":
499		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
500			schema.Properties = convertSchemaProperties(props)
501		}
502	}
503
504	return schema
505}
506
507func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
508	items, ok := paramMap["items"].(map[string]interface{})
509	if !ok {
510		return nil
511	}
512
513	return convertToSchema(items)
514}
515
516func mapJSONTypeToGenAI(jsonType string) genai.Type {
517	switch jsonType {
518	case "string":
519		return genai.TypeString
520	case "number":
521		return genai.TypeNumber
522	case "integer":
523		return genai.TypeInteger
524	case "boolean":
525		return genai.TypeBoolean
526	case "array":
527		return genai.TypeArray
528	case "object":
529		return genai.TypeObject
530	default:
531		return genai.TypeString // Default to string for unknown types
532	}
533}
534
535func contains(s string, substrs ...string) bool {
536	for _, substr := range substrs {
537		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
538			return true
539		}
540	}
541	return false
542}