gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/logging"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type geminiClient struct {
 21	providerOptions providerClientOptions
 22	client          *genai.Client
 23}
 24
 25type GeminiClient ProviderClient
 26
 27func newGeminiClient(opts providerClientOptions) GeminiClient {
 28	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 29	if err != nil {
 30		logging.Error("Failed to create Gemini client", "error", err)
 31		return nil
 32	}
 33
 34	return &geminiClient{
 35		providerOptions: opts,
 36		client:          client,
 37	}
 38}
 39
 40func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 41	var history []*genai.Content
 42	for _, msg := range messages {
 43		switch msg.Role {
 44		case message.User:
 45			var parts []*genai.Part
 46			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 47			for _, binaryContent := range msg.BinaryContent() {
 48				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 49				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 50					MIMEType: imageFormat[1],
 51					Data:     binaryContent.Data,
 52				}})
 53			}
 54			history = append(history, &genai.Content{
 55				Parts: parts,
 56				Role:  "user",
 57			})
 58		case message.Assistant:
 59			var assistantParts []*genai.Part
 60
 61			if msg.Content().String() != "" {
 62				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 63			}
 64
 65			if len(msg.ToolCalls()) > 0 {
 66				for _, call := range msg.ToolCalls() {
 67					args, _ := parseJsonToMap(call.Input)
 68					assistantParts = append(assistantParts, &genai.Part{
 69						FunctionCall: &genai.FunctionCall{
 70							Name: call.Name,
 71							Args: args,
 72						},
 73					})
 74				}
 75			}
 76
 77			if len(assistantParts) > 0 {
 78				history = append(history, &genai.Content{
 79					Role:  "model",
 80					Parts: assistantParts,
 81				})
 82			}
 83
 84		case message.Tool:
 85			for _, result := range msg.ToolResults() {
 86				response := map[string]any{"result": result.Content}
 87				parsed, err := parseJsonToMap(result.Content)
 88				if err == nil {
 89					response = parsed
 90				}
 91
 92				var toolCall message.ToolCall
 93				for _, m := range messages {
 94					if m.Role == message.Assistant {
 95						for _, call := range m.ToolCalls() {
 96							if call.ID == result.ToolCallID {
 97								toolCall = call
 98								break
 99							}
100						}
101					}
102				}
103
104				history = append(history, &genai.Content{
105					Parts: []*genai.Part{
106						{
107							FunctionResponse: &genai.FunctionResponse{
108								Name:     toolCall.Name,
109								Response: response,
110							},
111						},
112					},
113					Role: "function",
114				})
115			}
116		}
117	}
118
119	return history
120}
121
122func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
123	geminiTool := &genai.Tool{}
124	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
125
126	for _, tool := range tools {
127		info := tool.Info()
128		declaration := &genai.FunctionDeclaration{
129			Name:        info.Name,
130			Description: info.Description,
131			Parameters: &genai.Schema{
132				Type:       genai.TypeObject,
133				Properties: convertSchemaProperties(info.Parameters),
134				Required:   info.Required,
135			},
136		}
137
138		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
139	}
140
141	return []*genai.Tool{geminiTool}
142}
143
144func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
145	switch reason {
146	case genai.FinishReasonStop:
147		return message.FinishReasonEndTurn
148	case genai.FinishReasonMaxTokens:
149		return message.FinishReasonMaxTokens
150	default:
151		return message.FinishReasonUnknown
152	}
153}
154
155func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
156	// Convert messages
157	geminiMessages := g.convertMessages(messages)
158	model := g.providerOptions.model(g.providerOptions.modelType)
159	cfg := config.Get()
160	if cfg.Options.Debug {
161		jsonData, _ := json.Marshal(geminiMessages)
162		logging.Debug("Prepared messages", "messages", string(jsonData))
163	}
164
165	modelConfig := cfg.Models.Large
166	if g.providerOptions.modelType == config.SmallModel {
167		modelConfig = cfg.Models.Small
168	}
169
170	maxTokens := model.DefaultMaxTokens
171	if modelConfig.MaxTokens > 0 {
172		maxTokens = modelConfig.MaxTokens
173	}
174	history := geminiMessages[:len(geminiMessages)-1] // All but last message
175	lastMsg := geminiMessages[len(geminiMessages)-1]
176	config := &genai.GenerateContentConfig{
177		MaxOutputTokens: int32(maxTokens),
178		SystemInstruction: &genai.Content{
179			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
180		},
181	}
182	if len(tools) > 0 {
183		config.Tools = g.convertTools(tools)
184	}
185	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
186
187	attempts := 0
188	for {
189		attempts++
190		var toolCalls []message.ToolCall
191
192		var lastMsgParts []genai.Part
193		for _, part := range lastMsg.Parts {
194			lastMsgParts = append(lastMsgParts, *part)
195		}
196		resp, err := chat.SendMessage(ctx, lastMsgParts...)
197		// If there is an error we are going to see if we can retry the call
198		if err != nil {
199			retry, after, retryErr := g.shouldRetry(attempts, err)
200			if retryErr != nil {
201				return nil, retryErr
202			}
203			if retry {
204				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
205				select {
206				case <-ctx.Done():
207					return nil, ctx.Err()
208				case <-time.After(time.Duration(after) * time.Millisecond):
209					continue
210				}
211			}
212			return nil, retryErr
213		}
214
215		content := ""
216
217		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
218			for _, part := range resp.Candidates[0].Content.Parts {
219				switch {
220				case part.Text != "":
221					content = string(part.Text)
222				case part.FunctionCall != nil:
223					id := "call_" + uuid.New().String()
224					args, _ := json.Marshal(part.FunctionCall.Args)
225					toolCalls = append(toolCalls, message.ToolCall{
226						ID:       id,
227						Name:     part.FunctionCall.Name,
228						Input:    string(args),
229						Type:     "function",
230						Finished: true,
231					})
232				}
233			}
234		}
235		finishReason := message.FinishReasonEndTurn
236		if len(resp.Candidates) > 0 {
237			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
238		}
239		if len(toolCalls) > 0 {
240			finishReason = message.FinishReasonToolUse
241		}
242
243		return &ProviderResponse{
244			Content:      content,
245			ToolCalls:    toolCalls,
246			Usage:        g.usage(resp),
247			FinishReason: finishReason,
248		}, nil
249	}
250}
251
252func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253	// Convert messages
254	geminiMessages := g.convertMessages(messages)
255
256	model := g.providerOptions.model(g.providerOptions.modelType)
257	cfg := config.Get()
258	if cfg.Options.Debug {
259		jsonData, _ := json.Marshal(geminiMessages)
260		logging.Debug("Prepared messages", "messages", string(jsonData))
261	}
262
263	modelConfig := cfg.Models.Large
264	if g.providerOptions.modelType == config.SmallModel {
265		modelConfig = cfg.Models.Small
266	}
267	maxTokens := model.DefaultMaxTokens
268	if modelConfig.MaxTokens > 0 {
269		maxTokens = modelConfig.MaxTokens
270	}
271
272	// Override max tokens if set in provider options
273	if g.providerOptions.maxTokens > 0 {
274		maxTokens = g.providerOptions.maxTokens
275	}
276	history := geminiMessages[:len(geminiMessages)-1] // All but last message
277	lastMsg := geminiMessages[len(geminiMessages)-1]
278	config := &genai.GenerateContentConfig{
279		MaxOutputTokens: int32(maxTokens),
280		SystemInstruction: &genai.Content{
281			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
282		},
283	}
284	if len(tools) > 0 {
285		config.Tools = g.convertTools(tools)
286	}
287	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
288
289	attempts := 0
290	eventChan := make(chan ProviderEvent)
291
292	go func() {
293		defer close(eventChan)
294
295		for {
296			attempts++
297
298			currentContent := ""
299			toolCalls := []message.ToolCall{}
300			var finalResp *genai.GenerateContentResponse
301
302			eventChan <- ProviderEvent{Type: EventContentStart}
303
304			var lastMsgParts []genai.Part
305
306			for _, part := range lastMsg.Parts {
307				lastMsgParts = append(lastMsgParts, *part)
308			}
309			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
310				if err != nil {
311					retry, after, retryErr := g.shouldRetry(attempts, err)
312					if retryErr != nil {
313						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
314						return
315					}
316					if retry {
317						logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
318						select {
319						case <-ctx.Done():
320							if ctx.Err() != nil {
321								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
322							}
323
324							return
325						case <-time.After(time.Duration(after) * time.Millisecond):
326							break
327						}
328					} else {
329						eventChan <- ProviderEvent{Type: EventError, Error: err}
330						return
331					}
332				}
333
334				finalResp = resp
335
336				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
337					for _, part := range resp.Candidates[0].Content.Parts {
338						switch {
339						case part.Text != "":
340							delta := string(part.Text)
341							if delta != "" {
342								eventChan <- ProviderEvent{
343									Type:    EventContentDelta,
344									Content: delta,
345								}
346								currentContent += delta
347							}
348						case part.FunctionCall != nil:
349							id := "call_" + uuid.New().String()
350							args, _ := json.Marshal(part.FunctionCall.Args)
351							newCall := message.ToolCall{
352								ID:       id,
353								Name:     part.FunctionCall.Name,
354								Input:    string(args),
355								Type:     "function",
356								Finished: true,
357							}
358
359							isNew := true
360							for _, existing := range toolCalls {
361								if existing.Name == newCall.Name && existing.Input == newCall.Input {
362									isNew = false
363									break
364								}
365							}
366
367							if isNew {
368								toolCalls = append(toolCalls, newCall)
369							}
370						}
371					}
372				}
373			}
374
375			eventChan <- ProviderEvent{Type: EventContentStop}
376
377			if finalResp != nil {
378				finishReason := message.FinishReasonEndTurn
379				if len(finalResp.Candidates) > 0 {
380					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
381				}
382				if len(toolCalls) > 0 {
383					finishReason = message.FinishReasonToolUse
384				}
385				eventChan <- ProviderEvent{
386					Type: EventComplete,
387					Response: &ProviderResponse{
388						Content:      currentContent,
389						ToolCalls:    toolCalls,
390						Usage:        g.usage(finalResp),
391						FinishReason: finishReason,
392					},
393				}
394				return
395			}
396		}
397	}()
398
399	return eventChan
400}
401
402func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
403	// Check if error is a rate limit error
404	if attempts > maxRetries {
405		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
406	}
407
408	// Gemini doesn't have a standard error type we can check against
409	// So we'll check the error message for rate limit indicators
410	if errors.Is(err, io.EOF) {
411		return false, 0, err
412	}
413
414	errMsg := err.Error()
415	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
416
417	// Check for common rate limit error messages
418
419	if !isRateLimit {
420		return false, 0, err
421	}
422
423	// Calculate backoff with jitter
424	backoffMs := 2000 * (1 << (attempts - 1))
425	jitterMs := int(float64(backoffMs) * 0.2)
426	retryMs := backoffMs + jitterMs
427
428	return true, int64(retryMs), nil
429}
430
431func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
432	if resp == nil || resp.UsageMetadata == nil {
433		return TokenUsage{}
434	}
435
436	return TokenUsage{
437		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
438		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
439		CacheCreationTokens: 0, // Not directly provided by Gemini
440		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
441	}
442}
443
444func (g *geminiClient) Model() config.Model {
445	return g.providerOptions.model(g.providerOptions.modelType)
446}
447
448// Helper functions
449func parseJsonToMap(jsonStr string) (map[string]any, error) {
450	var result map[string]any
451	err := json.Unmarshal([]byte(jsonStr), &result)
452	return result, err
453}
454
455func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
456	properties := make(map[string]*genai.Schema)
457
458	for name, param := range parameters {
459		properties[name] = convertToSchema(param)
460	}
461
462	return properties
463}
464
465func convertToSchema(param any) *genai.Schema {
466	schema := &genai.Schema{Type: genai.TypeString}
467
468	paramMap, ok := param.(map[string]any)
469	if !ok {
470		return schema
471	}
472
473	if desc, ok := paramMap["description"].(string); ok {
474		schema.Description = desc
475	}
476
477	typeVal, hasType := paramMap["type"]
478	if !hasType {
479		return schema
480	}
481
482	typeStr, ok := typeVal.(string)
483	if !ok {
484		return schema
485	}
486
487	schema.Type = mapJSONTypeToGenAI(typeStr)
488
489	switch typeStr {
490	case "array":
491		schema.Items = processArrayItems(paramMap)
492	case "object":
493		if props, ok := paramMap["properties"].(map[string]any); ok {
494			schema.Properties = convertSchemaProperties(props)
495		}
496	}
497
498	return schema
499}
500
501func processArrayItems(paramMap map[string]any) *genai.Schema {
502	items, ok := paramMap["items"].(map[string]any)
503	if !ok {
504		return nil
505	}
506
507	return convertToSchema(items)
508}
509
510func mapJSONTypeToGenAI(jsonType string) genai.Type {
511	switch jsonType {
512	case "string":
513		return genai.TypeString
514	case "number":
515		return genai.TypeNumber
516	case "integer":
517		return genai.TypeInteger
518	case "boolean":
519		return genai.TypeBoolean
520	case "array":
521		return genai.TypeArray
522	case "object":
523		return genai.TypeObject
524	default:
525		return genai.TypeString // Default to string for unknown types
526	}
527}
528
529func contains(s string, substrs ...string) bool {
530	for _, substr := range substrs {
531		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
532			return true
533		}
534	}
535	return false
536}