gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46	}
 47	if opts.baseURL != "" {
 48		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 49		if err == nil && resolvedBaseURL != "" {
 50			cc.HTTPOptions = genai.HTTPOptions{
 51				BaseURL: resolvedBaseURL,
 52			}
 53		}
 54	}
 55	if config.Get().Options.Debug {
 56		cc.HTTPClient = log.NewHTTPClient()
 57	}
 58	client, err := genai.NewClient(context.Background(), cc)
 59	if err != nil {
 60		return nil, err
 61	}
 62	return client, nil
 63}
 64
 65func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 66	var history []*genai.Content
 67	for _, msg := range messages {
 68		switch msg.Role {
 69		case message.User:
 70			var parts []*genai.Part
 71			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 72			for _, binaryContent := range msg.BinaryContent() {
 73				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 74				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 75					MIMEType: imageFormat[1],
 76					Data:     binaryContent.Data,
 77				}})
 78			}
 79			history = append(history, &genai.Content{
 80				Parts: parts,
 81				Role:  genai.RoleUser,
 82			})
 83		case message.Assistant:
 84			var assistantParts []*genai.Part
 85
 86			if msg.Content().String() != "" {
 87				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 88			}
 89
 90			if len(msg.ToolCalls()) > 0 {
 91				for _, call := range msg.ToolCalls() {
 92					if !call.Finished {
 93						continue
 94					}
 95					args, _ := parseJSONToMap(call.Input)
 96					assistantParts = append(assistantParts, &genai.Part{
 97						FunctionCall: &genai.FunctionCall{
 98							Name: call.Name,
 99							Args: args,
100						},
101					})
102				}
103			}
104
105			if len(assistantParts) > 0 {
106				history = append(history, &genai.Content{
107					Role:  genai.RoleModel,
108					Parts: assistantParts,
109				})
110			}
111
112		case message.Tool:
113			var toolParts []*genai.Part
114			for _, result := range msg.ToolResults() {
115				response := map[string]any{"result": result.Content}
116				parsed, err := parseJSONToMap(result.Content)
117				if err == nil {
118					response = parsed
119				}
120
121				var toolCall message.ToolCall
122				for _, m := range messages {
123					if m.Role == message.Assistant {
124						for _, call := range m.ToolCalls() {
125							if call.ID == result.ToolCallID {
126								toolCall = call
127								break
128							}
129						}
130					}
131				}
132
133				toolParts = append(toolParts, &genai.Part{
134					FunctionResponse: &genai.FunctionResponse{
135						Name:     toolCall.Name,
136						Response: response,
137					},
138				})
139			}
140			if len(toolParts) > 0 {
141				history = append(history, &genai.Content{
142					Parts: toolParts,
143					Role:  genai.RoleUser,
144				})
145			}
146		}
147	}
148
149	return history
150}
151
152func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
153	geminiTool := &genai.Tool{}
154	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
155
156	for _, tool := range tools {
157		info := tool.Info()
158		declaration := &genai.FunctionDeclaration{
159			Name:        info.Name,
160			Description: info.Description,
161			Parameters: &genai.Schema{
162				Type:       genai.TypeObject,
163				Properties: convertSchemaProperties(info.Parameters),
164				Required:   info.Required,
165			},
166		}
167
168		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
169	}
170
171	return []*genai.Tool{geminiTool}
172}
173
174func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
175	switch reason {
176	case genai.FinishReasonStop:
177		return message.FinishReasonEndTurn
178	case genai.FinishReasonMaxTokens:
179		return message.FinishReasonMaxTokens
180	default:
181		return message.FinishReasonUnknown
182	}
183}
184
185func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
186	// Convert messages
187	geminiMessages := g.convertMessages(messages)
188	model := g.providerOptions.model(g.providerOptions.modelType)
189	cfg := config.Get()
190
191	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
192	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
193		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
194	}
195
196	maxTokens := model.DefaultMaxTokens
197	if modelConfig.MaxTokens > 0 {
198		maxTokens = modelConfig.MaxTokens
199	}
200	systemMessage := g.providerOptions.systemMessage
201	if g.providerOptions.systemPromptPrefix != "" {
202		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
203	}
204	history := geminiMessages[:len(geminiMessages)-1] // All but last message
205	lastMsg := geminiMessages[len(geminiMessages)-1]
206	config := &genai.GenerateContentConfig{
207		MaxOutputTokens: int32(maxTokens),
208		SystemInstruction: &genai.Content{
209			Parts: []*genai.Part{{Text: systemMessage}},
210		},
211	}
212	config.Tools = g.convertTools(tools)
213	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
214
215	attempts := 0
216	for {
217		attempts++
218		var toolCalls []message.ToolCall
219
220		var lastMsgParts []genai.Part
221		for _, part := range lastMsg.Parts {
222			lastMsgParts = append(lastMsgParts, *part)
223		}
224		resp, err := chat.SendMessage(ctx, lastMsgParts...)
225		// If there is an error we are going to see if we can retry the call
226		if err != nil {
227			retry, after, retryErr := g.shouldRetry(attempts, err)
228			if retryErr != nil {
229				return nil, retryErr
230			}
231			if retry {
232				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
233				select {
234				case <-ctx.Done():
235					return nil, ctx.Err()
236				case <-time.After(time.Duration(after) * time.Millisecond):
237					continue
238				}
239			}
240			return nil, retryErr
241		}
242
243		content := ""
244
245		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
246			for _, part := range resp.Candidates[0].Content.Parts {
247				switch {
248				case part.Text != "":
249					content = string(part.Text)
250				case part.FunctionCall != nil:
251					id := "call_" + uuid.New().String()
252					args, _ := json.Marshal(part.FunctionCall.Args)
253					toolCalls = append(toolCalls, message.ToolCall{
254						ID:       id,
255						Name:     part.FunctionCall.Name,
256						Input:    string(args),
257						Type:     "function",
258						Finished: true,
259					})
260				}
261			}
262		}
263		finishReason := message.FinishReasonEndTurn
264		if len(resp.Candidates) > 0 {
265			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
266		}
267		if len(toolCalls) > 0 {
268			finishReason = message.FinishReasonToolUse
269		}
270
271		return &ProviderResponse{
272			Content:      content,
273			ToolCalls:    toolCalls,
274			Usage:        g.usage(resp),
275			FinishReason: finishReason,
276		}, nil
277	}
278}
279
280func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
281	// Convert messages
282	geminiMessages := g.convertMessages(messages)
283
284	model := g.providerOptions.model(g.providerOptions.modelType)
285	cfg := config.Get()
286
287	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
288	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
289		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
290	}
291	maxTokens := model.DefaultMaxTokens
292	if modelConfig.MaxTokens > 0 {
293		maxTokens = modelConfig.MaxTokens
294	}
295
296	// Override max tokens if set in provider options
297	if g.providerOptions.maxTokens > 0 {
298		maxTokens = g.providerOptions.maxTokens
299	}
300	systemMessage := g.providerOptions.systemMessage
301	if g.providerOptions.systemPromptPrefix != "" {
302		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
303	}
304	history := geminiMessages[:len(geminiMessages)-1] // All but last message
305	lastMsg := geminiMessages[len(geminiMessages)-1]
306	config := &genai.GenerateContentConfig{
307		MaxOutputTokens: int32(maxTokens),
308		SystemInstruction: &genai.Content{
309			Parts: []*genai.Part{{Text: systemMessage}},
310		},
311	}
312	config.Tools = g.convertTools(tools)
313	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
314
315	attempts := 0
316	eventChan := make(chan ProviderEvent)
317
318	go func() {
319		defer close(eventChan)
320
321		for {
322			attempts++
323
324			currentContent := ""
325			toolCalls := []message.ToolCall{}
326			var finalResp *genai.GenerateContentResponse
327
328			eventChan <- ProviderEvent{Type: EventContentStart}
329
330			var lastMsgParts []genai.Part
331
332			for _, part := range lastMsg.Parts {
333				lastMsgParts = append(lastMsgParts, *part)
334			}
335
336			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
337				if err != nil {
338					retry, after, retryErr := g.shouldRetry(attempts, err)
339					if retryErr != nil {
340						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
341						return
342					}
343					if retry {
344						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
345						select {
346						case <-ctx.Done():
347							if ctx.Err() != nil {
348								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
349							}
350
351							return
352						case <-time.After(time.Duration(after) * time.Millisecond):
353							continue
354						}
355					} else {
356						eventChan <- ProviderEvent{Type: EventError, Error: err}
357						return
358					}
359				}
360
361				finalResp = resp
362
363				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
364					for _, part := range resp.Candidates[0].Content.Parts {
365						switch {
366						case part.Text != "":
367							delta := string(part.Text)
368							if delta != "" {
369								eventChan <- ProviderEvent{
370									Type:    EventContentDelta,
371									Content: delta,
372								}
373								currentContent += delta
374							}
375						case part.FunctionCall != nil:
376							id := "call_" + uuid.New().String()
377							args, _ := json.Marshal(part.FunctionCall.Args)
378							newCall := message.ToolCall{
379								ID:       id,
380								Name:     part.FunctionCall.Name,
381								Input:    string(args),
382								Type:     "function",
383								Finished: true,
384							}
385
386							toolCalls = append(toolCalls, newCall)
387						}
388					}
389				} else {
390					// no content received
391					break
392				}
393			}
394
395			eventChan <- ProviderEvent{Type: EventContentStop}
396
397			if finalResp != nil {
398				finishReason := message.FinishReasonEndTurn
399				if len(finalResp.Candidates) > 0 {
400					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
401				}
402				if len(toolCalls) > 0 {
403					finishReason = message.FinishReasonToolUse
404				}
405				eventChan <- ProviderEvent{
406					Type: EventComplete,
407					Response: &ProviderResponse{
408						Content:      currentContent,
409						ToolCalls:    toolCalls,
410						Usage:        g.usage(finalResp),
411						FinishReason: finishReason,
412					},
413				}
414				return
415			} else {
416				eventChan <- ProviderEvent{
417					Type:  EventError,
418					Error: errors.New("no content received"),
419				}
420			}
421		}
422	}()
423
424	return eventChan
425}
426
427func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
428	// Check if error is a rate limit error
429	if attempts > maxRetries {
430		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
431	}
432
433	// Gemini doesn't have a standard error type we can check against
434	// So we'll check the error message for rate limit indicators
435	if errors.Is(err, io.EOF) {
436		return false, 0, err
437	}
438
439	errMsg := err.Error()
440	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
441
442	// Check for token expiration (401 Unauthorized)
443	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
444		prev := g.providerOptions.apiKey
445		// in case the key comes from a script, we try to re-evaluate it.
446		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
447		if err != nil {
448			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
449		}
450		// if it didn't change, do not retry.
451		if prev == g.providerOptions.apiKey {
452			return false, 0, err
453		}
454		g.client, err = createGeminiClient(g.providerOptions)
455		if err != nil {
456			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
457		}
458		return true, 0, nil
459	}
460
461	// Check for common rate limit error messages
462
463	if !isRateLimit {
464		return false, 0, err
465	}
466
467	// Calculate backoff with jitter
468	backoffMs := 2000 * (1 << (attempts - 1))
469	jitterMs := int(float64(backoffMs) * 0.2)
470	retryMs := backoffMs + jitterMs
471
472	return true, int64(retryMs), nil
473}
474
475func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
476	if resp == nil || resp.UsageMetadata == nil {
477		return TokenUsage{}
478	}
479
480	return TokenUsage{
481		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
482		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
483		CacheCreationTokens: 0, // Not directly provided by Gemini
484		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
485	}
486}
487
488func (g *geminiClient) Model() catwalk.Model {
489	return g.providerOptions.model(g.providerOptions.modelType)
490}
491
492// Helper functions
493func parseJSONToMap(jsonStr string) (map[string]any, error) {
494	var result map[string]any
495	err := json.Unmarshal([]byte(jsonStr), &result)
496	return result, err
497}
498
499func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
500	properties := make(map[string]*genai.Schema)
501
502	for name, param := range parameters {
503		properties[name] = convertToSchema(param)
504	}
505
506	return properties
507}
508
509func convertToSchema(param any) *genai.Schema {
510	schema := &genai.Schema{Type: genai.TypeString}
511
512	paramMap, ok := param.(map[string]any)
513	if !ok {
514		return schema
515	}
516
517	if desc, ok := paramMap["description"].(string); ok {
518		schema.Description = desc
519	}
520
521	typeVal, hasType := paramMap["type"]
522	if !hasType {
523		return schema
524	}
525
526	typeStr, ok := typeVal.(string)
527	if !ok {
528		return schema
529	}
530
531	schema.Type = mapJSONTypeToGenAI(typeStr)
532
533	switch typeStr {
534	case "array":
535		schema.Items = processArrayItems(paramMap)
536	case "object":
537		if props, ok := paramMap["properties"].(map[string]any); ok {
538			schema.Properties = convertSchemaProperties(props)
539		}
540	}
541
542	return schema
543}
544
545func processArrayItems(paramMap map[string]any) *genai.Schema {
546	items, ok := paramMap["items"].(map[string]any)
547	if !ok {
548		return nil
549	}
550
551	return convertToSchema(items)
552}
553
554func mapJSONTypeToGenAI(jsonType string) genai.Type {
555	switch jsonType {
556	case "string":
557		return genai.TypeString
558	case "number":
559		return genai.TypeNumber
560	case "integer":
561		return genai.TypeInteger
562	case "boolean":
563		return genai.TypeBoolean
564	case "array":
565		return genai.TypeArray
566	case "object":
567		return genai.TypeObject
568	default:
569		return genai.TypeString // Default to string for unknown types
570	}
571}
572
573func contains(s string, substrs ...string) bool {
574	for _, substr := range substrs {
575		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
576			return true
577		}
578	}
579	return false
580}