gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/logging"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type geminiClient struct {
 21	providerOptions providerClientOptions
 22	client          *genai.Client
 23}
 24
 25type GeminiClient ProviderClient
 26
 27func newGeminiClient(opts providerClientOptions) GeminiClient {
 28	client, err := createGeminiClient(opts)
 29	if err != nil {
 30		logging.Error("Failed to create Gemini client", "error", err)
 31		return nil
 32	}
 33
 34	return &geminiClient{
 35		providerOptions: opts,
 36		client:          client,
 37	}
 38}
 39
 40func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 41	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 42	if err != nil {
 43		return nil, err
 44	}
 45	return client, nil
 46}
 47
 48func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 49	var history []*genai.Content
 50	for _, msg := range messages {
 51		switch msg.Role {
 52		case message.User:
 53			var parts []*genai.Part
 54			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 55			for _, binaryContent := range msg.BinaryContent() {
 56				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 57				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 58					MIMEType: imageFormat[1],
 59					Data:     binaryContent.Data,
 60				}})
 61			}
 62			history = append(history, &genai.Content{
 63				Parts: parts,
 64				Role:  "user",
 65			})
 66		case message.Assistant:
 67			var assistantParts []*genai.Part
 68
 69			if msg.Content().String() != "" {
 70				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 71			}
 72
 73			if len(msg.ToolCalls()) > 0 {
 74				for _, call := range msg.ToolCalls() {
 75					args, _ := parseJsonToMap(call.Input)
 76					assistantParts = append(assistantParts, &genai.Part{
 77						FunctionCall: &genai.FunctionCall{
 78							Name: call.Name,
 79							Args: args,
 80						},
 81					})
 82				}
 83			}
 84
 85			if len(assistantParts) > 0 {
 86				history = append(history, &genai.Content{
 87					Role:  "model",
 88					Parts: assistantParts,
 89				})
 90			}
 91
 92		case message.Tool:
 93			for _, result := range msg.ToolResults() {
 94				response := map[string]any{"result": result.Content}
 95				parsed, err := parseJsonToMap(result.Content)
 96				if err == nil {
 97					response = parsed
 98				}
 99
100				var toolCall message.ToolCall
101				for _, m := range messages {
102					if m.Role == message.Assistant {
103						for _, call := range m.ToolCalls() {
104							if call.ID == result.ToolCallID {
105								toolCall = call
106								break
107							}
108						}
109					}
110				}
111
112				history = append(history, &genai.Content{
113					Parts: []*genai.Part{
114						{
115							FunctionResponse: &genai.FunctionResponse{
116								Name:     toolCall.Name,
117								Response: response,
118							},
119						},
120					},
121					Role: "function",
122				})
123			}
124		}
125	}
126
127	return history
128}
129
130func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
131	geminiTool := &genai.Tool{}
132	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
133
134	for _, tool := range tools {
135		info := tool.Info()
136		declaration := &genai.FunctionDeclaration{
137			Name:        info.Name,
138			Description: info.Description,
139			Parameters: &genai.Schema{
140				Type:       genai.TypeObject,
141				Properties: convertSchemaProperties(info.Parameters),
142				Required:   info.Required,
143			},
144		}
145
146		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
147	}
148
149	return []*genai.Tool{geminiTool}
150}
151
152func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
153	switch reason {
154	case genai.FinishReasonStop:
155		return message.FinishReasonEndTurn
156	case genai.FinishReasonMaxTokens:
157		return message.FinishReasonMaxTokens
158	default:
159		return message.FinishReasonUnknown
160	}
161}
162
163func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
164	// Convert messages
165	geminiMessages := g.convertMessages(messages)
166	model := g.providerOptions.model(g.providerOptions.modelType)
167	cfg := config.Get()
168	if cfg.Options.Debug {
169		jsonData, _ := json.Marshal(geminiMessages)
170		logging.Debug("Prepared messages", "messages", string(jsonData))
171	}
172
173	modelConfig := cfg.Models.Large
174	if g.providerOptions.modelType == config.SmallModel {
175		modelConfig = cfg.Models.Small
176	}
177
178	maxTokens := model.DefaultMaxTokens
179	if modelConfig.MaxTokens > 0 {
180		maxTokens = modelConfig.MaxTokens
181	}
182	history := geminiMessages[:len(geminiMessages)-1] // All but last message
183	lastMsg := geminiMessages[len(geminiMessages)-1]
184	config := &genai.GenerateContentConfig{
185		MaxOutputTokens: int32(maxTokens),
186		SystemInstruction: &genai.Content{
187			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
188		},
189	}
190	if len(tools) > 0 {
191		config.Tools = g.convertTools(tools)
192	}
193	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
194
195	attempts := 0
196	for {
197		attempts++
198		var toolCalls []message.ToolCall
199
200		var lastMsgParts []genai.Part
201		for _, part := range lastMsg.Parts {
202			lastMsgParts = append(lastMsgParts, *part)
203		}
204		resp, err := chat.SendMessage(ctx, lastMsgParts...)
205		// If there is an error we are going to see if we can retry the call
206		if err != nil {
207			retry, after, retryErr := g.shouldRetry(attempts, err)
208			if retryErr != nil {
209				return nil, retryErr
210			}
211			if retry {
212				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
213				select {
214				case <-ctx.Done():
215					return nil, ctx.Err()
216				case <-time.After(time.Duration(after) * time.Millisecond):
217					continue
218				}
219			}
220			return nil, retryErr
221		}
222
223		content := ""
224
225		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
226			for _, part := range resp.Candidates[0].Content.Parts {
227				switch {
228				case part.Text != "":
229					content = string(part.Text)
230				case part.FunctionCall != nil:
231					id := "call_" + uuid.New().String()
232					args, _ := json.Marshal(part.FunctionCall.Args)
233					toolCalls = append(toolCalls, message.ToolCall{
234						ID:       id,
235						Name:     part.FunctionCall.Name,
236						Input:    string(args),
237						Type:     "function",
238						Finished: true,
239					})
240				}
241			}
242		}
243		finishReason := message.FinishReasonEndTurn
244		if len(resp.Candidates) > 0 {
245			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
246		}
247		if len(toolCalls) > 0 {
248			finishReason = message.FinishReasonToolUse
249		}
250
251		return &ProviderResponse{
252			Content:      content,
253			ToolCalls:    toolCalls,
254			Usage:        g.usage(resp),
255			FinishReason: finishReason,
256		}, nil
257	}
258}
259
260func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
261	// Convert messages
262	geminiMessages := g.convertMessages(messages)
263
264	model := g.providerOptions.model(g.providerOptions.modelType)
265	cfg := config.Get()
266	if cfg.Options.Debug {
267		jsonData, _ := json.Marshal(geminiMessages)
268		logging.Debug("Prepared messages", "messages", string(jsonData))
269	}
270
271	modelConfig := cfg.Models.Large
272	if g.providerOptions.modelType == config.SmallModel {
273		modelConfig = cfg.Models.Small
274	}
275	maxTokens := model.DefaultMaxTokens
276	if modelConfig.MaxTokens > 0 {
277		maxTokens = modelConfig.MaxTokens
278	}
279
280	// Override max tokens if set in provider options
281	if g.providerOptions.maxTokens > 0 {
282		maxTokens = g.providerOptions.maxTokens
283	}
284	history := geminiMessages[:len(geminiMessages)-1] // All but last message
285	lastMsg := geminiMessages[len(geminiMessages)-1]
286	config := &genai.GenerateContentConfig{
287		MaxOutputTokens: int32(maxTokens),
288		SystemInstruction: &genai.Content{
289			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
290		},
291	}
292	if len(tools) > 0 {
293		config.Tools = g.convertTools(tools)
294	}
295	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
296
297	attempts := 0
298	eventChan := make(chan ProviderEvent)
299
300	go func() {
301		defer close(eventChan)
302
303		for {
304			attempts++
305
306			currentContent := ""
307			toolCalls := []message.ToolCall{}
308			var finalResp *genai.GenerateContentResponse
309
310			eventChan <- ProviderEvent{Type: EventContentStart}
311
312			var lastMsgParts []genai.Part
313
314			for _, part := range lastMsg.Parts {
315				lastMsgParts = append(lastMsgParts, *part)
316			}
317			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
318				if err != nil {
319					retry, after, retryErr := g.shouldRetry(attempts, err)
320					if retryErr != nil {
321						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
322						return
323					}
324					if retry {
325						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
326						select {
327						case <-ctx.Done():
328							if ctx.Err() != nil {
329								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
330							}
331
332							return
333						case <-time.After(time.Duration(after) * time.Millisecond):
334							break
335						}
336					} else {
337						eventChan <- ProviderEvent{Type: EventError, Error: err}
338						return
339					}
340				}
341
342				finalResp = resp
343
344				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
345					for _, part := range resp.Candidates[0].Content.Parts {
346						switch {
347						case part.Text != "":
348							delta := string(part.Text)
349							if delta != "" {
350								eventChan <- ProviderEvent{
351									Type:    EventContentDelta,
352									Content: delta,
353								}
354								currentContent += delta
355							}
356						case part.FunctionCall != nil:
357							id := "call_" + uuid.New().String()
358							args, _ := json.Marshal(part.FunctionCall.Args)
359							newCall := message.ToolCall{
360								ID:       id,
361								Name:     part.FunctionCall.Name,
362								Input:    string(args),
363								Type:     "function",
364								Finished: true,
365							}
366
367							isNew := true
368							for _, existing := range toolCalls {
369								if existing.Name == newCall.Name && existing.Input == newCall.Input {
370									isNew = false
371									break
372								}
373							}
374
375							if isNew {
376								toolCalls = append(toolCalls, newCall)
377							}
378						}
379					}
380				}
381			}
382
383			eventChan <- ProviderEvent{Type: EventContentStop}
384
385			if finalResp != nil {
386				finishReason := message.FinishReasonEndTurn
387				if len(finalResp.Candidates) > 0 {
388					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
389				}
390				if len(toolCalls) > 0 {
391					finishReason = message.FinishReasonToolUse
392				}
393				eventChan <- ProviderEvent{
394					Type: EventComplete,
395					Response: &ProviderResponse{
396						Content:      currentContent,
397						ToolCalls:    toolCalls,
398						Usage:        g.usage(finalResp),
399						FinishReason: finishReason,
400					},
401				}
402				return
403			}
404		}
405	}()
406
407	return eventChan
408}
409
410func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
411	// Check if error is a rate limit error
412	if attempts > maxRetries {
413		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
414	}
415
416	// Gemini doesn't have a standard error type we can check against
417	// So we'll check the error message for rate limit indicators
418	if errors.Is(err, io.EOF) {
419		return false, 0, err
420	}
421
422	errMsg := err.Error()
423	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
424
425	// Check for token expiration (401 Unauthorized)
426	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
427		g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
428		if err != nil {
429			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
430		}
431		g.client, err = createGeminiClient(g.providerOptions)
432		if err != nil {
433			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
434		}
435		return true, 0, nil
436	}
437
438	// Check for common rate limit error messages
439
440	if !isRateLimit {
441		return false, 0, err
442	}
443
444	// Calculate backoff with jitter
445	backoffMs := 2000 * (1 << (attempts - 1))
446	jitterMs := int(float64(backoffMs) * 0.2)
447	retryMs := backoffMs + jitterMs
448
449	return true, int64(retryMs), nil
450}
451
452func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
453	if resp == nil || resp.UsageMetadata == nil {
454		return TokenUsage{}
455	}
456
457	return TokenUsage{
458		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
459		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
460		CacheCreationTokens: 0, // Not directly provided by Gemini
461		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
462	}
463}
464
465func (g *geminiClient) Model() config.Model {
466	return g.providerOptions.model(g.providerOptions.modelType)
467}
468
469// Helper functions
470func parseJsonToMap(jsonStr string) (map[string]any, error) {
471	var result map[string]any
472	err := json.Unmarshal([]byte(jsonStr), &result)
473	return result, err
474}
475
476func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
477	properties := make(map[string]*genai.Schema)
478
479	for name, param := range parameters {
480		properties[name] = convertToSchema(param)
481	}
482
483	return properties
484}
485
486func convertToSchema(param any) *genai.Schema {
487	schema := &genai.Schema{Type: genai.TypeString}
488
489	paramMap, ok := param.(map[string]any)
490	if !ok {
491		return schema
492	}
493
494	if desc, ok := paramMap["description"].(string); ok {
495		schema.Description = desc
496	}
497
498	typeVal, hasType := paramMap["type"]
499	if !hasType {
500		return schema
501	}
502
503	typeStr, ok := typeVal.(string)
504	if !ok {
505		return schema
506	}
507
508	schema.Type = mapJSONTypeToGenAI(typeStr)
509
510	switch typeStr {
511	case "array":
512		schema.Items = processArrayItems(paramMap)
513	case "object":
514		if props, ok := paramMap["properties"].(map[string]any); ok {
515			schema.Properties = convertSchemaProperties(props)
516		}
517	}
518
519	return schema
520}
521
522func processArrayItems(paramMap map[string]any) *genai.Schema {
523	items, ok := paramMap["items"].(map[string]any)
524	if !ok {
525		return nil
526	}
527
528	return convertToSchema(items)
529}
530
531func mapJSONTypeToGenAI(jsonType string) genai.Type {
532	switch jsonType {
533	case "string":
534		return genai.TypeString
535	case "number":
536		return genai.TypeNumber
537	case "integer":
538		return genai.TypeInteger
539	case "boolean":
540		return genai.TypeBoolean
541	case "array":
542		return genai.TypeArray
543	case "object":
544		return genai.TypeObject
545	default:
546		return genai.TypeString // Default to string for unknown types
547	}
548}
549
550func contains(s string, substrs ...string) bool {
551	for _, substr := range substrs {
552		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
553			return true
554		}
555	}
556	return false
557}