gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/google/uuid"
 13	"github.com/opencode-ai/opencode/internal/config"
 14	"github.com/opencode-ai/opencode/internal/llm/tools"
 15	"github.com/opencode-ai/opencode/internal/logging"
 16	"github.com/opencode-ai/opencode/internal/message"
 17	"google.golang.org/genai"
 18)
 19
 20type geminiOptions struct {
 21	disableCache bool
 22}
 23
 24type GeminiOption func(*geminiOptions)
 25
 26type geminiClient struct {
 27	providerOptions providerClientOptions
 28	options         geminiOptions
 29	client          *genai.Client
 30}
 31
 32type GeminiClient ProviderClient
 33
 34func newGeminiClient(opts providerClientOptions) GeminiClient {
 35	geminiOpts := geminiOptions{}
 36	for _, o := range opts.geminiOptions {
 37		o(&geminiOpts)
 38	}
 39
 40	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 41	if err != nil {
 42		logging.Error("Failed to create Gemini client", "error", err)
 43		return nil
 44	}
 45
 46	return &geminiClient{
 47		providerOptions: opts,
 48		options:         geminiOpts,
 49		client:          client,
 50	}
 51}
 52
 53func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 54	var history []*genai.Content
 55	for _, msg := range messages {
 56		switch msg.Role {
 57		case message.User:
 58			var parts []*genai.Part
 59			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 60			for _, binaryContent := range msg.BinaryContent() {
 61				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 62				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 63					MIMEType: imageFormat[1],
 64					Data:     binaryContent.Data,
 65				}})
 66			}
 67			history = append(history, &genai.Content{
 68				Parts: parts,
 69				Role:  "user",
 70			})
 71		case message.Assistant:
 72			content := &genai.Content{
 73				Role:  "model",
 74				Parts: []*genai.Part{},
 75			}
 76
 77			if msg.Content().String() != "" {
 78				content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
 79			}
 80
 81			if len(msg.ToolCalls()) > 0 {
 82				for _, call := range msg.ToolCalls() {
 83					args, _ := parseJsonToMap(call.Input)
 84					content.Parts = append(content.Parts, &genai.Part{
 85						FunctionCall: &genai.FunctionCall{
 86							Name: call.Name,
 87							Args: args,
 88						},
 89					})
 90				}
 91			}
 92
 93			history = append(history, content)
 94
 95		case message.Tool:
 96			for _, result := range msg.ToolResults() {
 97				response := map[string]interface{}{"result": result.Content}
 98				parsed, err := parseJsonToMap(result.Content)
 99				if err == nil {
100					response = parsed
101				}
102
103				var toolCall message.ToolCall
104				for _, m := range messages {
105					if m.Role == message.Assistant {
106						for _, call := range m.ToolCalls() {
107							if call.ID == result.ToolCallID {
108								toolCall = call
109								break
110							}
111						}
112					}
113				}
114
115				history = append(history, &genai.Content{
116					Parts: []*genai.Part{
117						{
118							FunctionResponse: &genai.FunctionResponse{
119								Name:     toolCall.Name,
120								Response: response,
121							},
122						},
123					},
124					Role: "function",
125				})
126			}
127		}
128	}
129
130	return history
131}
132
133func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
134	geminiTool := &genai.Tool{}
135	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
136
137	for _, tool := range tools {
138		info := tool.Info()
139		declaration := &genai.FunctionDeclaration{
140			Name:        info.Name,
141			Description: info.Description,
142			Parameters: &genai.Schema{
143				Type:       genai.TypeObject,
144				Properties: convertSchemaProperties(info.Parameters),
145				Required:   info.Required,
146			},
147		}
148
149		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
150	}
151
152	return []*genai.Tool{geminiTool}
153}
154
155func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
156	switch {
157	case reason == genai.FinishReasonStop:
158		return message.FinishReasonEndTurn
159	case reason == genai.FinishReasonMaxTokens:
160		return message.FinishReasonMaxTokens
161	default:
162		return message.FinishReasonUnknown
163	}
164}
165
166func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
167	// Convert messages
168	geminiMessages := g.convertMessages(messages)
169
170	cfg := config.Get()
171	if cfg.Debug {
172		jsonData, _ := json.Marshal(geminiMessages)
173		logging.Debug("Prepared messages", "messages", string(jsonData))
174	}
175
176	history := geminiMessages[:len(geminiMessages)-1] // All but last message
177	lastMsg := geminiMessages[len(geminiMessages)-1]
178	config := &genai.GenerateContentConfig{
179		MaxOutputTokens: int32(g.providerOptions.maxTokens),
180		SystemInstruction: &genai.Content{
181			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
182		},
183	}
184	if len(tools) > 0 {
185		config.Tools = g.convertTools(tools)
186	}
187	chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
188
189	attempts := 0
190	for {
191		attempts++
192		var toolCalls []message.ToolCall
193
194		var lastMsgParts []genai.Part
195		for _, part := range lastMsg.Parts {
196			lastMsgParts = append(lastMsgParts, *part)
197		}
198		resp, err := chat.SendMessage(ctx, lastMsgParts...)
199		// If there is an error we are going to see if we can retry the call
200		if err != nil {
201			retry, after, retryErr := g.shouldRetry(attempts, err)
202			if retryErr != nil {
203				return nil, retryErr
204			}
205			if retry {
206				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
207				select {
208				case <-ctx.Done():
209					return nil, ctx.Err()
210				case <-time.After(time.Duration(after) * time.Millisecond):
211					continue
212				}
213			}
214			return nil, retryErr
215		}
216
217		content := ""
218
219		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
220			for _, part := range resp.Candidates[0].Content.Parts {
221				switch {
222				case part.Text != "":
223					content = string(part.Text)
224				case part.FunctionCall != nil:
225					id := "call_" + uuid.New().String()
226					args, _ := json.Marshal(part.FunctionCall.Args)
227					toolCalls = append(toolCalls, message.ToolCall{
228						ID:       id,
229						Name:     part.FunctionCall.Name,
230						Input:    string(args),
231						Type:     "function",
232						Finished: true,
233					})
234				}
235			}
236		}
237		finishReason := message.FinishReasonEndTurn
238		if len(resp.Candidates) > 0 {
239			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
240		}
241		if len(toolCalls) > 0 {
242			finishReason = message.FinishReasonToolUse
243		}
244
245		return &ProviderResponse{
246			Content:      content,
247			ToolCalls:    toolCalls,
248			Usage:        g.usage(resp),
249			FinishReason: finishReason,
250		}, nil
251	}
252}
253
254func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
255	// Convert messages
256	geminiMessages := g.convertMessages(messages)
257
258	cfg := config.Get()
259	if cfg.Debug {
260		jsonData, _ := json.Marshal(geminiMessages)
261		logging.Debug("Prepared messages", "messages", string(jsonData))
262	}
263
264	history := geminiMessages[:len(geminiMessages)-1] // All but last message
265	lastMsg := geminiMessages[len(geminiMessages)-1]
266	config := &genai.GenerateContentConfig{
267		MaxOutputTokens: int32(g.providerOptions.maxTokens),
268		SystemInstruction: &genai.Content{
269			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
270		},
271	}
272	if len(tools) > 0 {
273		config.Tools = g.convertTools(tools)
274	}
275	chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
276
277	attempts := 0
278	eventChan := make(chan ProviderEvent)
279
280	go func() {
281		defer close(eventChan)
282
283		for {
284			attempts++
285
286			currentContent := ""
287			toolCalls := []message.ToolCall{}
288			var finalResp *genai.GenerateContentResponse
289
290			eventChan <- ProviderEvent{Type: EventContentStart}
291
292			var lastMsgParts []genai.Part
293
294			for _, part := range lastMsg.Parts {
295				lastMsgParts = append(lastMsgParts, *part)
296			}
297			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
298				if err != nil {
299					retry, after, retryErr := g.shouldRetry(attempts, err)
300					if retryErr != nil {
301						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
302						return
303					}
304					if retry {
305						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
306						select {
307						case <-ctx.Done():
308							if ctx.Err() != nil {
309								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
310							}
311
312							return
313						case <-time.After(time.Duration(after) * time.Millisecond):
314							break
315						}
316					} else {
317						eventChan <- ProviderEvent{Type: EventError, Error: err}
318						return
319					}
320				}
321
322				finalResp = resp
323
324				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
325					for _, part := range resp.Candidates[0].Content.Parts {
326						switch {
327						case part.Text != "":
328							delta := string(part.Text)
329							if delta != "" {
330								eventChan <- ProviderEvent{
331									Type:    EventContentDelta,
332									Content: delta,
333								}
334								currentContent += delta
335							}
336						case part.FunctionCall != nil:
337							id := "call_" + uuid.New().String()
338							args, _ := json.Marshal(part.FunctionCall.Args)
339							newCall := message.ToolCall{
340								ID:       id,
341								Name:     part.FunctionCall.Name,
342								Input:    string(args),
343								Type:     "function",
344								Finished: true,
345							}
346
347							isNew := true
348							for _, existing := range toolCalls {
349								if existing.Name == newCall.Name && existing.Input == newCall.Input {
350									isNew = false
351									break
352								}
353							}
354
355							if isNew {
356								toolCalls = append(toolCalls, newCall)
357							}
358						}
359					}
360				}
361			}
362
363			eventChan <- ProviderEvent{Type: EventContentStop}
364
365			if finalResp != nil {
366
367				finishReason := message.FinishReasonEndTurn
368				if len(finalResp.Candidates) > 0 {
369					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
370				}
371				if len(toolCalls) > 0 {
372					finishReason = message.FinishReasonToolUse
373				}
374				eventChan <- ProviderEvent{
375					Type: EventComplete,
376					Response: &ProviderResponse{
377						Content:      currentContent,
378						ToolCalls:    toolCalls,
379						Usage:        g.usage(finalResp),
380						FinishReason: finishReason,
381					},
382				}
383				return
384			}
385
386		}
387	}()
388
389	return eventChan
390}
391
392func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
393	// Check if error is a rate limit error
394	if attempts > maxRetries {
395		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
396	}
397
398	// Gemini doesn't have a standard error type we can check against
399	// So we'll check the error message for rate limit indicators
400	if errors.Is(err, io.EOF) {
401		return false, 0, err
402	}
403
404	errMsg := err.Error()
405	isRateLimit := false
406
407	// Check for common rate limit error messages
408	if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
409		isRateLimit = true
410	}
411
412	if !isRateLimit {
413		return false, 0, err
414	}
415
416	// Calculate backoff with jitter
417	backoffMs := 2000 * (1 << (attempts - 1))
418	jitterMs := int(float64(backoffMs) * 0.2)
419	retryMs := backoffMs + jitterMs
420
421	return true, int64(retryMs), nil
422}
423
424func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
425	var toolCalls []message.ToolCall
426
427	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
428		for _, part := range resp.Candidates[0].Content.Parts {
429			if part.FunctionCall != nil {
430				id := "call_" + uuid.New().String()
431				args, _ := json.Marshal(part.FunctionCall.Args)
432				toolCalls = append(toolCalls, message.ToolCall{
433					ID:    id,
434					Name:  part.FunctionCall.Name,
435					Input: string(args),
436					Type:  "function",
437				})
438			}
439		}
440	}
441
442	return toolCalls
443}
444
445func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
446	if resp == nil || resp.UsageMetadata == nil {
447		return TokenUsage{}
448	}
449
450	return TokenUsage{
451		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
452		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
453		CacheCreationTokens: 0, // Not directly provided by Gemini
454		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
455	}
456}
457
458func WithGeminiDisableCache() GeminiOption {
459	return func(options *geminiOptions) {
460		options.disableCache = true
461	}
462}
463
464// Helper functions
465func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
466	var result map[string]interface{}
467	err := json.Unmarshal([]byte(jsonStr), &result)
468	return result, err
469}
470
471func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
472	properties := make(map[string]*genai.Schema)
473
474	for name, param := range parameters {
475		properties[name] = convertToSchema(param)
476	}
477
478	return properties
479}
480
481func convertToSchema(param interface{}) *genai.Schema {
482	schema := &genai.Schema{Type: genai.TypeString}
483
484	paramMap, ok := param.(map[string]interface{})
485	if !ok {
486		return schema
487	}
488
489	if desc, ok := paramMap["description"].(string); ok {
490		schema.Description = desc
491	}
492
493	typeVal, hasType := paramMap["type"]
494	if !hasType {
495		return schema
496	}
497
498	typeStr, ok := typeVal.(string)
499	if !ok {
500		return schema
501	}
502
503	schema.Type = mapJSONTypeToGenAI(typeStr)
504
505	switch typeStr {
506	case "array":
507		schema.Items = processArrayItems(paramMap)
508	case "object":
509		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
510			schema.Properties = convertSchemaProperties(props)
511		}
512	}
513
514	return schema
515}
516
517func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
518	items, ok := paramMap["items"].(map[string]interface{})
519	if !ok {
520		return nil
521	}
522
523	return convertToSchema(items)
524}
525
526func mapJSONTypeToGenAI(jsonType string) genai.Type {
527	switch jsonType {
528	case "string":
529		return genai.TypeString
530	case "number":
531		return genai.TypeNumber
532	case "integer":
533		return genai.TypeInteger
534	case "boolean":
535		return genai.TypeBoolean
536	case "array":
537		return genai.TypeArray
538	case "object":
539		return genai.TypeObject
540	default:
541		return genai.TypeString // Default to string for unknown types
542	}
543}
544
545func contains(s string, substrs ...string) bool {
546	for _, substr := range substrs {
547		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
548			return true
549		}
550	}
551	return false
552}