gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/google/uuid"
 18	"google.golang.org/genai"
 19)
 20
 21type geminiClient struct {
 22	providerOptions providerClientOptions
 23	client          *genai.Client
 24}
 25
 26type GeminiClient ProviderClient
 27
 28func newGeminiClient(opts providerClientOptions) GeminiClient {
 29	client, err := createGeminiClient(opts)
 30	if err != nil {
 31		slog.Error("Failed to create Gemini client", "error", err)
 32		return nil
 33	}
 34
 35	return &geminiClient{
 36		providerOptions: opts,
 37		client:          client,
 38	}
 39}
 40
 41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 42	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 43	if err != nil {
 44		return nil, err
 45	}
 46	return client, nil
 47}
 48
 49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 50	var history []*genai.Content
 51	for _, msg := range messages {
 52		switch msg.Role {
 53		case message.User:
 54			var parts []*genai.Part
 55			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 56			for _, binaryContent := range msg.BinaryContent() {
 57				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 58				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 59					MIMEType: imageFormat[1],
 60					Data:     binaryContent.Data,
 61				}})
 62			}
 63			history = append(history, &genai.Content{
 64				Parts: parts,
 65				Role:  "user",
 66			})
 67		case message.Assistant:
 68			var assistantParts []*genai.Part
 69
 70			if msg.Content().String() != "" {
 71				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 72			}
 73
 74			if len(msg.ToolCalls()) > 0 {
 75				for _, call := range msg.ToolCalls() {
 76					args, _ := parseJSONToMap(call.Input)
 77					assistantParts = append(assistantParts, &genai.Part{
 78						FunctionCall: &genai.FunctionCall{
 79							Name: call.Name,
 80							Args: args,
 81						},
 82					})
 83				}
 84			}
 85
 86			if len(assistantParts) > 0 {
 87				history = append(history, &genai.Content{
 88					Role:  "model",
 89					Parts: assistantParts,
 90				})
 91			}
 92
 93		case message.Tool:
 94			for _, result := range msg.ToolResults() {
 95				response := map[string]any{"result": result.Content}
 96				parsed, err := parseJSONToMap(result.Content)
 97				if err == nil {
 98					response = parsed
 99				}
100
101				var toolCall message.ToolCall
102				for _, m := range messages {
103					if m.Role == message.Assistant {
104						for _, call := range m.ToolCalls() {
105							if call.ID == result.ToolCallID {
106								toolCall = call
107								break
108							}
109						}
110					}
111				}
112
113				history = append(history, &genai.Content{
114					Parts: []*genai.Part{
115						{
116							FunctionResponse: &genai.FunctionResponse{
117								Name:     toolCall.Name,
118								Response: response,
119							},
120						},
121					},
122					Role: "function",
123				})
124			}
125		}
126	}
127
128	return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132	geminiTool := &genai.Tool{}
133	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135	for _, tool := range tools {
136		info := tool.Info()
137		declaration := &genai.FunctionDeclaration{
138			Name:        info.Name,
139			Description: info.Description,
140			Parameters: &genai.Schema{
141				Type:       genai.TypeObject,
142				Properties: convertSchemaProperties(info.Parameters),
143				Required:   info.Required,
144			},
145		}
146
147		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148	}
149
150	return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154	switch reason {
155	case genai.FinishReasonStop:
156		return message.FinishReasonEndTurn
157	case genai.FinishReasonMaxTokens:
158		return message.FinishReasonMaxTokens
159	default:
160		return message.FinishReasonUnknown
161	}
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165	// Convert messages
166	geminiMessages := g.convertMessages(messages)
167	model := g.providerOptions.model(g.providerOptions.modelType)
168	cfg := config.Get()
169	if cfg.Options.Debug {
170		jsonData, _ := json.Marshal(geminiMessages)
171		slog.Debug("Prepared messages", "messages", string(jsonData))
172	}
173
174	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177	}
178
179	maxTokens := model.DefaultMaxTokens
180	if modelConfig.MaxTokens > 0 {
181		maxTokens = modelConfig.MaxTokens
182	}
183	systemMessage := g.providerOptions.systemMessage
184	if g.providerOptions.systemPromptPrefix != "" {
185		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
186	}
187	history := geminiMessages[:len(geminiMessages)-1] // All but last message
188	lastMsg := geminiMessages[len(geminiMessages)-1]
189	config := &genai.GenerateContentConfig{
190		MaxOutputTokens: int32(maxTokens),
191		SystemInstruction: &genai.Content{
192			Parts: []*genai.Part{{Text: systemMessage}},
193		},
194	}
195	config.Tools = g.convertTools(tools)
196	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
197
198	attempts := 0
199	for {
200		attempts++
201		var toolCalls []message.ToolCall
202
203		var lastMsgParts []genai.Part
204		for _, part := range lastMsg.Parts {
205			lastMsgParts = append(lastMsgParts, *part)
206		}
207		resp, err := chat.SendMessage(ctx, lastMsgParts...)
208		// If there is an error we are going to see if we can retry the call
209		if err != nil {
210			retry, after, retryErr := g.shouldRetry(attempts, err)
211			if retryErr != nil {
212				return nil, retryErr
213			}
214			if retry {
215				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
216				select {
217				case <-ctx.Done():
218					return nil, ctx.Err()
219				case <-time.After(time.Duration(after) * time.Millisecond):
220					continue
221				}
222			}
223			return nil, retryErr
224		}
225
226		content := ""
227
228		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
229			for _, part := range resp.Candidates[0].Content.Parts {
230				switch {
231				case part.Text != "":
232					content = string(part.Text)
233				case part.FunctionCall != nil:
234					id := "call_" + uuid.New().String()
235					args, _ := json.Marshal(part.FunctionCall.Args)
236					toolCalls = append(toolCalls, message.ToolCall{
237						ID:       id,
238						Name:     part.FunctionCall.Name,
239						Input:    string(args),
240						Type:     "function",
241						Finished: true,
242					})
243				}
244			}
245		}
246		finishReason := message.FinishReasonEndTurn
247		if len(resp.Candidates) > 0 {
248			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
249		}
250		if len(toolCalls) > 0 {
251			finishReason = message.FinishReasonToolUse
252		}
253
254		return &ProviderResponse{
255			Content:      content,
256			ToolCalls:    toolCalls,
257			Usage:        g.usage(resp),
258			FinishReason: finishReason,
259		}, nil
260	}
261}
262
263func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
264	// Convert messages
265	geminiMessages := g.convertMessages(messages)
266
267	model := g.providerOptions.model(g.providerOptions.modelType)
268	cfg := config.Get()
269	if cfg.Options.Debug {
270		jsonData, _ := json.Marshal(geminiMessages)
271		slog.Debug("Prepared messages", "messages", string(jsonData))
272	}
273
274	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
275	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
276		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
277	}
278	maxTokens := model.DefaultMaxTokens
279	if modelConfig.MaxTokens > 0 {
280		maxTokens = modelConfig.MaxTokens
281	}
282
283	// Override max tokens if set in provider options
284	if g.providerOptions.maxTokens > 0 {
285		maxTokens = g.providerOptions.maxTokens
286	}
287	systemMessage := g.providerOptions.systemMessage
288	if g.providerOptions.systemPromptPrefix != "" {
289		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
290	}
291	history := geminiMessages[:len(geminiMessages)-1] // All but last message
292	lastMsg := geminiMessages[len(geminiMessages)-1]
293	config := &genai.GenerateContentConfig{
294		MaxOutputTokens: int32(maxTokens),
295		SystemInstruction: &genai.Content{
296			Parts: []*genai.Part{{Text: systemMessage}},
297		},
298	}
299	config.Tools = g.convertTools(tools)
300	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
301
302	attempts := 0
303	eventChan := make(chan ProviderEvent)
304
305	go func() {
306		defer close(eventChan)
307
308		for {
309			attempts++
310
311			currentContent := ""
312			toolCalls := []message.ToolCall{}
313			var finalResp *genai.GenerateContentResponse
314
315			eventChan <- ProviderEvent{Type: EventContentStart}
316
317			var lastMsgParts []genai.Part
318
319			for _, part := range lastMsg.Parts {
320				lastMsgParts = append(lastMsgParts, *part)
321			}
322			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
323				if err != nil {
324					retry, after, retryErr := g.shouldRetry(attempts, err)
325					if retryErr != nil {
326						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
327						return
328					}
329					if retry {
330						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
331						select {
332						case <-ctx.Done():
333							if ctx.Err() != nil {
334								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
335							}
336
337							return
338						case <-time.After(time.Duration(after) * time.Millisecond):
339							break
340						}
341					} else {
342						eventChan <- ProviderEvent{Type: EventError, Error: err}
343						return
344					}
345				}
346
347				finalResp = resp
348
349				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
350					for _, part := range resp.Candidates[0].Content.Parts {
351						switch {
352						case part.Text != "":
353							delta := string(part.Text)
354							if delta != "" {
355								eventChan <- ProviderEvent{
356									Type:    EventContentDelta,
357									Content: delta,
358								}
359								currentContent += delta
360							}
361						case part.FunctionCall != nil:
362							id := "call_" + uuid.New().String()
363							args, _ := json.Marshal(part.FunctionCall.Args)
364							newCall := message.ToolCall{
365								ID:       id,
366								Name:     part.FunctionCall.Name,
367								Input:    string(args),
368								Type:     "function",
369								Finished: true,
370							}
371
372							isNew := true
373							for _, existing := range toolCalls {
374								if existing.Name == newCall.Name && existing.Input == newCall.Input {
375									isNew = false
376									break
377								}
378							}
379
380							if isNew {
381								toolCalls = append(toolCalls, newCall)
382							}
383						}
384					}
385				}
386			}
387
388			eventChan <- ProviderEvent{Type: EventContentStop}
389
390			if finalResp != nil {
391				finishReason := message.FinishReasonEndTurn
392				if len(finalResp.Candidates) > 0 {
393					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
394				}
395				if len(toolCalls) > 0 {
396					finishReason = message.FinishReasonToolUse
397				}
398				eventChan <- ProviderEvent{
399					Type: EventComplete,
400					Response: &ProviderResponse{
401						Content:      currentContent,
402						ToolCalls:    toolCalls,
403						Usage:        g.usage(finalResp),
404						FinishReason: finishReason,
405					},
406				}
407				return
408			}
409		}
410	}()
411
412	return eventChan
413}
414
415func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
416	// Check if error is a rate limit error
417	if attempts > maxRetries {
418		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
419	}
420
421	// Gemini doesn't have a standard error type we can check against
422	// So we'll check the error message for rate limit indicators
423	if errors.Is(err, io.EOF) {
424		return false, 0, err
425	}
426
427	errMsg := err.Error()
428	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
429
430	// Check for token expiration (401 Unauthorized)
431	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
432		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
433		if err != nil {
434			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
435		}
436		g.client, err = createGeminiClient(g.providerOptions)
437		if err != nil {
438			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
439		}
440		return true, 0, nil
441	}
442
443	// Check for common rate limit error messages
444
445	if !isRateLimit {
446		return false, 0, err
447	}
448
449	// Calculate backoff with jitter
450	backoffMs := 2000 * (1 << (attempts - 1))
451	jitterMs := int(float64(backoffMs) * 0.2)
452	retryMs := backoffMs + jitterMs
453
454	return true, int64(retryMs), nil
455}
456
457func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
458	if resp == nil || resp.UsageMetadata == nil {
459		return TokenUsage{}
460	}
461
462	return TokenUsage{
463		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
464		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
465		CacheCreationTokens: 0, // Not directly provided by Gemini
466		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
467	}
468}
469
470func (g *geminiClient) Model() catwalk.Model {
471	return g.providerOptions.model(g.providerOptions.modelType)
472}
473
474// Helper functions
475func parseJSONToMap(jsonStr string) (map[string]any, error) {
476	var result map[string]any
477	err := json.Unmarshal([]byte(jsonStr), &result)
478	return result, err
479}
480
481func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
482	properties := make(map[string]*genai.Schema)
483
484	for name, param := range parameters {
485		properties[name] = convertToSchema(param)
486	}
487
488	return properties
489}
490
491func convertToSchema(param any) *genai.Schema {
492	schema := &genai.Schema{Type: genai.TypeString}
493
494	paramMap, ok := param.(map[string]any)
495	if !ok {
496		return schema
497	}
498
499	if desc, ok := paramMap["description"].(string); ok {
500		schema.Description = desc
501	}
502
503	typeVal, hasType := paramMap["type"]
504	if !hasType {
505		return schema
506	}
507
508	typeStr, ok := typeVal.(string)
509	if !ok {
510		return schema
511	}
512
513	schema.Type = mapJSONTypeToGenAI(typeStr)
514
515	switch typeStr {
516	case "array":
517		schema.Items = processArrayItems(paramMap)
518	case "object":
519		if props, ok := paramMap["properties"].(map[string]any); ok {
520			schema.Properties = convertSchemaProperties(props)
521		}
522	}
523
524	return schema
525}
526
527func processArrayItems(paramMap map[string]any) *genai.Schema {
528	items, ok := paramMap["items"].(map[string]any)
529	if !ok {
530		return nil
531	}
532
533	return convertToSchema(items)
534}
535
536func mapJSONTypeToGenAI(jsonType string) genai.Type {
537	switch jsonType {
538	case "string":
539		return genai.TypeString
540	case "number":
541		return genai.TypeNumber
542	case "integer":
543		return genai.TypeInteger
544	case "boolean":
545		return genai.TypeBoolean
546	case "array":
547		return genai.TypeArray
548	case "object":
549		return genai.TypeObject
550	default:
551		return genai.TypeString // Default to string for unknown types
552	}
553}
554
555func contains(s string, substrs ...string) bool {
556	for _, substr := range substrs {
557		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
558			return true
559		}
560	}
561	return false
562}