gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46	}
 47	if opts.cfg.Options.Debug {
 48		cc.HTTPClient = log.NewHTTPClient()
 49	}
 50	client, err := genai.NewClient(context.Background(), cc)
 51	if err != nil {
 52		return nil, err
 53	}
 54	return client, nil
 55}
 56
 57func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 58	var history []*genai.Content
 59	for _, msg := range messages {
 60		switch msg.Role {
 61		case message.User:
 62			var parts []*genai.Part
 63			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 64			for _, binaryContent := range msg.BinaryContent() {
 65				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 66				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 67					MIMEType: imageFormat[1],
 68					Data:     binaryContent.Data,
 69				}})
 70			}
 71			history = append(history, &genai.Content{
 72				Parts: parts,
 73				Role:  genai.RoleUser,
 74			})
 75		case message.Assistant:
 76			var assistantParts []*genai.Part
 77
 78			if msg.Content().String() != "" {
 79				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 80			}
 81
 82			if len(msg.ToolCalls()) > 0 {
 83				for _, call := range msg.ToolCalls() {
 84					if !call.Finished {
 85						continue
 86					}
 87					args, _ := parseJSONToMap(call.Input)
 88					assistantParts = append(assistantParts, &genai.Part{
 89						FunctionCall: &genai.FunctionCall{
 90							Name: call.Name,
 91							Args: args,
 92						},
 93					})
 94				}
 95			}
 96
 97			if len(assistantParts) > 0 {
 98				history = append(history, &genai.Content{
 99					Role:  genai.RoleModel,
100					Parts: assistantParts,
101				})
102			}
103
104		case message.Tool:
105			var toolParts []*genai.Part
106			for _, result := range msg.ToolResults() {
107				response := map[string]any{"result": result.Content}
108				parsed, err := parseJSONToMap(result.Content)
109				if err == nil {
110					response = parsed
111				}
112
113				var toolCall message.ToolCall
114				for _, m := range messages {
115					if m.Role == message.Assistant {
116						for _, call := range m.ToolCalls() {
117							if call.ID == result.ToolCallID {
118								toolCall = call
119								break
120							}
121						}
122					}
123				}
124
125				toolParts = append(toolParts, &genai.Part{
126					FunctionResponse: &genai.FunctionResponse{
127						Name:     toolCall.Name,
128						Response: response,
129					},
130				})
131			}
132			if len(toolParts) > 0 {
133				history = append(history, &genai.Content{
134					Parts: toolParts,
135					Role:  genai.RoleUser,
136				})
137			}
138		}
139	}
140
141	return history
142}
143
144func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
145	geminiTool := &genai.Tool{}
146	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
147
148	for _, tool := range tools {
149		info := tool.Info()
150		declaration := &genai.FunctionDeclaration{
151			Name:        info.Name,
152			Description: info.Description,
153			Parameters: &genai.Schema{
154				Type:       genai.TypeObject,
155				Properties: convertSchemaProperties(info.Parameters),
156				Required:   info.Required,
157			},
158		}
159
160		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
161	}
162
163	return []*genai.Tool{geminiTool}
164}
165
166func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
167	switch reason {
168	case genai.FinishReasonStop:
169		return message.FinishReasonEndTurn
170	case genai.FinishReasonMaxTokens:
171		return message.FinishReasonMaxTokens
172	default:
173		return message.FinishReasonUnknown
174	}
175}
176
177func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
178	// Convert messages
179	geminiMessages := g.convertMessages(messages)
180	model := g.providerOptions.model(g.providerOptions.modelType)
181	cfg := g.providerOptions.cfg
182
183	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
184	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
185		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
186	}
187
188	maxTokens := model.DefaultMaxTokens
189	if modelConfig.MaxTokens > 0 {
190		maxTokens = modelConfig.MaxTokens
191	}
192	systemMessage := g.providerOptions.systemMessage
193	if g.providerOptions.systemPromptPrefix != "" {
194		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
195	}
196	history := geminiMessages[:len(geminiMessages)-1] // All but last message
197	lastMsg := geminiMessages[len(geminiMessages)-1]
198	config := &genai.GenerateContentConfig{
199		MaxOutputTokens: int32(maxTokens),
200		SystemInstruction: &genai.Content{
201			Parts: []*genai.Part{{Text: systemMessage}},
202		},
203	}
204	config.Tools = g.convertTools(tools)
205	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
206
207	attempts := 0
208	for {
209		attempts++
210		var toolCalls []message.ToolCall
211
212		var lastMsgParts []genai.Part
213		for _, part := range lastMsg.Parts {
214			lastMsgParts = append(lastMsgParts, *part)
215		}
216		resp, err := chat.SendMessage(ctx, lastMsgParts...)
217		// If there is an error we are going to see if we can retry the call
218		if err != nil {
219			retry, after, retryErr := g.shouldRetry(attempts, err)
220			if retryErr != nil {
221				return nil, retryErr
222			}
223			if retry {
224				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
225				select {
226				case <-ctx.Done():
227					return nil, ctx.Err()
228				case <-time.After(time.Duration(after) * time.Millisecond):
229					continue
230				}
231			}
232			return nil, retryErr
233		}
234
235		content := ""
236
237		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
238			for _, part := range resp.Candidates[0].Content.Parts {
239				switch {
240				case part.Text != "":
241					content = string(part.Text)
242				case part.FunctionCall != nil:
243					id := "call_" + uuid.New().String()
244					args, _ := json.Marshal(part.FunctionCall.Args)
245					toolCalls = append(toolCalls, message.ToolCall{
246						ID:       id,
247						Name:     part.FunctionCall.Name,
248						Input:    string(args),
249						Type:     "function",
250						Finished: true,
251					})
252				}
253			}
254		}
255		finishReason := message.FinishReasonEndTurn
256		if len(resp.Candidates) > 0 {
257			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
258		}
259		if len(toolCalls) > 0 {
260			finishReason = message.FinishReasonToolUse
261		}
262
263		return &ProviderResponse{
264			Content:      content,
265			ToolCalls:    toolCalls,
266			Usage:        g.usage(resp),
267			FinishReason: finishReason,
268		}, nil
269	}
270}
271
272func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
273	// Convert messages
274	geminiMessages := g.convertMessages(messages)
275
276	model := g.providerOptions.model(g.providerOptions.modelType)
277	cfg := g.providerOptions.cfg
278
279	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
280	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
281		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
282	}
283	maxTokens := model.DefaultMaxTokens
284	if modelConfig.MaxTokens > 0 {
285		maxTokens = modelConfig.MaxTokens
286	}
287
288	// Override max tokens if set in provider options
289	if g.providerOptions.maxTokens > 0 {
290		maxTokens = g.providerOptions.maxTokens
291	}
292	systemMessage := g.providerOptions.systemMessage
293	if g.providerOptions.systemPromptPrefix != "" {
294		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
295	}
296	history := geminiMessages[:len(geminiMessages)-1] // All but last message
297	lastMsg := geminiMessages[len(geminiMessages)-1]
298	config := &genai.GenerateContentConfig{
299		MaxOutputTokens: int32(maxTokens),
300		SystemInstruction: &genai.Content{
301			Parts: []*genai.Part{{Text: systemMessage}},
302		},
303	}
304	config.Tools = g.convertTools(tools)
305	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
306
307	attempts := 0
308	eventChan := make(chan ProviderEvent)
309
310	go func() {
311		defer close(eventChan)
312
313		for {
314			attempts++
315
316			currentContent := ""
317			toolCalls := []message.ToolCall{}
318			var finalResp *genai.GenerateContentResponse
319
320			eventChan <- ProviderEvent{Type: EventContentStart}
321
322			var lastMsgParts []genai.Part
323
324			for _, part := range lastMsg.Parts {
325				lastMsgParts = append(lastMsgParts, *part)
326			}
327
328			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
329				if err != nil {
330					retry, after, retryErr := g.shouldRetry(attempts, err)
331					if retryErr != nil {
332						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
333						return
334					}
335					if retry {
336						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
337						select {
338						case <-ctx.Done():
339							if ctx.Err() != nil {
340								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
341							}
342
343							return
344						case <-time.After(time.Duration(after) * time.Millisecond):
345							continue
346						}
347					} else {
348						eventChan <- ProviderEvent{Type: EventError, Error: err}
349						return
350					}
351				}
352
353				finalResp = resp
354
355				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
356					for _, part := range resp.Candidates[0].Content.Parts {
357						switch {
358						case part.Text != "":
359							delta := string(part.Text)
360							if delta != "" {
361								eventChan <- ProviderEvent{
362									Type:    EventContentDelta,
363									Content: delta,
364								}
365								currentContent += delta
366							}
367						case part.FunctionCall != nil:
368							id := "call_" + uuid.New().String()
369							args, _ := json.Marshal(part.FunctionCall.Args)
370							newCall := message.ToolCall{
371								ID:       id,
372								Name:     part.FunctionCall.Name,
373								Input:    string(args),
374								Type:     "function",
375								Finished: true,
376							}
377
378							toolCalls = append(toolCalls, newCall)
379						}
380					}
381				} else {
382					// no content received
383					break
384				}
385			}
386
387			eventChan <- ProviderEvent{Type: EventContentStop}
388
389			if finalResp != nil {
390				finishReason := message.FinishReasonEndTurn
391				if len(finalResp.Candidates) > 0 {
392					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
393				}
394				if len(toolCalls) > 0 {
395					finishReason = message.FinishReasonToolUse
396				}
397				eventChan <- ProviderEvent{
398					Type: EventComplete,
399					Response: &ProviderResponse{
400						Content:      currentContent,
401						ToolCalls:    toolCalls,
402						Usage:        g.usage(finalResp),
403						FinishReason: finishReason,
404					},
405				}
406				return
407			} else {
408				eventChan <- ProviderEvent{
409					Type:  EventError,
410					Error: errors.New("no content received"),
411				}
412			}
413		}
414	}()
415
416	return eventChan
417}
418
419func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
420	// Check if error is a rate limit error
421	if attempts > maxRetries {
422		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
423	}
424
425	// Gemini doesn't have a standard error type we can check against
426	// So we'll check the error message for rate limit indicators
427	if errors.Is(err, io.EOF) {
428		return false, 0, err
429	}
430
431	errMsg := err.Error()
432	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
433
434	// Check for token expiration (401 Unauthorized)
435	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
436		g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey)
437		if err != nil {
438			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
439		}
440		g.client, err = createGeminiClient(g.providerOptions)
441		if err != nil {
442			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
443		}
444		return true, 0, nil
445	}
446
447	// Check for common rate limit error messages
448
449	if !isRateLimit {
450		return false, 0, err
451	}
452
453	// Calculate backoff with jitter
454	backoffMs := 2000 * (1 << (attempts - 1))
455	jitterMs := int(float64(backoffMs) * 0.2)
456	retryMs := backoffMs + jitterMs
457
458	return true, int64(retryMs), nil
459}
460
461func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
462	if resp == nil || resp.UsageMetadata == nil {
463		return TokenUsage{}
464	}
465
466	return TokenUsage{
467		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
468		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
469		CacheCreationTokens: 0, // Not directly provided by Gemini
470		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
471	}
472}
473
474func (g *geminiClient) Model() catwalk.Model {
475	return g.providerOptions.model(g.providerOptions.modelType)
476}
477
478// Helper functions
479func parseJSONToMap(jsonStr string) (map[string]any, error) {
480	var result map[string]any
481	err := json.Unmarshal([]byte(jsonStr), &result)
482	return result, err
483}
484
485func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
486	properties := make(map[string]*genai.Schema)
487
488	for name, param := range parameters {
489		properties[name] = convertToSchema(param)
490	}
491
492	return properties
493}
494
495func convertToSchema(param any) *genai.Schema {
496	schema := &genai.Schema{Type: genai.TypeString}
497
498	paramMap, ok := param.(map[string]any)
499	if !ok {
500		return schema
501	}
502
503	if desc, ok := paramMap["description"].(string); ok {
504		schema.Description = desc
505	}
506
507	typeVal, hasType := paramMap["type"]
508	if !hasType {
509		return schema
510	}
511
512	typeStr, ok := typeVal.(string)
513	if !ok {
514		return schema
515	}
516
517	schema.Type = mapJSONTypeToGenAI(typeStr)
518
519	switch typeStr {
520	case "array":
521		schema.Items = processArrayItems(paramMap)
522	case "object":
523		if props, ok := paramMap["properties"].(map[string]any); ok {
524			schema.Properties = convertSchemaProperties(props)
525		}
526	}
527
528	return schema
529}
530
531func processArrayItems(paramMap map[string]any) *genai.Schema {
532	items, ok := paramMap["items"].(map[string]any)
533	if !ok {
534		return nil
535	}
536
537	return convertToSchema(items)
538}
539
540func mapJSONTypeToGenAI(jsonType string) genai.Type {
541	switch jsonType {
542	case "string":
543		return genai.TypeString
544	case "number":
545		return genai.TypeNumber
546	case "integer":
547		return genai.TypeInteger
548	case "boolean":
549		return genai.TypeBoolean
550	case "array":
551		return genai.TypeArray
552	case "object":
553		return genai.TypeObject
554	default:
555		return genai.TypeString // Default to string for unknown types
556	}
557}
558
559func contains(s string, substrs ...string) bool {
560	for _, substr := range substrs {
561		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
562			return true
563		}
564	}
565	return false
566}