gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/logging"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/google/uuid"
 18	"google.golang.org/genai"
 19)
 20
 21type geminiClient struct {
 22	providerOptions providerClientOptions
 23	client          *genai.Client
 24}
 25
 26type GeminiClient ProviderClient
 27
 28func newGeminiClient(opts providerClientOptions) GeminiClient {
 29	client, err := createGeminiClient(opts)
 30	if err != nil {
 31		logging.Error("Failed to create Gemini client", "error", err)
 32		return nil
 33	}
 34
 35	return &geminiClient{
 36		providerOptions: opts,
 37		client:          client,
 38	}
 39}
 40
 41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 42	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 43	if err != nil {
 44		return nil, err
 45	}
 46	return client, nil
 47}
 48
 49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 50	var history []*genai.Content
 51	for _, msg := range messages {
 52		switch msg.Role {
 53		case message.User:
 54			var parts []*genai.Part
 55			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 56			for _, binaryContent := range msg.BinaryContent() {
 57				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 58				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 59					MIMEType: imageFormat[1],
 60					Data:     binaryContent.Data,
 61				}})
 62			}
 63			history = append(history, &genai.Content{
 64				Parts: parts,
 65				Role:  "user",
 66			})
 67		case message.Assistant:
 68			var assistantParts []*genai.Part
 69
 70			if msg.Content().String() != "" {
 71				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 72			}
 73
 74			if len(msg.ToolCalls()) > 0 {
 75				for _, call := range msg.ToolCalls() {
 76					args, _ := parseJsonToMap(call.Input)
 77					assistantParts = append(assistantParts, &genai.Part{
 78						FunctionCall: &genai.FunctionCall{
 79							Name: call.Name,
 80							Args: args,
 81						},
 82					})
 83				}
 84			}
 85
 86			if len(assistantParts) > 0 {
 87				history = append(history, &genai.Content{
 88					Role:  "model",
 89					Parts: assistantParts,
 90				})
 91			}
 92
 93		case message.Tool:
 94			for _, result := range msg.ToolResults() {
 95				response := map[string]any{"result": result.Content}
 96				parsed, err := parseJsonToMap(result.Content)
 97				if err == nil {
 98					response = parsed
 99				}
100
101				var toolCall message.ToolCall
102				for _, m := range messages {
103					if m.Role == message.Assistant {
104						for _, call := range m.ToolCalls() {
105							if call.ID == result.ToolCallID {
106								toolCall = call
107								break
108							}
109						}
110					}
111				}
112
113				history = append(history, &genai.Content{
114					Parts: []*genai.Part{
115						{
116							FunctionResponse: &genai.FunctionResponse{
117								Name:     toolCall.Name,
118								Response: response,
119							},
120						},
121					},
122					Role: "function",
123				})
124			}
125		}
126	}
127
128	return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132	geminiTool := &genai.Tool{}
133	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135	for _, tool := range tools {
136		info := tool.Info()
137		declaration := &genai.FunctionDeclaration{
138			Name:        info.Name,
139			Description: info.Description,
140			Parameters: &genai.Schema{
141				Type:       genai.TypeObject,
142				Properties: convertSchemaProperties(info.Parameters),
143				Required:   info.Required,
144			},
145		}
146
147		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148	}
149
150	return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154	switch reason {
155	case genai.FinishReasonStop:
156		return message.FinishReasonEndTurn
157	case genai.FinishReasonMaxTokens:
158		return message.FinishReasonMaxTokens
159	default:
160		return message.FinishReasonUnknown
161	}
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165	// Convert messages
166	geminiMessages := g.convertMessages(messages)
167	model := g.providerOptions.model(g.providerOptions.modelType)
168	cfg := config.Get()
169	if cfg.Options.Debug {
170		jsonData, _ := json.Marshal(geminiMessages)
171		logging.Debug("Prepared messages", "messages", string(jsonData))
172	}
173
174	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177	}
178
179	maxTokens := model.DefaultMaxTokens
180	if modelConfig.MaxTokens > 0 {
181		maxTokens = modelConfig.MaxTokens
182	}
183	history := geminiMessages[:len(geminiMessages)-1] // All but last message
184	lastMsg := geminiMessages[len(geminiMessages)-1]
185	config := &genai.GenerateContentConfig{
186		MaxOutputTokens: int32(maxTokens),
187		SystemInstruction: &genai.Content{
188			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
189		},
190	}
191	if len(tools) > 0 {
192		config.Tools = g.convertTools(tools)
193	}
194	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
195
196	attempts := 0
197	for {
198		attempts++
199		var toolCalls []message.ToolCall
200
201		var lastMsgParts []genai.Part
202		for _, part := range lastMsg.Parts {
203			lastMsgParts = append(lastMsgParts, *part)
204		}
205		resp, err := chat.SendMessage(ctx, lastMsgParts...)
206		// If there is an error we are going to see if we can retry the call
207		if err != nil {
208			retry, after, retryErr := g.shouldRetry(attempts, err)
209			if retryErr != nil {
210				return nil, retryErr
211			}
212			if retry {
213				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
214				select {
215				case <-ctx.Done():
216					return nil, ctx.Err()
217				case <-time.After(time.Duration(after) * time.Millisecond):
218					continue
219				}
220			}
221			return nil, retryErr
222		}
223
224		content := ""
225
226		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
227			for _, part := range resp.Candidates[0].Content.Parts {
228				switch {
229				case part.Text != "":
230					content = string(part.Text)
231				case part.FunctionCall != nil:
232					id := "call_" + uuid.New().String()
233					args, _ := json.Marshal(part.FunctionCall.Args)
234					toolCalls = append(toolCalls, message.ToolCall{
235						ID:       id,
236						Name:     part.FunctionCall.Name,
237						Input:    string(args),
238						Type:     "function",
239						Finished: true,
240					})
241				}
242			}
243		}
244		finishReason := message.FinishReasonEndTurn
245		if len(resp.Candidates) > 0 {
246			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
247		}
248		if len(toolCalls) > 0 {
249			finishReason = message.FinishReasonToolUse
250		}
251
252		return &ProviderResponse{
253			Content:      content,
254			ToolCalls:    toolCalls,
255			Usage:        g.usage(resp),
256			FinishReason: finishReason,
257		}, nil
258	}
259}
260
261func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262	// Convert messages
263	geminiMessages := g.convertMessages(messages)
264
265	model := g.providerOptions.model(g.providerOptions.modelType)
266	cfg := config.Get()
267	if cfg.Options.Debug {
268		jsonData, _ := json.Marshal(geminiMessages)
269		logging.Debug("Prepared messages", "messages", string(jsonData))
270	}
271
272	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
273	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
274		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
275	}
276	maxTokens := model.DefaultMaxTokens
277	if modelConfig.MaxTokens > 0 {
278		maxTokens = modelConfig.MaxTokens
279	}
280
281	// Override max tokens if set in provider options
282	if g.providerOptions.maxTokens > 0 {
283		maxTokens = g.providerOptions.maxTokens
284	}
285	history := geminiMessages[:len(geminiMessages)-1] // All but last message
286	lastMsg := geminiMessages[len(geminiMessages)-1]
287	config := &genai.GenerateContentConfig{
288		MaxOutputTokens: int32(maxTokens),
289		SystemInstruction: &genai.Content{
290			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
291		},
292	}
293	if len(tools) > 0 {
294		config.Tools = g.convertTools(tools)
295	}
296	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
297
298	attempts := 0
299	eventChan := make(chan ProviderEvent)
300
301	go func() {
302		defer close(eventChan)
303
304		for {
305			attempts++
306
307			currentContent := ""
308			toolCalls := []message.ToolCall{}
309			var finalResp *genai.GenerateContentResponse
310
311			eventChan <- ProviderEvent{Type: EventContentStart}
312
313			var lastMsgParts []genai.Part
314
315			for _, part := range lastMsg.Parts {
316				lastMsgParts = append(lastMsgParts, *part)
317			}
318			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
319				if err != nil {
320					retry, after, retryErr := g.shouldRetry(attempts, err)
321					if retryErr != nil {
322						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
323						return
324					}
325					if retry {
326						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
327						select {
328						case <-ctx.Done():
329							if ctx.Err() != nil {
330								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
331							}
332
333							return
334						case <-time.After(time.Duration(after) * time.Millisecond):
335							break
336						}
337					} else {
338						eventChan <- ProviderEvent{Type: EventError, Error: err}
339						return
340					}
341				}
342
343				finalResp = resp
344
345				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
346					for _, part := range resp.Candidates[0].Content.Parts {
347						switch {
348						case part.Text != "":
349							delta := string(part.Text)
350							if delta != "" {
351								eventChan <- ProviderEvent{
352									Type:    EventContentDelta,
353									Content: delta,
354								}
355								currentContent += delta
356							}
357						case part.FunctionCall != nil:
358							id := "call_" + uuid.New().String()
359							args, _ := json.Marshal(part.FunctionCall.Args)
360							newCall := message.ToolCall{
361								ID:       id,
362								Name:     part.FunctionCall.Name,
363								Input:    string(args),
364								Type:     "function",
365								Finished: true,
366							}
367
368							isNew := true
369							for _, existing := range toolCalls {
370								if existing.Name == newCall.Name && existing.Input == newCall.Input {
371									isNew = false
372									break
373								}
374							}
375
376							if isNew {
377								toolCalls = append(toolCalls, newCall)
378							}
379						}
380					}
381				}
382			}
383
384			eventChan <- ProviderEvent{Type: EventContentStop}
385
386			if finalResp != nil {
387				finishReason := message.FinishReasonEndTurn
388				if len(finalResp.Candidates) > 0 {
389					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
390				}
391				if len(toolCalls) > 0 {
392					finishReason = message.FinishReasonToolUse
393				}
394				eventChan <- ProviderEvent{
395					Type: EventComplete,
396					Response: &ProviderResponse{
397						Content:      currentContent,
398						ToolCalls:    toolCalls,
399						Usage:        g.usage(finalResp),
400						FinishReason: finishReason,
401					},
402				}
403				return
404			}
405		}
406	}()
407
408	return eventChan
409}
410
411func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
412	// Check if error is a rate limit error
413	if attempts > maxRetries {
414		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
415	}
416
417	// Gemini doesn't have a standard error type we can check against
418	// So we'll check the error message for rate limit indicators
419	if errors.Is(err, io.EOF) {
420		return false, 0, err
421	}
422
423	errMsg := err.Error()
424	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
425
426	// Check for token expiration (401 Unauthorized)
427	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
428		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
429		if err != nil {
430			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
431		}
432		g.client, err = createGeminiClient(g.providerOptions)
433		if err != nil {
434			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
435		}
436		return true, 0, nil
437	}
438
439	// Check for common rate limit error messages
440
441	if !isRateLimit {
442		return false, 0, err
443	}
444
445	// Calculate backoff with jitter
446	backoffMs := 2000 * (1 << (attempts - 1))
447	jitterMs := int(float64(backoffMs) * 0.2)
448	retryMs := backoffMs + jitterMs
449
450	return true, int64(retryMs), nil
451}
452
453func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
454	if resp == nil || resp.UsageMetadata == nil {
455		return TokenUsage{}
456	}
457
458	return TokenUsage{
459		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
460		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
461		CacheCreationTokens: 0, // Not directly provided by Gemini
462		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
463	}
464}
465
466func (g *geminiClient) Model() provider.Model {
467	return g.providerOptions.model(g.providerOptions.modelType)
468}
469
470// Helper functions
471func parseJsonToMap(jsonStr string) (map[string]any, error) {
472	var result map[string]any
473	err := json.Unmarshal([]byte(jsonStr), &result)
474	return result, err
475}
476
477func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
478	properties := make(map[string]*genai.Schema)
479
480	for name, param := range parameters {
481		properties[name] = convertToSchema(param)
482	}
483
484	return properties
485}
486
487func convertToSchema(param any) *genai.Schema {
488	schema := &genai.Schema{Type: genai.TypeString}
489
490	paramMap, ok := param.(map[string]any)
491	if !ok {
492		return schema
493	}
494
495	if desc, ok := paramMap["description"].(string); ok {
496		schema.Description = desc
497	}
498
499	typeVal, hasType := paramMap["type"]
500	if !hasType {
501		return schema
502	}
503
504	typeStr, ok := typeVal.(string)
505	if !ok {
506		return schema
507	}
508
509	schema.Type = mapJSONTypeToGenAI(typeStr)
510
511	switch typeStr {
512	case "array":
513		schema.Items = processArrayItems(paramMap)
514	case "object":
515		if props, ok := paramMap["properties"].(map[string]any); ok {
516			schema.Properties = convertSchemaProperties(props)
517		}
518	}
519
520	return schema
521}
522
523func processArrayItems(paramMap map[string]any) *genai.Schema {
524	items, ok := paramMap["items"].(map[string]any)
525	if !ok {
526		return nil
527	}
528
529	return convertToSchema(items)
530}
531
532func mapJSONTypeToGenAI(jsonType string) genai.Type {
533	switch jsonType {
534	case "string":
535		return genai.TypeString
536	case "number":
537		return genai.TypeNumber
538	case "integer":
539		return genai.TypeInteger
540	case "boolean":
541		return genai.TypeBoolean
542	case "array":
543		return genai.TypeArray
544	case "object":
545		return genai.TypeObject
546	default:
547		return genai.TypeString // Default to string for unknown types
548	}
549}
550
551func contains(s string, substrs ...string) bool {
552	for _, substr := range substrs {
553		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
554			return true
555		}
556	}
557	return false
558}