gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46	}
 47	if config.Get().Options.Debug {
 48		cc.HTTPClient = log.NewHTTPClient()
 49	}
 50	client, err := genai.NewClient(context.Background(), cc)
 51	if err != nil {
 52		return nil, err
 53	}
 54	return client, nil
 55}
 56
 57func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 58	var history []*genai.Content
 59	for _, msg := range messages {
 60		switch msg.Role {
 61		case message.User:
 62			var parts []*genai.Part
 63			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 64			for _, binaryContent := range msg.BinaryContent() {
 65				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 66				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 67					MIMEType: imageFormat[1],
 68					Data:     binaryContent.Data,
 69				}})
 70			}
 71			history = append(history, &genai.Content{
 72				Parts: parts,
 73				Role:  genai.RoleUser,
 74			})
 75		case message.Assistant:
 76			var assistantParts []*genai.Part
 77
 78			if msg.Content().String() != "" {
 79				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 80			}
 81
 82			if len(msg.ToolCalls()) > 0 {
 83				for _, call := range msg.ToolCalls() {
 84					if !call.Finished {
 85						continue
 86					}
 87					args, _ := parseJSONToMap(call.Input)
 88					assistantParts = append(assistantParts, &genai.Part{
 89						FunctionCall: &genai.FunctionCall{
 90							Name: call.Name,
 91							Args: args,
 92						},
 93					})
 94				}
 95			}
 96
 97			if len(assistantParts) > 0 {
 98				history = append(history, &genai.Content{
 99					Role:  genai.RoleModel,
100					Parts: assistantParts,
101				})
102			}
103
104		case message.Tool:
105			for _, result := range msg.ToolResults() {
106				response := map[string]any{"result": result.Content}
107				parsed, err := parseJSONToMap(result.Content)
108				if err == nil {
109					response = parsed
110				}
111
112				var toolCall message.ToolCall
113				for _, m := range messages {
114					if m.Role == message.Assistant {
115						for _, call := range m.ToolCalls() {
116							if call.ID == result.ToolCallID {
117								toolCall = call
118								break
119							}
120						}
121					}
122				}
123
124				history = append(history, &genai.Content{
125					Parts: []*genai.Part{
126						{
127							FunctionResponse: &genai.FunctionResponse{
128								Name:     toolCall.Name,
129								Response: response,
130							},
131						},
132					},
133					Role: genai.RoleModel,
134				})
135			}
136		}
137	}
138
139	return history
140}
141
142func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
143	geminiTool := &genai.Tool{}
144	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
145
146	for _, tool := range tools {
147		info := tool.Info()
148		declaration := &genai.FunctionDeclaration{
149			Name:        info.Name,
150			Description: info.Description,
151			Parameters: &genai.Schema{
152				Type:       genai.TypeObject,
153				Properties: convertSchemaProperties(info.Parameters),
154				Required:   info.Required,
155			},
156		}
157
158		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
159	}
160
161	return []*genai.Tool{geminiTool}
162}
163
164func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
165	switch reason {
166	case genai.FinishReasonStop:
167		return message.FinishReasonEndTurn
168	case genai.FinishReasonMaxTokens:
169		return message.FinishReasonMaxTokens
170	default:
171		return message.FinishReasonUnknown
172	}
173}
174
175func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
176	// Convert messages
177	geminiMessages := g.convertMessages(messages)
178	model := g.providerOptions.model(g.providerOptions.modelType)
179	cfg := config.Get()
180
181	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
182	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
183		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
184	}
185
186	maxTokens := model.DefaultMaxTokens
187	if modelConfig.MaxTokens > 0 {
188		maxTokens = modelConfig.MaxTokens
189	}
190	systemMessage := g.providerOptions.systemMessage
191	if g.providerOptions.systemPromptPrefix != "" {
192		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
193	}
194	history := geminiMessages[:len(geminiMessages)-1] // All but last message
195	lastMsg := geminiMessages[len(geminiMessages)-1]
196	config := &genai.GenerateContentConfig{
197		MaxOutputTokens: int32(maxTokens),
198		SystemInstruction: &genai.Content{
199			Parts: []*genai.Part{{Text: systemMessage}},
200		},
201	}
202	config.Tools = g.convertTools(tools)
203	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
204
205	attempts := 0
206	for {
207		attempts++
208		var toolCalls []message.ToolCall
209
210		var lastMsgParts []genai.Part
211		for _, part := range lastMsg.Parts {
212			lastMsgParts = append(lastMsgParts, *part)
213		}
214		resp, err := chat.SendMessage(ctx, lastMsgParts...)
215		// If there is an error we are going to see if we can retry the call
216		if err != nil {
217			retry, after, retryErr := g.shouldRetry(attempts, err)
218			if retryErr != nil {
219				return nil, retryErr
220			}
221			if retry {
222				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
223				select {
224				case <-ctx.Done():
225					return nil, ctx.Err()
226				case <-time.After(time.Duration(after) * time.Millisecond):
227					continue
228				}
229			}
230			return nil, retryErr
231		}
232
233		content := ""
234
235		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
236			for _, part := range resp.Candidates[0].Content.Parts {
237				switch {
238				case part.Text != "":
239					content = string(part.Text)
240				case part.FunctionCall != nil:
241					id := "call_" + uuid.New().String()
242					args, _ := json.Marshal(part.FunctionCall.Args)
243					toolCalls = append(toolCalls, message.ToolCall{
244						ID:       id,
245						Name:     part.FunctionCall.Name,
246						Input:    string(args),
247						Type:     "function",
248						Finished: true,
249					})
250				}
251			}
252		}
253		finishReason := message.FinishReasonEndTurn
254		if len(resp.Candidates) > 0 {
255			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
256		}
257		if len(toolCalls) > 0 {
258			finishReason = message.FinishReasonToolUse
259		}
260
261		return &ProviderResponse{
262			Content:      content,
263			ToolCalls:    toolCalls,
264			Usage:        g.usage(resp),
265			FinishReason: finishReason,
266		}, nil
267	}
268}
269
270func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
271	// Convert messages
272	geminiMessages := g.convertMessages(messages)
273
274	model := g.providerOptions.model(g.providerOptions.modelType)
275	cfg := config.Get()
276
277	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
278	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
279		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
280	}
281	maxTokens := model.DefaultMaxTokens
282	if modelConfig.MaxTokens > 0 {
283		maxTokens = modelConfig.MaxTokens
284	}
285
286	// Override max tokens if set in provider options
287	if g.providerOptions.maxTokens > 0 {
288		maxTokens = g.providerOptions.maxTokens
289	}
290	systemMessage := g.providerOptions.systemMessage
291	if g.providerOptions.systemPromptPrefix != "" {
292		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
293	}
294	history := geminiMessages[:len(geminiMessages)-1] // All but last message
295	lastMsg := geminiMessages[len(geminiMessages)-1]
296	config := &genai.GenerateContentConfig{
297		MaxOutputTokens: int32(maxTokens),
298		SystemInstruction: &genai.Content{
299			Parts: []*genai.Part{{Text: systemMessage}},
300		},
301	}
302	config.Tools = g.convertTools(tools)
303	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
304
305	attempts := 0
306	eventChan := make(chan ProviderEvent)
307
308	go func() {
309		defer close(eventChan)
310
311		for {
312			attempts++
313
314			currentContent := ""
315			toolCalls := []message.ToolCall{}
316			var finalResp *genai.GenerateContentResponse
317
318			eventChan <- ProviderEvent{Type: EventContentStart}
319
320			var lastMsgParts []genai.Part
321
322			for _, part := range lastMsg.Parts {
323				lastMsgParts = append(lastMsgParts, *part)
324			}
325
326			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
327				if err != nil {
328					retry, after, retryErr := g.shouldRetry(attempts, err)
329					if retryErr != nil {
330						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
331						return
332					}
333					if retry {
334						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
335						select {
336						case <-ctx.Done():
337							if ctx.Err() != nil {
338								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
339							}
340
341							return
342						case <-time.After(time.Duration(after) * time.Millisecond):
343							continue
344						}
345					} else {
346						eventChan <- ProviderEvent{Type: EventError, Error: err}
347						return
348					}
349				}
350
351				finalResp = resp
352
353				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
354					for _, part := range resp.Candidates[0].Content.Parts {
355						switch {
356						case part.Text != "":
357							delta := string(part.Text)
358							if delta != "" {
359								eventChan <- ProviderEvent{
360									Type:    EventContentDelta,
361									Content: delta,
362								}
363								currentContent += delta
364							}
365						case part.FunctionCall != nil:
366							id := "call_" + uuid.New().String()
367							args, _ := json.Marshal(part.FunctionCall.Args)
368							newCall := message.ToolCall{
369								ID:       id,
370								Name:     part.FunctionCall.Name,
371								Input:    string(args),
372								Type:     "function",
373								Finished: true,
374							}
375
376							isNew := true
377							for _, existing := range toolCalls {
378								if existing.Name == newCall.Name && existing.Input == newCall.Input {
379									isNew = false
380									break
381								}
382							}
383
384							if isNew {
385								toolCalls = append(toolCalls, newCall)
386							}
387						}
388					}
389				} else {
390					// no content received
391					break
392				}
393			}
394
395			eventChan <- ProviderEvent{Type: EventContentStop}
396
397			if finalResp != nil {
398				finishReason := message.FinishReasonEndTurn
399				if len(finalResp.Candidates) > 0 {
400					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
401				}
402				if len(toolCalls) > 0 {
403					finishReason = message.FinishReasonToolUse
404				}
405				eventChan <- ProviderEvent{
406					Type: EventComplete,
407					Response: &ProviderResponse{
408						Content:      currentContent,
409						ToolCalls:    toolCalls,
410						Usage:        g.usage(finalResp),
411						FinishReason: finishReason,
412					},
413				}
414				return
415			} else {
416				eventChan <- ProviderEvent{
417					Type:  EventError,
418					Error: errors.New("no content received"),
419				}
420			}
421		}
422	}()
423
424	return eventChan
425}
426
427func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
428	// Check if error is a rate limit error
429	if attempts > maxRetries {
430		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
431	}
432
433	// Gemini doesn't have a standard error type we can check against
434	// So we'll check the error message for rate limit indicators
435	if errors.Is(err, io.EOF) {
436		return false, 0, err
437	}
438
439	errMsg := err.Error()
440	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
441
442	// Check for token expiration (401 Unauthorized)
443	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
444		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
445		if err != nil {
446			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
447		}
448		g.client, err = createGeminiClient(g.providerOptions)
449		if err != nil {
450			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
451		}
452		return true, 0, nil
453	}
454
455	// Check for common rate limit error messages
456
457	if !isRateLimit {
458		return false, 0, err
459	}
460
461	// Calculate backoff with jitter
462	backoffMs := 2000 * (1 << (attempts - 1))
463	jitterMs := int(float64(backoffMs) * 0.2)
464	retryMs := backoffMs + jitterMs
465
466	return true, int64(retryMs), nil
467}
468
469func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
470	if resp == nil || resp.UsageMetadata == nil {
471		return TokenUsage{}
472	}
473
474	return TokenUsage{
475		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
476		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
477		CacheCreationTokens: 0, // Not directly provided by Gemini
478		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
479	}
480}
481
482func (g *geminiClient) Model() catwalk.Model {
483	return g.providerOptions.model(g.providerOptions.modelType)
484}
485
486// Helper functions
487func parseJSONToMap(jsonStr string) (map[string]any, error) {
488	var result map[string]any
489	err := json.Unmarshal([]byte(jsonStr), &result)
490	return result, err
491}
492
493func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
494	properties := make(map[string]*genai.Schema)
495
496	for name, param := range parameters {
497		properties[name] = convertToSchema(param)
498	}
499
500	return properties
501}
502
503func convertToSchema(param any) *genai.Schema {
504	schema := &genai.Schema{Type: genai.TypeString}
505
506	paramMap, ok := param.(map[string]any)
507	if !ok {
508		return schema
509	}
510
511	if desc, ok := paramMap["description"].(string); ok {
512		schema.Description = desc
513	}
514
515	typeVal, hasType := paramMap["type"]
516	if !hasType {
517		return schema
518	}
519
520	typeStr, ok := typeVal.(string)
521	if !ok {
522		return schema
523	}
524
525	schema.Type = mapJSONTypeToGenAI(typeStr)
526
527	switch typeStr {
528	case "array":
529		schema.Items = processArrayItems(paramMap)
530	case "object":
531		if props, ok := paramMap["properties"].(map[string]any); ok {
532			schema.Properties = convertSchemaProperties(props)
533		}
534	}
535
536	return schema
537}
538
539func processArrayItems(paramMap map[string]any) *genai.Schema {
540	items, ok := paramMap["items"].(map[string]any)
541	if !ok {
542		return nil
543	}
544
545	return convertToSchema(items)
546}
547
548func mapJSONTypeToGenAI(jsonType string) genai.Type {
549	switch jsonType {
550	case "string":
551		return genai.TypeString
552	case "number":
553		return genai.TypeNumber
554	case "integer":
555		return genai.TypeInteger
556	case "boolean":
557		return genai.TypeBoolean
558	case "array":
559		return genai.TypeArray
560	case "object":
561		return genai.TypeObject
562	default:
563		return genai.TypeString // Default to string for unknown types
564	}
565}
566
567func contains(s string, substrs ...string) bool {
568	for _, substr := range substrs {
569		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
570			return true
571		}
572	}
573	return false
574}