gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46		HTTPOptions: genai.HTTPOptions{
 47			BaseURL: opts.baseURL,
 48		},
 49	}
 50	if config.Get().Options.Debug {
 51		cc.HTTPClient = log.NewHTTPClient()
 52	}
 53	client, err := genai.NewClient(context.Background(), cc)
 54	if err != nil {
 55		return nil, err
 56	}
 57	return client, nil
 58}
 59
 60func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 61	var history []*genai.Content
 62	for _, msg := range messages {
 63		switch msg.Role {
 64		case message.User:
 65			var parts []*genai.Part
 66			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 67			for _, binaryContent := range msg.BinaryContent() {
 68				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 69				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 70					MIMEType: imageFormat[1],
 71					Data:     binaryContent.Data,
 72				}})
 73			}
 74			history = append(history, &genai.Content{
 75				Parts: parts,
 76				Role:  genai.RoleUser,
 77			})
 78		case message.Assistant:
 79			var assistantParts []*genai.Part
 80
 81			if msg.Content().String() != "" {
 82				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 83			}
 84
 85			if len(msg.ToolCalls()) > 0 {
 86				for _, call := range msg.ToolCalls() {
 87					if !call.Finished {
 88						continue
 89					}
 90					args, _ := parseJSONToMap(call.Input)
 91					assistantParts = append(assistantParts, &genai.Part{
 92						FunctionCall: &genai.FunctionCall{
 93							Name: call.Name,
 94							Args: args,
 95						},
 96					})
 97				}
 98			}
 99
100			if len(assistantParts) > 0 {
101				history = append(history, &genai.Content{
102					Role:  genai.RoleModel,
103					Parts: assistantParts,
104				})
105			}
106
107		case message.Tool:
108			var toolParts []*genai.Part
109			for _, result := range msg.ToolResults() {
110				response := map[string]any{"result": result.Content}
111				parsed, err := parseJSONToMap(result.Content)
112				if err == nil {
113					response = parsed
114				}
115
116				var toolCall message.ToolCall
117				for _, m := range messages {
118					if m.Role == message.Assistant {
119						for _, call := range m.ToolCalls() {
120							if call.ID == result.ToolCallID {
121								toolCall = call
122								break
123							}
124						}
125					}
126				}
127
128				toolParts = append(toolParts, &genai.Part{
129					FunctionResponse: &genai.FunctionResponse{
130						Name:     toolCall.Name,
131						Response: response,
132					},
133				})
134			}
135			if len(toolParts) > 0 {
136				history = append(history, &genai.Content{
137					Parts: toolParts,
138					Role:  genai.RoleUser,
139				})
140			}
141		}
142	}
143
144	return history
145}
146
147func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
148	geminiTool := &genai.Tool{}
149	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
150
151	for _, tool := range tools {
152		info := tool.Info()
153		declaration := &genai.FunctionDeclaration{
154			Name:        info.Name,
155			Description: info.Description,
156			Parameters: &genai.Schema{
157				Type:       genai.TypeObject,
158				Properties: convertSchemaProperties(info.Parameters),
159				Required:   info.Required,
160			},
161		}
162
163		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
164	}
165
166	return []*genai.Tool{geminiTool}
167}
168
169func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
170	switch reason {
171	case genai.FinishReasonStop:
172		return message.FinishReasonEndTurn
173	case genai.FinishReasonMaxTokens:
174		return message.FinishReasonMaxTokens
175	default:
176		return message.FinishReasonUnknown
177	}
178}
179
180func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
181	// Convert messages
182	geminiMessages := g.convertMessages(messages)
183	model := g.providerOptions.model(g.providerOptions.modelType)
184	cfg := config.Get()
185
186	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
187	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
188		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
189	}
190
191	maxTokens := model.DefaultMaxTokens
192	if modelConfig.MaxTokens > 0 {
193		maxTokens = modelConfig.MaxTokens
194	}
195	systemMessage := g.providerOptions.systemMessage
196	if g.providerOptions.systemPromptPrefix != "" {
197		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
198	}
199	history := geminiMessages[:len(geminiMessages)-1] // All but last message
200	lastMsg := geminiMessages[len(geminiMessages)-1]
201	config := &genai.GenerateContentConfig{
202		MaxOutputTokens: int32(maxTokens),
203		SystemInstruction: &genai.Content{
204			Parts: []*genai.Part{{Text: systemMessage}},
205		},
206	}
207	config.Tools = g.convertTools(tools)
208	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
209
210	attempts := 0
211	for {
212		attempts++
213		var toolCalls []message.ToolCall
214
215		var lastMsgParts []genai.Part
216		for _, part := range lastMsg.Parts {
217			lastMsgParts = append(lastMsgParts, *part)
218		}
219		resp, err := chat.SendMessage(ctx, lastMsgParts...)
220		// If there is an error we are going to see if we can retry the call
221		if err != nil {
222			retry, after, retryErr := g.shouldRetry(attempts, err)
223			if retryErr != nil {
224				return nil, retryErr
225			}
226			if retry {
227				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
228				select {
229				case <-ctx.Done():
230					return nil, ctx.Err()
231				case <-time.After(time.Duration(after) * time.Millisecond):
232					continue
233				}
234			}
235			return nil, retryErr
236		}
237
238		content := ""
239
240		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
241			for _, part := range resp.Candidates[0].Content.Parts {
242				switch {
243				case part.Text != "":
244					content = string(part.Text)
245				case part.FunctionCall != nil:
246					id := "call_" + uuid.New().String()
247					args, _ := json.Marshal(part.FunctionCall.Args)
248					toolCalls = append(toolCalls, message.ToolCall{
249						ID:       id,
250						Name:     part.FunctionCall.Name,
251						Input:    string(args),
252						Type:     "function",
253						Finished: true,
254					})
255				}
256			}
257		}
258		finishReason := message.FinishReasonEndTurn
259		if len(resp.Candidates) > 0 {
260			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
261		}
262		if len(toolCalls) > 0 {
263			finishReason = message.FinishReasonToolUse
264		}
265
266		return &ProviderResponse{
267			Content:      content,
268			ToolCalls:    toolCalls,
269			Usage:        g.usage(resp),
270			FinishReason: finishReason,
271		}, nil
272	}
273}
274
275func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
276	// Convert messages
277	geminiMessages := g.convertMessages(messages)
278
279	model := g.providerOptions.model(g.providerOptions.modelType)
280	cfg := config.Get()
281
282	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
283	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
284		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
285	}
286	maxTokens := model.DefaultMaxTokens
287	if modelConfig.MaxTokens > 0 {
288		maxTokens = modelConfig.MaxTokens
289	}
290
291	// Override max tokens if set in provider options
292	if g.providerOptions.maxTokens > 0 {
293		maxTokens = g.providerOptions.maxTokens
294	}
295	systemMessage := g.providerOptions.systemMessage
296	if g.providerOptions.systemPromptPrefix != "" {
297		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
298	}
299	history := geminiMessages[:len(geminiMessages)-1] // All but last message
300	lastMsg := geminiMessages[len(geminiMessages)-1]
301	config := &genai.GenerateContentConfig{
302		MaxOutputTokens: int32(maxTokens),
303		SystemInstruction: &genai.Content{
304			Parts: []*genai.Part{{Text: systemMessage}},
305		},
306	}
307	config.Tools = g.convertTools(tools)
308	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
309
310	attempts := 0
311	eventChan := make(chan ProviderEvent)
312
313	go func() {
314		defer close(eventChan)
315
316		for {
317			attempts++
318
319			currentContent := ""
320			toolCalls := []message.ToolCall{}
321			var finalResp *genai.GenerateContentResponse
322
323			eventChan <- ProviderEvent{Type: EventContentStart}
324
325			var lastMsgParts []genai.Part
326
327			for _, part := range lastMsg.Parts {
328				lastMsgParts = append(lastMsgParts, *part)
329			}
330
331			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
332				if err != nil {
333					retry, after, retryErr := g.shouldRetry(attempts, err)
334					if retryErr != nil {
335						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
336						return
337					}
338					if retry {
339						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
340						select {
341						case <-ctx.Done():
342							if ctx.Err() != nil {
343								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
344							}
345
346							return
347						case <-time.After(time.Duration(after) * time.Millisecond):
348							continue
349						}
350					} else {
351						eventChan <- ProviderEvent{Type: EventError, Error: err}
352						return
353					}
354				}
355
356				finalResp = resp
357
358				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
359					for _, part := range resp.Candidates[0].Content.Parts {
360						switch {
361						case part.Text != "":
362							delta := string(part.Text)
363							if delta != "" {
364								eventChan <- ProviderEvent{
365									Type:    EventContentDelta,
366									Content: delta,
367								}
368								currentContent += delta
369							}
370						case part.FunctionCall != nil:
371							id := "call_" + uuid.New().String()
372							args, _ := json.Marshal(part.FunctionCall.Args)
373							newCall := message.ToolCall{
374								ID:       id,
375								Name:     part.FunctionCall.Name,
376								Input:    string(args),
377								Type:     "function",
378								Finished: true,
379							}
380
381							toolCalls = append(toolCalls, newCall)
382						}
383					}
384				} else {
385					// no content received
386					break
387				}
388			}
389
390			eventChan <- ProviderEvent{Type: EventContentStop}
391
392			if finalResp != nil {
393				finishReason := message.FinishReasonEndTurn
394				if len(finalResp.Candidates) > 0 {
395					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
396				}
397				if len(toolCalls) > 0 {
398					finishReason = message.FinishReasonToolUse
399				}
400				eventChan <- ProviderEvent{
401					Type: EventComplete,
402					Response: &ProviderResponse{
403						Content:      currentContent,
404						ToolCalls:    toolCalls,
405						Usage:        g.usage(finalResp),
406						FinishReason: finishReason,
407					},
408				}
409				return
410			} else {
411				eventChan <- ProviderEvent{
412					Type:  EventError,
413					Error: errors.New("no content received"),
414				}
415			}
416		}
417	}()
418
419	return eventChan
420}
421
422func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
423	// Check if error is a rate limit error
424	if attempts > maxRetries {
425		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
426	}
427
428	// Gemini doesn't have a standard error type we can check against
429	// So we'll check the error message for rate limit indicators
430	if errors.Is(err, io.EOF) {
431		return false, 0, err
432	}
433
434	errMsg := err.Error()
435	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
436
437	// Check for token expiration (401 Unauthorized)
438	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
439		prev := g.providerOptions.apiKey
440		// in case the key comes from a script, we try to re-evaluate it.
441		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
442		if err != nil {
443			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
444		}
445		// if it didn't change, do not retry.
446		if prev == g.providerOptions.apiKey {
447			return false, 0, err
448		}
449		g.client, err = createGeminiClient(g.providerOptions)
450		if err != nil {
451			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
452		}
453		return true, 0, nil
454	}
455
456	// Check for common rate limit error messages
457
458	if !isRateLimit {
459		return false, 0, err
460	}
461
462	// Calculate backoff with jitter
463	backoffMs := 2000 * (1 << (attempts - 1))
464	jitterMs := int(float64(backoffMs) * 0.2)
465	retryMs := backoffMs + jitterMs
466
467	return true, int64(retryMs), nil
468}
469
470func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
471	if resp == nil || resp.UsageMetadata == nil {
472		return TokenUsage{}
473	}
474
475	return TokenUsage{
476		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
477		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
478		CacheCreationTokens: 0, // Not directly provided by Gemini
479		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
480	}
481}
482
483func (g *geminiClient) Model() catwalk.Model {
484	return g.providerOptions.model(g.providerOptions.modelType)
485}
486
487// Helper functions
488func parseJSONToMap(jsonStr string) (map[string]any, error) {
489	var result map[string]any
490	err := json.Unmarshal([]byte(jsonStr), &result)
491	return result, err
492}
493
494func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
495	properties := make(map[string]*genai.Schema)
496
497	for name, param := range parameters {
498		properties[name] = convertToSchema(param)
499	}
500
501	return properties
502}
503
504func convertToSchema(param any) *genai.Schema {
505	schema := &genai.Schema{Type: genai.TypeString}
506
507	paramMap, ok := param.(map[string]any)
508	if !ok {
509		return schema
510	}
511
512	if desc, ok := paramMap["description"].(string); ok {
513		schema.Description = desc
514	}
515
516	typeVal, hasType := paramMap["type"]
517	if !hasType {
518		return schema
519	}
520
521	typeStr, ok := typeVal.(string)
522	if !ok {
523		return schema
524	}
525
526	schema.Type = mapJSONTypeToGenAI(typeStr)
527
528	switch typeStr {
529	case "array":
530		schema.Items = processArrayItems(paramMap)
531	case "object":
532		if props, ok := paramMap["properties"].(map[string]any); ok {
533			schema.Properties = convertSchemaProperties(props)
534		}
535	}
536
537	return schema
538}
539
540func processArrayItems(paramMap map[string]any) *genai.Schema {
541	items, ok := paramMap["items"].(map[string]any)
542	if !ok {
543		return nil
544	}
545
546	return convertToSchema(items)
547}
548
549func mapJSONTypeToGenAI(jsonType string) genai.Type {
550	switch jsonType {
551	case "string":
552		return genai.TypeString
553	case "number":
554		return genai.TypeNumber
555	case "integer":
556		return genai.TypeInteger
557	case "boolean":
558		return genai.TypeBoolean
559	case "array":
560		return genai.TypeArray
561	case "object":
562		return genai.TypeObject
563	default:
564		return genai.TypeString // Default to string for unknown types
565	}
566}
567
568func contains(s string, substrs ...string) bool {
569	for _, substr := range substrs {
570		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
571			return true
572		}
573	}
574	return false
575}