gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/message"
 15	"github.com/google/uuid"
 16	"google.golang.org/genai"
 17)
 18
 19type geminiProvider struct {
 20	*baseProvider
 21	client *genai.Client
 22}
 23
 24func NewGeminiProvider(base *baseProvider) Provider {
 25	client, err := createGeminiClient(base)
 26	if err != nil {
 27		slog.Error("Failed to create Gemini client", "error", err)
 28		return nil
 29	}
 30
 31	return &geminiProvider{
 32		baseProvider: base,
 33		client:       client,
 34	}
 35}
 36
 37func createGeminiClient(base *baseProvider) (*genai.Client, error) {
 38	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: base.apiKey, Backend: genai.BackendGeminiAPI})
 39	if err != nil {
 40		return nil, err
 41	}
 42	return client, nil
 43}
 44
 45func (g *geminiProvider) convertMessages(messages []message.Message) []*genai.Content {
 46	var history []*genai.Content
 47	for _, msg := range messages {
 48		switch msg.Role {
 49		case message.User:
 50			var parts []*genai.Part
 51			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 52			for _, binaryContent := range msg.BinaryContent() {
 53				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 54				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 55					MIMEType: imageFormat[1],
 56					Data:     binaryContent.Data,
 57				}})
 58			}
 59			history = append(history, &genai.Content{
 60				Parts: parts,
 61				Role:  "user",
 62			})
 63		case message.Assistant:
 64			var assistantParts []*genai.Part
 65
 66			if msg.Content().String() != "" {
 67				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 68			}
 69
 70			if len(msg.ToolCalls()) > 0 {
 71				for _, call := range msg.ToolCalls() {
 72					args, _ := parseJSONToMap(call.Input)
 73					assistantParts = append(assistantParts, &genai.Part{
 74						FunctionCall: &genai.FunctionCall{
 75							Name: call.Name,
 76							Args: args,
 77						},
 78					})
 79				}
 80			}
 81
 82			if len(assistantParts) > 0 {
 83				history = append(history, &genai.Content{
 84					Role:  "model",
 85					Parts: assistantParts,
 86				})
 87			}
 88
 89		case message.Tool:
 90			for _, result := range msg.ToolResults() {
 91				response := map[string]any{"result": result.Content}
 92				parsed, err := parseJSONToMap(result.Content)
 93				if err == nil {
 94					response = parsed
 95				}
 96
 97				var toolCall message.ToolCall
 98				for _, m := range messages {
 99					if m.Role == message.Assistant {
100						for _, call := range m.ToolCalls() {
101							if call.ID == result.ToolCallID {
102								toolCall = call
103								break
104							}
105						}
106					}
107				}
108
109				history = append(history, &genai.Content{
110					Parts: []*genai.Part{
111						{
112							FunctionResponse: &genai.FunctionResponse{
113								Name:     toolCall.Name,
114								Response: response,
115							},
116						},
117					},
118					Role: "function",
119				})
120			}
121		}
122	}
123
124	return history
125}
126
127func (g *geminiProvider) convertTools(tools []tools.BaseTool) []*genai.Tool {
128	geminiTool := &genai.Tool{}
129	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
130
131	for _, tool := range tools {
132		info := tool.Info()
133		declaration := &genai.FunctionDeclaration{
134			Name:        info.Name,
135			Description: info.Description,
136			Parameters: &genai.Schema{
137				Type:       genai.TypeObject,
138				Properties: convertSchemaProperties(info.Parameters),
139				Required:   info.Required,
140			},
141		}
142
143		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
144	}
145
146	return []*genai.Tool{geminiTool}
147}
148
149func (g *geminiProvider) finishReason(reason genai.FinishReason) message.FinishReason {
150	switch reason {
151	case genai.FinishReasonStop:
152		return message.FinishReasonEndTurn
153	case genai.FinishReasonMaxTokens:
154		return message.FinishReasonMaxTokens
155	default:
156		return message.FinishReasonUnknown
157	}
158}
159
160func (g *geminiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
161	messages = g.cleanMessages(messages)
162	return g.send(ctx, model, messages, tools)
163}
164
165func (g *geminiProvider) send(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166	// Convert messages
167	geminiMessages := g.convertMessages(messages)
168	if g.debug {
169		jsonData, _ := json.Marshal(geminiMessages)
170		slog.Debug("Prepared messages", "messages", string(jsonData))
171	}
172
173	model := g.Model(modelID)
174	maxTokens := model.DefaultMaxTokens
175	if g.maxTokens > 0 {
176		maxTokens = g.maxTokens
177	}
178	systemMessage := g.systemMessage
179	if g.systemPromptPrefix != "" {
180		systemMessage = g.systemPromptPrefix + "\n" + systemMessage
181	}
182	history := geminiMessages[:len(geminiMessages)-1] // All but last message
183	lastMsg := geminiMessages[len(geminiMessages)-1]
184	config := &genai.GenerateContentConfig{
185		MaxOutputTokens: int32(maxTokens),
186		SystemInstruction: &genai.Content{
187			Parts: []*genai.Part{{Text: systemMessage}},
188		},
189	}
190	config.Tools = g.convertTools(tools)
191	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
192
193	attempts := 0
194	for {
195		attempts++
196		var toolCalls []message.ToolCall
197
198		var lastMsgParts []genai.Part
199		for _, part := range lastMsg.Parts {
200			lastMsgParts = append(lastMsgParts, *part)
201		}
202		resp, err := chat.SendMessage(ctx, lastMsgParts...)
203		// If there is an error we are going to see if we can retry the call
204		if err != nil {
205			retry, after, retryErr := g.shouldRetry(attempts, err)
206			if retryErr != nil {
207				return nil, retryErr
208			}
209			if retry {
210				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
211				select {
212				case <-ctx.Done():
213					return nil, ctx.Err()
214				case <-time.After(time.Duration(after) * time.Millisecond):
215					continue
216				}
217			}
218			return nil, retryErr
219		}
220
221		content := ""
222
223		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
224			for _, part := range resp.Candidates[0].Content.Parts {
225				switch {
226				case part.Text != "":
227					content = string(part.Text)
228				case part.FunctionCall != nil:
229					id := "call_" + uuid.New().String()
230					args, _ := json.Marshal(part.FunctionCall.Args)
231					toolCalls = append(toolCalls, message.ToolCall{
232						ID:       id,
233						Name:     part.FunctionCall.Name,
234						Input:    string(args),
235						Type:     "function",
236						Finished: true,
237					})
238				}
239			}
240		}
241		finishReason := message.FinishReasonEndTurn
242		if len(resp.Candidates) > 0 {
243			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
244		}
245		if len(toolCalls) > 0 {
246			finishReason = message.FinishReasonToolUse
247		}
248
249		return &ProviderResponse{
250			Content:      content,
251			ToolCalls:    toolCalls,
252			Usage:        g.usage(resp),
253			FinishReason: finishReason,
254		}, nil
255	}
256}
257
258func (g *geminiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
259	messages = g.cleanMessages(messages)
260	return g.stream(ctx, model, messages, tools)
261}
262
263func (g *geminiProvider) stream(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
264	// Convert messages
265	geminiMessages := g.convertMessages(messages)
266
267	model := g.Model(modelID)
268	if g.debug {
269		jsonData, _ := json.Marshal(geminiMessages)
270		slog.Debug("Prepared messages", "messages", string(jsonData))
271	}
272
273	maxTokens := model.DefaultMaxTokens
274	if g.maxTokens > 0 {
275		maxTokens = g.maxTokens
276	}
277
278	systemMessage := g.systemMessage
279	if g.systemPromptPrefix != "" {
280		systemMessage = g.systemPromptPrefix + "\n" + systemMessage
281	}
282
283	history := geminiMessages[:len(geminiMessages)-1] // All but last message
284	lastMsg := geminiMessages[len(geminiMessages)-1]
285	config := &genai.GenerateContentConfig{
286		MaxOutputTokens: int32(maxTokens),
287		SystemInstruction: &genai.Content{
288			Parts: []*genai.Part{{Text: systemMessage}},
289		},
290	}
291	config.Tools = g.convertTools(tools)
292	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
293
294	attempts := 0
295	eventChan := make(chan ProviderEvent)
296
297	go func() {
298		defer close(eventChan)
299
300		for {
301			attempts++
302
303			currentContent := ""
304			toolCalls := []message.ToolCall{}
305			var finalResp *genai.GenerateContentResponse
306
307			eventChan <- ProviderEvent{Type: EventContentStart}
308
309			var lastMsgParts []genai.Part
310
311			for _, part := range lastMsg.Parts {
312				lastMsgParts = append(lastMsgParts, *part)
313			}
314			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
315				if err != nil {
316					retry, after, retryErr := g.shouldRetry(attempts, err)
317					if retryErr != nil {
318						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
319						return
320					}
321					if retry {
322						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
323						select {
324						case <-ctx.Done():
325							if ctx.Err() != nil {
326								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
327							}
328
329							return
330						case <-time.After(time.Duration(after) * time.Millisecond):
331							break
332						}
333					} else {
334						eventChan <- ProviderEvent{Type: EventError, Error: err}
335						return
336					}
337				}
338
339				finalResp = resp
340
341				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
342					for _, part := range resp.Candidates[0].Content.Parts {
343						switch {
344						case part.Text != "":
345							delta := string(part.Text)
346							if delta != "" {
347								eventChan <- ProviderEvent{
348									Type:    EventContentDelta,
349									Content: delta,
350								}
351								currentContent += delta
352							}
353						case part.FunctionCall != nil:
354							id := "call_" + uuid.New().String()
355							args, _ := json.Marshal(part.FunctionCall.Args)
356							newCall := message.ToolCall{
357								ID:       id,
358								Name:     part.FunctionCall.Name,
359								Input:    string(args),
360								Type:     "function",
361								Finished: true,
362							}
363
364							isNew := true
365							for _, existing := range toolCalls {
366								if existing.Name == newCall.Name && existing.Input == newCall.Input {
367									isNew = false
368									break
369								}
370							}
371
372							if isNew {
373								toolCalls = append(toolCalls, newCall)
374							}
375						}
376					}
377				}
378			}
379
380			eventChan <- ProviderEvent{Type: EventContentStop}
381
382			if finalResp != nil {
383				finishReason := message.FinishReasonEndTurn
384				if len(finalResp.Candidates) > 0 {
385					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
386				}
387				if len(toolCalls) > 0 {
388					finishReason = message.FinishReasonToolUse
389				}
390				eventChan <- ProviderEvent{
391					Type: EventComplete,
392					Response: &ProviderResponse{
393						Content:      currentContent,
394						ToolCalls:    toolCalls,
395						Usage:        g.usage(finalResp),
396						FinishReason: finishReason,
397					},
398				}
399				return
400			}
401		}
402	}()
403
404	return eventChan
405}
406
407func (g *geminiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
408	// Check if error is a rate limit error
409	if attempts > maxRetries {
410		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
411	}
412
413	// Gemini doesn't have a standard error type we can check against
414	// So we'll check the error message for rate limit indicators
415	if errors.Is(err, io.EOF) {
416		return false, 0, err
417	}
418
419	errMsg := err.Error()
420	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
421
422	// Check for token expiration (401 Unauthorized)
423	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
424		g.apiKey, err = g.resolver.ResolveValue(g.config.APIKey)
425		if err != nil {
426			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
427		}
428		g.client, err = createGeminiClient(g.baseProvider)
429		if err != nil {
430			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
431		}
432		return true, 0, nil
433	}
434
435	// Check for common rate limit error messages
436
437	if !isRateLimit {
438		return false, 0, err
439	}
440
441	// Calculate backoff with jitter
442	backoffMs := 2000 * (1 << (attempts - 1))
443	jitterMs := int(float64(backoffMs) * 0.2)
444	retryMs := backoffMs + jitterMs
445
446	return true, int64(retryMs), nil
447}
448
449func (g *geminiProvider) usage(resp *genai.GenerateContentResponse) TokenUsage {
450	if resp == nil || resp.UsageMetadata == nil {
451		return TokenUsage{}
452	}
453
454	return TokenUsage{
455		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
456		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
457		CacheCreationTokens: 0, // Not directly provided by Gemini
458		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
459	}
460}
461
462// Helper functions
463func parseJSONToMap(jsonStr string) (map[string]any, error) {
464	var result map[string]any
465	err := json.Unmarshal([]byte(jsonStr), &result)
466	return result, err
467}
468
469func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
470	properties := make(map[string]*genai.Schema)
471
472	for name, param := range parameters {
473		properties[name] = convertToSchema(param)
474	}
475
476	return properties
477}
478
479func convertToSchema(param any) *genai.Schema {
480	schema := &genai.Schema{Type: genai.TypeString}
481
482	paramMap, ok := param.(map[string]any)
483	if !ok {
484		return schema
485	}
486
487	if desc, ok := paramMap["description"].(string); ok {
488		schema.Description = desc
489	}
490
491	typeVal, hasType := paramMap["type"]
492	if !hasType {
493		return schema
494	}
495
496	typeStr, ok := typeVal.(string)
497	if !ok {
498		return schema
499	}
500
501	schema.Type = mapJSONTypeToGenAI(typeStr)
502
503	switch typeStr {
504	case "array":
505		schema.Items = processArrayItems(paramMap)
506	case "object":
507		if props, ok := paramMap["properties"].(map[string]any); ok {
508			schema.Properties = convertSchemaProperties(props)
509		}
510	}
511
512	return schema
513}
514
515func processArrayItems(paramMap map[string]any) *genai.Schema {
516	items, ok := paramMap["items"].(map[string]any)
517	if !ok {
518		return nil
519	}
520
521	return convertToSchema(items)
522}
523
524func mapJSONTypeToGenAI(jsonType string) genai.Type {
525	switch jsonType {
526	case "string":
527		return genai.TypeString
528	case "number":
529		return genai.TypeNumber
530	case "integer":
531		return genai.TypeInteger
532	case "boolean":
533		return genai.TypeBoolean
534	case "array":
535		return genai.TypeArray
536	case "object":
537		return genai.TypeObject
538	default:
539		return genai.TypeString // Default to string for unknown types
540	}
541}
542
543func contains(s string, substrs ...string) bool {
544	for _, substr := range substrs {
545		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
546			return true
547		}
548	}
549	return false
550}