gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/google/uuid"
 18	"google.golang.org/genai"
 19)
 20
 21type geminiClient struct {
 22	providerOptions providerClientOptions
 23	client          *genai.Client
 24}
 25
 26type GeminiClient ProviderClient
 27
 28func newGeminiClient(opts providerClientOptions) GeminiClient {
 29	client, err := createGeminiClient(opts)
 30	if err != nil {
 31		slog.Error("Failed to create Gemini client", "error", err)
 32		return nil
 33	}
 34
 35	return &geminiClient{
 36		providerOptions: opts,
 37		client:          client,
 38	}
 39}
 40
 41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 42	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 43	if err != nil {
 44		return nil, err
 45	}
 46	return client, nil
 47}
 48
 49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 50	var history []*genai.Content
 51	for _, msg := range messages {
 52		switch msg.Role {
 53		case message.User:
 54			var parts []*genai.Part
 55			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 56			for _, binaryContent := range msg.BinaryContent() {
 57				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 58				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 59					MIMEType: imageFormat[1],
 60					Data:     binaryContent.Data,
 61				}})
 62			}
 63			history = append(history, &genai.Content{
 64				Parts: parts,
 65				Role:  "user",
 66			})
 67		case message.Assistant:
 68			var assistantParts []*genai.Part
 69
 70			if msg.Content().String() != "" {
 71				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 72			}
 73
 74			if len(msg.ToolCalls()) > 0 {
 75				for _, call := range msg.ToolCalls() {
 76					args, _ := parseJSONToMap(call.Input)
 77					assistantParts = append(assistantParts, &genai.Part{
 78						FunctionCall: &genai.FunctionCall{
 79							Name: call.Name,
 80							Args: args,
 81						},
 82					})
 83				}
 84			}
 85
 86			if len(assistantParts) > 0 {
 87				history = append(history, &genai.Content{
 88					Role:  "model",
 89					Parts: assistantParts,
 90				})
 91			}
 92
 93		case message.Tool:
 94			for _, result := range msg.ToolResults() {
 95				response := map[string]any{"result": result.Content}
 96				parsed, err := parseJSONToMap(result.Content)
 97				if err == nil {
 98					response = parsed
 99				}
100
101				var toolCall message.ToolCall
102				for _, m := range messages {
103					if m.Role == message.Assistant {
104						for _, call := range m.ToolCalls() {
105							if call.ID == result.ToolCallID {
106								toolCall = call
107								break
108							}
109						}
110					}
111				}
112
113				history = append(history, &genai.Content{
114					Parts: []*genai.Part{
115						{
116							FunctionResponse: &genai.FunctionResponse{
117								Name:     toolCall.Name,
118								Response: response,
119							},
120						},
121					},
122					Role: "function",
123				})
124			}
125		}
126	}
127
128	return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132	geminiTool := &genai.Tool{}
133	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135	for _, tool := range tools {
136		info := tool.Info()
137		declaration := &genai.FunctionDeclaration{
138			Name:        info.Name,
139			Description: info.Description,
140			Parameters: &genai.Schema{
141				Type:       genai.TypeObject,
142				Properties: convertSchemaProperties(info.Parameters),
143				Required:   info.Required,
144			},
145		}
146
147		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148	}
149
150	return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154	switch reason {
155	case genai.FinishReasonStop:
156		return message.FinishReasonEndTurn
157	case genai.FinishReasonMaxTokens:
158		return message.FinishReasonMaxTokens
159	default:
160		return message.FinishReasonUnknown
161	}
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165	// Convert messages
166	geminiMessages := g.convertMessages(messages)
167	model := g.providerOptions.model(g.providerOptions.modelType)
168	cfg := config.Get()
169	if cfg.Options.Debug {
170		jsonData, _ := json.Marshal(geminiMessages)
171		slog.Debug("Prepared messages", "messages", string(jsonData))
172	}
173
174	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177	}
178
179	maxTokens := model.DefaultMaxTokens
180	if modelConfig.MaxTokens > 0 {
181		maxTokens = modelConfig.MaxTokens
182	}
183	history := geminiMessages[:len(geminiMessages)-1] // All but last message
184	lastMsg := geminiMessages[len(geminiMessages)-1]
185	config := &genai.GenerateContentConfig{
186		MaxOutputTokens: int32(maxTokens),
187		SystemInstruction: &genai.Content{
188			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
189		},
190	}
191	config.Tools = g.convertTools(tools)
192	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
193
194	attempts := 0
195	for {
196		attempts++
197		var toolCalls []message.ToolCall
198
199		var lastMsgParts []genai.Part
200		for _, part := range lastMsg.Parts {
201			lastMsgParts = append(lastMsgParts, *part)
202		}
203		resp, err := chat.SendMessage(ctx, lastMsgParts...)
204		// If there is an error we are going to see if we can retry the call
205		if err != nil {
206			retry, after, retryErr := g.shouldRetry(attempts, err)
207			if retryErr != nil {
208				return nil, retryErr
209			}
210			if retry {
211				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
212				select {
213				case <-ctx.Done():
214					return nil, ctx.Err()
215				case <-time.After(time.Duration(after) * time.Millisecond):
216					continue
217				}
218			}
219			return nil, retryErr
220		}
221
222		content := ""
223
224		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
225			for _, part := range resp.Candidates[0].Content.Parts {
226				switch {
227				case part.Text != "":
228					content = string(part.Text)
229				case part.FunctionCall != nil:
230					id := "call_" + uuid.New().String()
231					args, _ := json.Marshal(part.FunctionCall.Args)
232					toolCalls = append(toolCalls, message.ToolCall{
233						ID:       id,
234						Name:     part.FunctionCall.Name,
235						Input:    string(args),
236						Type:     "function",
237						Finished: true,
238					})
239				}
240			}
241		}
242		finishReason := message.FinishReasonEndTurn
243		if len(resp.Candidates) > 0 {
244			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
245		}
246		if len(toolCalls) > 0 {
247			finishReason = message.FinishReasonToolUse
248		}
249
250		return &ProviderResponse{
251			Content:      content,
252			ToolCalls:    toolCalls,
253			Usage:        g.usage(resp),
254			FinishReason: finishReason,
255		}, nil
256	}
257}
258
259func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
260	// Convert messages
261	geminiMessages := g.convertMessages(messages)
262
263	model := g.providerOptions.model(g.providerOptions.modelType)
264	cfg := config.Get()
265	if cfg.Options.Debug {
266		jsonData, _ := json.Marshal(geminiMessages)
267		slog.Debug("Prepared messages", "messages", string(jsonData))
268	}
269
270	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
271	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
272		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
273	}
274	maxTokens := model.DefaultMaxTokens
275	if modelConfig.MaxTokens > 0 {
276		maxTokens = modelConfig.MaxTokens
277	}
278
279	// Override max tokens if set in provider options
280	if g.providerOptions.maxTokens > 0 {
281		maxTokens = g.providerOptions.maxTokens
282	}
283	history := geminiMessages[:len(geminiMessages)-1] // All but last message
284	lastMsg := geminiMessages[len(geminiMessages)-1]
285	config := &genai.GenerateContentConfig{
286		MaxOutputTokens: int32(maxTokens),
287		SystemInstruction: &genai.Content{
288			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
289		},
290	}
291	config.Tools = g.convertTools(tools)
292	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
293
294	attempts := 0
295	eventChan := make(chan ProviderEvent)
296
297	go func() {
298		defer close(eventChan)
299
300		for {
301			attempts++
302
303			currentContent := ""
304			toolCalls := []message.ToolCall{}
305			var finalResp *genai.GenerateContentResponse
306
307			eventChan <- ProviderEvent{Type: EventContentStart}
308
309			var lastMsgParts []genai.Part
310
311			for _, part := range lastMsg.Parts {
312				lastMsgParts = append(lastMsgParts, *part)
313			}
314			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
315				if err != nil {
316					retry, after, retryErr := g.shouldRetry(attempts, err)
317					if retryErr != nil {
318						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
319						return
320					}
321					if retry {
322						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
323						select {
324						case <-ctx.Done():
325							if ctx.Err() != nil {
326								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
327							}
328
329							return
330						case <-time.After(time.Duration(after) * time.Millisecond):
331							break
332						}
333					} else {
334						eventChan <- ProviderEvent{Type: EventError, Error: err}
335						return
336					}
337				}
338
339				finalResp = resp
340
341				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
342					for _, part := range resp.Candidates[0].Content.Parts {
343						switch {
344						case part.Text != "":
345							delta := string(part.Text)
346							if delta != "" {
347								eventChan <- ProviderEvent{
348									Type:    EventContentDelta,
349									Content: delta,
350								}
351								currentContent += delta
352							}
353						case part.FunctionCall != nil:
354							id := "call_" + uuid.New().String()
355							args, _ := json.Marshal(part.FunctionCall.Args)
356							newCall := message.ToolCall{
357								ID:       id,
358								Name:     part.FunctionCall.Name,
359								Input:    string(args),
360								Type:     "function",
361								Finished: true,
362							}
363
364							isNew := true
365							for _, existing := range toolCalls {
366								if existing.Name == newCall.Name && existing.Input == newCall.Input {
367									isNew = false
368									break
369								}
370							}
371
372							if isNew {
373								toolCalls = append(toolCalls, newCall)
374							}
375						}
376					}
377				}
378			}
379
380			eventChan <- ProviderEvent{Type: EventContentStop}
381
382			if finalResp != nil {
383				finishReason := message.FinishReasonEndTurn
384				if len(finalResp.Candidates) > 0 {
385					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
386				}
387				if len(toolCalls) > 0 {
388					finishReason = message.FinishReasonToolUse
389				}
390				eventChan <- ProviderEvent{
391					Type: EventComplete,
392					Response: &ProviderResponse{
393						Content:      currentContent,
394						ToolCalls:    toolCalls,
395						Usage:        g.usage(finalResp),
396						FinishReason: finishReason,
397					},
398				}
399				return
400			}
401		}
402	}()
403
404	return eventChan
405}
406
407func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
408	// Check if error is a rate limit error
409	if attempts > maxRetries {
410		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
411	}
412
413	// Gemini doesn't have a standard error type we can check against
414	// So we'll check the error message for rate limit indicators
415	if errors.Is(err, io.EOF) {
416		return false, 0, err
417	}
418
419	errMsg := err.Error()
420	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
421
422	// Check for token expiration (401 Unauthorized)
423	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
424		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
425		if err != nil {
426			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
427		}
428		g.client, err = createGeminiClient(g.providerOptions)
429		if err != nil {
430			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
431		}
432		return true, 0, nil
433	}
434
435	// Check for common rate limit error messages
436
437	if !isRateLimit {
438		return false, 0, err
439	}
440
441	// Calculate backoff with jitter
442	backoffMs := 2000 * (1 << (attempts - 1))
443	jitterMs := int(float64(backoffMs) * 0.2)
444	retryMs := backoffMs + jitterMs
445
446	return true, int64(retryMs), nil
447}
448
449func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
450	if resp == nil || resp.UsageMetadata == nil {
451		return TokenUsage{}
452	}
453
454	return TokenUsage{
455		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
456		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
457		CacheCreationTokens: 0, // Not directly provided by Gemini
458		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
459	}
460}
461
462func (g *geminiClient) Model() catwalk.Model {
463	return g.providerOptions.model(g.providerOptions.modelType)
464}
465
466// Helper functions
467func parseJSONToMap(jsonStr string) (map[string]any, error) {
468	var result map[string]any
469	err := json.Unmarshal([]byte(jsonStr), &result)
470	return result, err
471}
472
473func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
474	properties := make(map[string]*genai.Schema)
475
476	for name, param := range parameters {
477		properties[name] = convertToSchema(param)
478	}
479
480	return properties
481}
482
483func convertToSchema(param any) *genai.Schema {
484	schema := &genai.Schema{Type: genai.TypeString}
485
486	paramMap, ok := param.(map[string]any)
487	if !ok {
488		return schema
489	}
490
491	if desc, ok := paramMap["description"].(string); ok {
492		schema.Description = desc
493	}
494
495	typeVal, hasType := paramMap["type"]
496	if !hasType {
497		return schema
498	}
499
500	typeStr, ok := typeVal.(string)
501	if !ok {
502		return schema
503	}
504
505	schema.Type = mapJSONTypeToGenAI(typeStr)
506
507	switch typeStr {
508	case "array":
509		schema.Items = processArrayItems(paramMap)
510	case "object":
511		if props, ok := paramMap["properties"].(map[string]any); ok {
512			schema.Properties = convertSchemaProperties(props)
513		}
514	}
515
516	return schema
517}
518
519func processArrayItems(paramMap map[string]any) *genai.Schema {
520	items, ok := paramMap["items"].(map[string]any)
521	if !ok {
522		return nil
523	}
524
525	return convertToSchema(items)
526}
527
528func mapJSONTypeToGenAI(jsonType string) genai.Type {
529	switch jsonType {
530	case "string":
531		return genai.TypeString
532	case "number":
533		return genai.TypeNumber
534	case "integer":
535		return genai.TypeInteger
536	case "boolean":
537		return genai.TypeBoolean
538	case "array":
539		return genai.TypeArray
540	case "object":
541		return genai.TypeObject
542	default:
543		return genai.TypeString // Default to string for unknown types
544	}
545}
546
547func contains(s string, substrs ...string) bool {
548	for _, substr := range substrs {
549		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
550			return true
551		}
552	}
553	return false
554}