1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22type geminiClient struct {
 23	providerOptions providerClientOptions
 24	client          *genai.Client
 25}
 26
 27type GeminiClient ProviderClient
 28
 29func newGeminiClient(opts providerClientOptions) GeminiClient {
 30	client, err := createGeminiClient(opts)
 31	if err != nil {
 32		slog.Error("Failed to create Gemini client", "error", err)
 33		return nil
 34	}
 35
 36	return &geminiClient{
 37		providerOptions: opts,
 38		client:          client,
 39	}
 40}
 41
 42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 43	cc := &genai.ClientConfig{
 44		APIKey:  opts.apiKey,
 45		Backend: genai.BackendGeminiAPI,
 46	}
 47	if opts.baseURL != "" {
 48		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 49		if err == nil && resolvedBaseURL != "" {
 50			cc.HTTPOptions = genai.HTTPOptions{
 51				BaseURL: resolvedBaseURL,
 52			}
 53		}
 54	}
 55	if config.Get().Options.Debug {
 56		cc.HTTPClient = log.NewHTTPClient()
 57	}
 58	client, err := genai.NewClient(context.Background(), cc)
 59	if err != nil {
 60		return nil, err
 61	}
 62	return client, nil
 63}
 64
 65func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 66	var history []*genai.Content
 67	for _, msg := range messages {
 68		switch msg.Role {
 69		case message.User:
 70			var parts []*genai.Part
 71			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 72			for _, binaryContent := range msg.BinaryContent() {
 73				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
 74					MIMEType: binaryContent.MIMEType,
 75					Data:     binaryContent.Data,
 76				}})
 77			}
 78			history = append(history, &genai.Content{
 79				Parts: parts,
 80				Role:  genai.RoleUser,
 81			})
 82		case message.Assistant:
 83			var assistantParts []*genai.Part
 84
 85			if msg.Content().String() != "" {
 86				assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
 87			}
 88
 89			if len(msg.ToolCalls()) > 0 {
 90				for _, call := range msg.ToolCalls() {
 91					if !call.Finished {
 92						continue
 93					}
 94					args, _ := parseJSONToMap(call.Input)
 95					assistantParts = append(assistantParts, &genai.Part{
 96						FunctionCall: &genai.FunctionCall{
 97							Name: call.Name,
 98							Args: args,
 99						},
100					})
101				}
102			}
103
104			if len(assistantParts) > 0 {
105				history = append(history, &genai.Content{
106					Role:  genai.RoleModel,
107					Parts: assistantParts,
108				})
109			}
110
111		case message.Tool:
112			var toolParts []*genai.Part
113			for _, result := range msg.ToolResults() {
114				response := map[string]any{"result": result.Content}
115				parsed, err := parseJSONToMap(result.Content)
116				if err == nil {
117					response = parsed
118				}
119
120				var toolCall message.ToolCall
121				for _, m := range messages {
122					if m.Role == message.Assistant {
123						for _, call := range m.ToolCalls() {
124							if call.ID == result.ToolCallID {
125								toolCall = call
126								break
127							}
128						}
129					}
130				}
131
132				toolParts = append(toolParts, &genai.Part{
133					FunctionResponse: &genai.FunctionResponse{
134						Name:     toolCall.Name,
135						Response: response,
136					},
137				})
138			}
139			if len(toolParts) > 0 {
140				history = append(history, &genai.Content{
141					Parts: toolParts,
142					Role:  genai.RoleUser,
143				})
144			}
145		}
146	}
147
148	return history
149}
150
151func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
152	geminiTool := &genai.Tool{}
153	geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
154
155	for _, tool := range tools {
156		info := tool.Info()
157		declaration := &genai.FunctionDeclaration{
158			Name:        info.Name,
159			Description: info.Description,
160			Parameters: &genai.Schema{
161				Type:       genai.TypeObject,
162				Properties: convertSchemaProperties(info.Parameters),
163				Required:   info.Required,
164			},
165		}
166
167		geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
168	}
169
170	return []*genai.Tool{geminiTool}
171}
172
173func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
174	switch reason {
175	case genai.FinishReasonStop:
176		return message.FinishReasonEndTurn
177	case genai.FinishReasonMaxTokens:
178		return message.FinishReasonMaxTokens
179	default:
180		return message.FinishReasonUnknown
181	}
182}
183
184func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
185	// Convert messages
186	geminiMessages := g.convertMessages(messages)
187	model := g.providerOptions.model(g.providerOptions.modelType)
188	cfg := config.Get()
189
190	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
191	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
192		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
193	}
194
195	maxTokens := model.DefaultMaxTokens
196	if modelConfig.MaxTokens > 0 {
197		maxTokens = modelConfig.MaxTokens
198	}
199	systemMessage := g.providerOptions.systemMessage
200	if g.providerOptions.systemPromptPrefix != "" {
201		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
202	}
203	history := geminiMessages[:len(geminiMessages)-1] // All but last message
204	lastMsg := geminiMessages[len(geminiMessages)-1]
205	config := &genai.GenerateContentConfig{
206		MaxOutputTokens: int32(maxTokens),
207		SystemInstruction: &genai.Content{
208			Parts: []*genai.Part{{Text: systemMessage}},
209		},
210	}
211	config.Tools = g.convertTools(tools)
212	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
213
214	attempts := 0
215	for {
216		attempts++
217		var toolCalls []message.ToolCall
218
219		var lastMsgParts []genai.Part
220		for _, part := range lastMsg.Parts {
221			lastMsgParts = append(lastMsgParts, *part)
222		}
223		resp, err := chat.SendMessage(ctx, lastMsgParts...)
224		// If there is an error we are going to see if we can retry the call
225		if err != nil {
226			retry, after, retryErr := g.shouldRetry(attempts, err)
227			if retryErr != nil {
228				return nil, retryErr
229			}
230			if retry {
231				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
232				select {
233				case <-ctx.Done():
234					return nil, ctx.Err()
235				case <-time.After(time.Duration(after) * time.Millisecond):
236					continue
237				}
238			}
239			return nil, retryErr
240		}
241
242		content := ""
243
244		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
245			for _, part := range resp.Candidates[0].Content.Parts {
246				switch {
247				case part.Text != "":
248					content = string(part.Text)
249				case part.FunctionCall != nil:
250					id := "call_" + uuid.New().String()
251					args, _ := json.Marshal(part.FunctionCall.Args)
252					toolCalls = append(toolCalls, message.ToolCall{
253						ID:       id,
254						Name:     part.FunctionCall.Name,
255						Input:    string(args),
256						Type:     "function",
257						Finished: true,
258					})
259				}
260			}
261		}
262		finishReason := message.FinishReasonEndTurn
263		if len(resp.Candidates) > 0 {
264			finishReason = g.finishReason(resp.Candidates[0].FinishReason)
265		}
266		if len(toolCalls) > 0 {
267			finishReason = message.FinishReasonToolUse
268		}
269
270		return &ProviderResponse{
271			Content:      content,
272			ToolCalls:    toolCalls,
273			Usage:        g.usage(resp),
274			FinishReason: finishReason,
275		}, nil
276	}
277}
278
279func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
280	// Convert messages
281	geminiMessages := g.convertMessages(messages)
282
283	model := g.providerOptions.model(g.providerOptions.modelType)
284	cfg := config.Get()
285
286	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
287	if g.providerOptions.modelType == config.SelectedModelTypeSmall {
288		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
289	}
290	maxTokens := model.DefaultMaxTokens
291	if modelConfig.MaxTokens > 0 {
292		maxTokens = modelConfig.MaxTokens
293	}
294
295	// Override max tokens if set in provider options
296	if g.providerOptions.maxTokens > 0 {
297		maxTokens = g.providerOptions.maxTokens
298	}
299	systemMessage := g.providerOptions.systemMessage
300	if g.providerOptions.systemPromptPrefix != "" {
301		systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
302	}
303	history := geminiMessages[:len(geminiMessages)-1] // All but last message
304	lastMsg := geminiMessages[len(geminiMessages)-1]
305	config := &genai.GenerateContentConfig{
306		MaxOutputTokens: int32(maxTokens),
307		SystemInstruction: &genai.Content{
308			Parts: []*genai.Part{{Text: systemMessage}},
309		},
310	}
311	config.Tools = g.convertTools(tools)
312	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
313
314	attempts := 0
315	eventChan := make(chan ProviderEvent)
316
317	go func() {
318		defer close(eventChan)
319
320		for {
321			attempts++
322
323			currentContent := ""
324			toolCalls := []message.ToolCall{}
325			var finalResp *genai.GenerateContentResponse
326
327			eventChan <- ProviderEvent{Type: EventContentStart}
328
329			var lastMsgParts []genai.Part
330
331			for _, part := range lastMsg.Parts {
332				lastMsgParts = append(lastMsgParts, *part)
333			}
334
335			for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
336				if err != nil {
337					retry, after, retryErr := g.shouldRetry(attempts, err)
338					if retryErr != nil {
339						eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
340						return
341					}
342					if retry {
343						slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
344						select {
345						case <-ctx.Done():
346							if ctx.Err() != nil {
347								eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
348							}
349
350							return
351						case <-time.After(time.Duration(after) * time.Millisecond):
352							continue
353						}
354					} else {
355						eventChan <- ProviderEvent{Type: EventError, Error: err}
356						return
357					}
358				}
359
360				finalResp = resp
361
362				if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
363					for _, part := range resp.Candidates[0].Content.Parts {
364						switch {
365						case part.Text != "":
366							delta := string(part.Text)
367							if delta != "" {
368								eventChan <- ProviderEvent{
369									Type:    EventContentDelta,
370									Content: delta,
371								}
372								currentContent += delta
373							}
374						case part.FunctionCall != nil:
375							id := "call_" + uuid.New().String()
376							args, _ := json.Marshal(part.FunctionCall.Args)
377							newCall := message.ToolCall{
378								ID:       id,
379								Name:     part.FunctionCall.Name,
380								Input:    string(args),
381								Type:     "function",
382								Finished: true,
383							}
384
385							toolCalls = append(toolCalls, newCall)
386						}
387					}
388				} else {
389					// no content received
390					break
391				}
392			}
393
394			eventChan <- ProviderEvent{Type: EventContentStop}
395
396			if finalResp != nil {
397				finishReason := message.FinishReasonEndTurn
398				if len(finalResp.Candidates) > 0 {
399					finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
400				}
401				if len(toolCalls) > 0 {
402					finishReason = message.FinishReasonToolUse
403				}
404				eventChan <- ProviderEvent{
405					Type: EventComplete,
406					Response: &ProviderResponse{
407						Content:      currentContent,
408						ToolCalls:    toolCalls,
409						Usage:        g.usage(finalResp),
410						FinishReason: finishReason,
411					},
412				}
413				return
414			} else {
415				eventChan <- ProviderEvent{
416					Type:  EventError,
417					Error: errors.New("no content received"),
418				}
419			}
420		}
421	}()
422
423	return eventChan
424}
425
426func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
427	// Check if error is a rate limit error
428	if attempts > maxRetries {
429		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
430	}
431
432	// Gemini doesn't have a standard error type we can check against
433	// So we'll check the error message for rate limit indicators
434	if errors.Is(err, io.EOF) {
435		return false, 0, err
436	}
437
438	errMsg := err.Error()
439	isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
440
441	// Check for token expiration (401 Unauthorized)
442	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
443		prev := g.providerOptions.apiKey
444		// in case the key comes from a script, we try to re-evaluate it.
445		g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
446		if err != nil {
447			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
448		}
449		// if it didn't change, do not retry.
450		if prev == g.providerOptions.apiKey {
451			return false, 0, err
452		}
453		g.client, err = createGeminiClient(g.providerOptions)
454		if err != nil {
455			return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
456		}
457		return true, 0, nil
458	}
459
460	// Check for common rate limit error messages
461
462	if !isRateLimit {
463		return false, 0, err
464	}
465
466	// Calculate backoff with jitter
467	backoffMs := 2000 * (1 << (attempts - 1))
468	jitterMs := int(float64(backoffMs) * 0.2)
469	retryMs := backoffMs + jitterMs
470
471	return true, int64(retryMs), nil
472}
473
474func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
475	if resp == nil || resp.UsageMetadata == nil {
476		return TokenUsage{}
477	}
478
479	return TokenUsage{
480		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
481		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
482		CacheCreationTokens: 0, // Not directly provided by Gemini
483		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
484	}
485}
486
487func (g *geminiClient) Model() catwalk.Model {
488	return g.providerOptions.model(g.providerOptions.modelType)
489}
490
491// Helper functions
492func parseJSONToMap(jsonStr string) (map[string]any, error) {
493	var result map[string]any
494	err := json.Unmarshal([]byte(jsonStr), &result)
495	return result, err
496}
497
498func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
499	properties := make(map[string]*genai.Schema)
500
501	for name, param := range parameters {
502		properties[name] = convertToSchema(param)
503	}
504
505	return properties
506}
507
508func convertToSchema(param any) *genai.Schema {
509	schema := &genai.Schema{Type: genai.TypeString}
510
511	paramMap, ok := param.(map[string]any)
512	if !ok {
513		return schema
514	}
515
516	if desc, ok := paramMap["description"].(string); ok {
517		schema.Description = desc
518	}
519
520	typeVal, hasType := paramMap["type"]
521	if !hasType {
522		return schema
523	}
524
525	typeStr, ok := typeVal.(string)
526	if !ok {
527		return schema
528	}
529
530	schema.Type = mapJSONTypeToGenAI(typeStr)
531
532	switch typeStr {
533	case "array":
534		schema.Items = processArrayItems(paramMap)
535	case "object":
536		if props, ok := paramMap["properties"].(map[string]any); ok {
537			schema.Properties = convertSchemaProperties(props)
538		}
539	}
540
541	return schema
542}
543
544func processArrayItems(paramMap map[string]any) *genai.Schema {
545	items, ok := paramMap["items"].(map[string]any)
546	if !ok {
547		return nil
548	}
549
550	return convertToSchema(items)
551}
552
553func mapJSONTypeToGenAI(jsonType string) genai.Type {
554	switch jsonType {
555	case "string":
556		return genai.TypeString
557	case "number":
558		return genai.TypeNumber
559	case "integer":
560		return genai.TypeInteger
561	case "boolean":
562		return genai.TypeBoolean
563	case "array":
564		return genai.TypeArray
565	case "object":
566		return genai.TypeObject
567	default:
568		return genai.TypeString // Default to string for unknown types
569	}
570}
571
572func contains(s string, substrs ...string) bool {
573	for _, substr := range substrs {
574		if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
575			return true
576		}
577	}
578	return false
579}