1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46	}
 47	if config.Get().Options.Debug {
 48		cc.HTTPClient = log.NewHTTPClient()
 49	}
 50	client, err := genai.NewClient(context.Background(), cc)
 51	if err != nil {
 52		return nil, err
 53	}
 54	return client, nil
 55}
 56
 57func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 58	var history []*genai.Content
 59	for _, msg := range messages {
 60		switch msg.Role {
 61		case message.User:
 62			var parts []*genai.Part
 63			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 64			for _, binaryContent := range msg.BinaryContent() {
 65				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 66				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 67					MIMEType: imageFormat[1],
 68					Data:     binaryContent.Data,
 69				}})
 70			}
 71			history = append(history, &genai.Content{
 72				Parts: parts,
 73				Role:  genai.RoleUser,
 74			})
 75		case message.Assistant:
 76			var assistantParts []*genai.Part
 77
 78			if msg.Content().String() != "" {
 79				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 80			}
 81
 82			if len(msg.ToolCalls()) > 0 {
 83				for _, call := range msg.ToolCalls() {
 84					args, _ := parseJSONToMap(call.Input)
 85					assistantParts = append(assistantParts, &genai.Part{
 86						FunctionCall: &genai.FunctionCall{
 87							Name: call.Name,
 88							Args: args,
 89						},
 90					})
 91				}
 92			}
 93
 94			if len(assistantParts) > 0 {
 95				history = append(history, &genai.Content{
 96					Role:  genai.RoleModel,
 97					Parts: assistantParts,
 98				})
 99			}
100
101		case message.Tool:
102			for _, result := range msg.ToolResults() {
103				response := map[string]any{"result": result.Content}
104				parsed, err := parseJSONToMap(result.Content)
105				if err == nil {
106					response = parsed
107				}
108
109				var toolCall message.ToolCall
110				for _, m := range messages {
111					if m.Role == message.Assistant {
112						for _, call := range m.ToolCalls() {
113							if call.ID == result.ToolCallID {
114								toolCall = call
115								break
116							}
117						}
118					}
119				}
120
121				history = append(history, &genai.Content{
122					Parts: []*genai.Part{
123						{
124							FunctionResponse: &genai.FunctionResponse{
125								Name:     toolCall.Name,
126								Response: response,
127							},
128						},
129					},
130					Role: genai.RoleModel,
131				})
132			}
133		}
134	}
135
136	return history
137}
138
139func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
140	geminiTool := &genai.Tool{}
141	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
142
143	for _, tool := range tools {
144		info := tool.Info()
145		declaration := &genai.FunctionDeclaration{
146			Name:        info.Name,
147			Description: info.Description,
148			Parameters: &genai.Schema{
149				Type:       genai.TypeObject,
150				Properties: convertSchemaProperties(info.Parameters),
151				Required:   info.Required,
152			},
153		}
154
155		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
156	}
157
158	return []*genai.Tool{geminiTool}
159}
160
161func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
162	switch reason {
163	case genai.FinishReasonStop:
164		return message.FinishReasonEndTurn
165	case genai.FinishReasonMaxTokens:
166		return message.FinishReasonMaxTokens
167	default:
168		return message.FinishReasonUnknown
169	}
170}
171
172func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
173	// Convert messages
174	geminiMessages := g.convertMessages(messages)
175	model := g.providerOptions.model(g.providerOptions.modelType)
176	cfg := config.Get()
177
178	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
179	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
180		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
181	}
182
183	maxTokens := model.DefaultMaxTokens
184	if modelConfig.MaxTokens > 0 {
185		maxTokens = modelConfig.MaxTokens
186	}
187	systemMessage := g.providerOptions.systemMessage
188	if g.providerOptions.systemPromptPrefix != "" {
189		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
190	}
191	history := geminiMessages[:len(geminiMessages)-1] // All but last message
192	lastMsg := geminiMessages[len(geminiMessages)-1]
193	config := &genai.GenerateContentConfig{
194		MaxOutputTokens: int32(maxTokens),
195		SystemInstruction: &genai.Content{
196			Parts: []*genai.Part{{Text: systemMessage}},
197		},
198	}
199	config.Tools = g.convertTools(tools)
200	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
201
202	attempts := 0
203	for {
204		attempts++
205		var toolCalls []message.ToolCall
206
207		var lastMsgParts []genai.Part
208		for _, part := range lastMsg.Parts {
209			lastMsgParts = append(lastMsgParts, *part)
210		}
211		resp, err := chat.SendMessage(ctx, lastMsgParts...)
212		// If there is an error we are going to see if we can retry the call
213		if err != nil {
214			retry, after, retryErr := g.shouldRetry(attempts, err)
215			if retryErr != nil {
216				return nil, retryErr
217			}
218			if retry {
219				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
220				select {
221				case <-ctx.Done():
222					return nil, ctx.Err()
223				case <-time.After(time.Duration(after) * time.Millisecond):
224					continue
225				}
226			}
227			return nil, retryErr
228		}
229
230		content := ""
231
232		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
233			for _, part := range resp.Candidates[0].Content.Parts {
234				switch {
235				case part.Text != "":
236					content = string(part.Text)
237				case part.FunctionCall != nil:
238					id := "call_" + uuid.New().String()
239					args, _ := json.Marshal(part.FunctionCall.Args)
240					toolCalls = append(toolCalls, message.ToolCall{
241						ID:       id,
242						Name:     part.FunctionCall.Name,
243						Input:    string(args),
244						Type:     "function",
245						Finished: true,
246					})
247				}
248			}
249		}
250		finishReason := message.FinishReasonEndTurn
251		if len(resp.Candidates) > 0 {
252			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
253		}
254		if len(toolCalls) > 0 {
255			finishReason = message.FinishReasonToolUse
256		}
257
258		return &ProviderResponse{
259			Content:      content,
260			ToolCalls:    toolCalls,
261			Usage:        g.usage(resp),
262			FinishReason: finishReason,
263		}, nil
264	}
265}
266
267func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
268	// Convert messages
269	geminiMessages := g.convertMessages(messages)
270
271	model := g.providerOptions.model(g.providerOptions.modelType)
272	cfg := config.Get()
273
274	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
275	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
276		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
277	}
278	maxTokens := model.DefaultMaxTokens
279	if modelConfig.MaxTokens > 0 {
280		maxTokens = modelConfig.MaxTokens
281	}
282
283	// Override max tokens if set in provider options
284	if g.providerOptions.maxTokens > 0 {
285		maxTokens = g.providerOptions.maxTokens
286	}
287	systemMessage := g.providerOptions.systemMessage
288	if g.providerOptions.systemPromptPrefix != "" {
289		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
290	}
291	history := geminiMessages[:len(geminiMessages)-1] // All but last message
292	lastMsg := geminiMessages[len(geminiMessages)-1]
293	config := &genai.GenerateContentConfig{
294		MaxOutputTokens: int32(maxTokens),
295		SystemInstruction: &genai.Content{
296			Parts: []*genai.Part{{Text: systemMessage}},
297		},
298	}
299	config.Tools = g.convertTools(tools)
300	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
301
302	attempts := 0
303	eventChan := make(chan ProviderEvent)
304
305	go func() {
306		defer close(eventChan)
307
308		for {
309			attempts++
310
311			currentContent := ""
312			toolCalls := []message.ToolCall{}
313			var finalResp *genai.GenerateContentResponse
314
315			eventChan <- ProviderEvent{Type: EventContentStart}
316
317			var lastMsgParts []genai.Part
318
319			for _, part := range lastMsg.Parts {
320				lastMsgParts = append(lastMsgParts, *part)
321			}
322			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
323				if err != nil {
324					retry, after, retryErr := g.shouldRetry(attempts, err)
325					if retryErr != nil {
326						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
327						return
328					}
329					if retry {
330						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
331						select {
332						case <-ctx.Done():
333							if ctx.Err() != nil {
334								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
335							}
336
337							return
338						case <-time.After(time.Duration(after) * time.Millisecond):
339							continue
340						}
341					} else {
342						eventChan <- ProviderEvent{Type: EventError, Error: err}
343						return
344					}
345				}
346
347				finalResp = resp
348
349				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
350					for _, part := range resp.Candidates[0].Content.Parts {
351						switch {
352						case part.Text != "":
353							delta := string(part.Text)
354							if delta != "" {
355								eventChan <- ProviderEvent{
356									Type:    EventContentDelta,
357									Content: delta,
358								}
359								currentContent += delta
360							}
361						case part.FunctionCall != nil:
362							id := "call_" + uuid.New().String()
363							args, _ := json.Marshal(part.FunctionCall.Args)
364							newCall := message.ToolCall{
365								ID:       id,
366								Name:     part.FunctionCall.Name,
367								Input:    string(args),
368								Type:     "function",
369								Finished: true,
370							}
371
372							isNew := true
373							for _, existing := range toolCalls {
374								if existing.Name == newCall.Name && existing.Input == newCall.Input {
375									isNew = false
376									break
377								}
378							}
379
380							if isNew {
381								toolCalls = append(toolCalls, newCall)
382							}
383						}
384					}
385				}
386			}
387
388			eventChan <- ProviderEvent{Type: EventContentStop}
389
390			if finalResp != nil {
391				finishReason := message.FinishReasonEndTurn
392				if len(finalResp.Candidates) > 0 {
393					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
394				}
395				if len(toolCalls) > 0 {
396					finishReason = message.FinishReasonToolUse
397				}
398				eventChan <- ProviderEvent{
399					Type: EventComplete,
400					Response: &ProviderResponse{
401						Content:      currentContent,
402						ToolCalls:    toolCalls,
403						Usage:        g.usage(finalResp),
404						FinishReason: finishReason,
405					},
406				}
407				return
408			}
409		}
410	}()
411
412	return eventChan
413}
414
415func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
416	// Check if error is a rate limit error
417	if attempts > maxRetries {
418		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
419	}
420
421	// Gemini doesn't have a standard error type we can check against
422	// So we'll check the error message for rate limit indicators
423	if errors.Is(err, io.EOF) {
424		return false, 0, err
425	}
426
427	errMsg := err.Error()
428	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
429
430	// Check for token expiration (401 Unauthorized)
431	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
432		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
433		if err != nil {
434			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
435		}
436		g.client, err = createGeminiClient(g.providerOptions)
437		if err != nil {
438			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
439		}
440		return true, 0, nil
441	}
442
443	// Check for common rate limit error messages
444
445	if !isRateLimit {
446		return false, 0, err
447	}
448
449	// Calculate backoff with jitter
450	backoffMs := 2000 * (1 << (attempts - 1))
451	jitterMs := int(float64(backoffMs) * 0.2)
452	retryMs := backoffMs + jitterMs
453
454	return true, int64(retryMs), nil
455}
456
457func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
458	if resp == nil || resp.UsageMetadata == nil {
459		return TokenUsage{}
460	}
461
462	return TokenUsage{
463		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
464		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
465		CacheCreationTokens: 0, // Not directly provided by Gemini
466		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
467	}
468}
469
470func (g *geminiClient) Model() catwalk.Model {
471	return g.providerOptions.model(g.providerOptions.modelType)
472}
473
474// Helper functions
475func parseJSONToMap(jsonStr string) (map[string]any, error) {
476	var result map[string]any
477	err := json.Unmarshal([]byte(jsonStr), &result)
478	return result, err
479}
480
481func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
482	properties := make(map[string]*genai.Schema)
483
484	for name, param := range parameters {
485		properties[name] = convertToSchema(param)
486	}
487
488	return properties
489}
490
491func convertToSchema(param any) *genai.Schema {
492	schema := &genai.Schema{Type: genai.TypeString}
493
494	paramMap, ok := param.(map[string]any)
495	if !ok {
496		return schema
497	}
498
499	if desc, ok := paramMap["description"].(string); ok {
500		schema.Description = desc
501	}
502
503	typeVal, hasType := paramMap["type"]
504	if !hasType {
505		return schema
506	}
507
508	typeStr, ok := typeVal.(string)
509	if !ok {
510		return schema
511	}
512
513	schema.Type = mapJSONTypeToGenAI(typeStr)
514
515	switch typeStr {
516	case "array":
517		schema.Items = processArrayItems(paramMap)
518	case "object":
519		if props, ok := paramMap["properties"].(map[string]any); ok {
520			schema.Properties = convertSchemaProperties(props)
521		}
522	}
523
524	return schema
525}
526
527func processArrayItems(paramMap map[string]any) *genai.Schema {
528	items, ok := paramMap["items"].(map[string]any)
529	if !ok {
530		return nil
531	}
532
533	return convertToSchema(items)
534}
535
536func mapJSONTypeToGenAI(jsonType string) genai.Type {
537	switch jsonType {
538	case "string":
539		return genai.TypeString
540	case "number":
541		return genai.TypeNumber
542	case "integer":
543		return genai.TypeInteger
544	case "boolean":
545		return genai.TypeBoolean
546	case "array":
547		return genai.TypeArray
548	case "object":
549		return genai.TypeObject
550	default:
551		return genai.TypeString // Default to string for unknown types
552	}
553}
554
555func contains(s string, substrs ...string) bool {
556	for _, substr := range substrs {
557		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
558			return true
559		}
560	}
561	return false
562}