1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22type geminiClient struct {
23 providerOptions providerClientOptions
24 client *genai.Client
25}
26
27type GeminiClient ProviderClient
28
29func newGeminiClient(opts providerClientOptions) GeminiClient {
30 client, err := createGeminiClient(opts)
31 if err != nil {
32 slog.Error("Failed to create Gemini client", "error", err)
33 return nil
34 }
35
36 return &geminiClient{
37 providerOptions: opts,
38 client: client,
39 }
40}
41
42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
43 cc := &genai.ClientConfig{
44 APIKey: opts.apiKey,
45 Backend: genai.BackendGeminiAPI,
46 }
47 if opts.baseURL != "" {
48 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
49 if err == nil && resolvedBaseURL != "" {
50 cc.HTTPOptions = genai.HTTPOptions{
51 BaseURL: resolvedBaseURL,
52 }
53 }
54 }
55 if config.Get().Options.Debug {
56 cc.HTTPClient = log.NewHTTPClient()
57 }
58 client, err := genai.NewClient(context.Background(), cc)
59 if err != nil {
60 return nil, err
61 }
62 return client, nil
63}
64
65func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
66 var history []*genai.Content
67 for _, msg := range messages {
68 switch msg.Role {
69 case message.User:
70 var parts []*genai.Part
71 parts = append(parts, &genai.Part{Text: msg.Content().String()})
72 for _, binaryContent := range msg.BinaryContent() {
73 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
74 MIMEType: binaryContent.MIMEType,
75 Data: binaryContent.Data,
76 }})
77 }
78 history = append(history, &genai.Content{
79 Parts: parts,
80 Role: genai.RoleUser,
81 })
82 case message.Assistant:
83 var assistantParts []*genai.Part
84
85 if msg.Content().String() != "" {
86 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
87 }
88
89 if len(msg.ToolCalls()) > 0 {
90 for _, call := range msg.ToolCalls() {
91 if !call.Finished {
92 continue
93 }
94 args, _ := parseJSONToMap(call.Input)
95 assistantParts = append(assistantParts, &genai.Part{
96 FunctionCall: &genai.FunctionCall{
97 Name: call.Name,
98 Args: args,
99 },
100 })
101 }
102 }
103
104 if len(assistantParts) > 0 {
105 history = append(history, &genai.Content{
106 Role: genai.RoleModel,
107 Parts: assistantParts,
108 })
109 }
110
111 case message.Tool:
112 var toolParts []*genai.Part
113 for _, result := range msg.ToolResults() {
114 response := map[string]any{"result": result.Content}
115 parsed, err := parseJSONToMap(result.Content)
116 if err == nil {
117 response = parsed
118 }
119
120 var toolCall message.ToolCall
121 for _, m := range messages {
122 if m.Role == message.Assistant {
123 for _, call := range m.ToolCalls() {
124 if call.ID == result.ToolCallID {
125 toolCall = call
126 break
127 }
128 }
129 }
130 }
131
132 toolParts = append(toolParts, &genai.Part{
133 FunctionResponse: &genai.FunctionResponse{
134 Name: toolCall.Name,
135 Response: response,
136 },
137 })
138 }
139 if len(toolParts) > 0 {
140 history = append(history, &genai.Content{
141 Parts: toolParts,
142 Role: genai.RoleUser,
143 })
144 }
145 }
146 }
147
148 return history
149}
150
151func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
152 geminiTool := &genai.Tool{}
153 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
154
155 for _, tool := range tools {
156 info := tool.Info()
157 declaration := &genai.FunctionDeclaration{
158 Name: info.Name,
159 Description: info.Description,
160 Parameters: &genai.Schema{
161 Type: genai.TypeObject,
162 Properties: convertSchemaProperties(info.Parameters),
163 Required: info.Required,
164 },
165 }
166
167 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
168 }
169
170 return []*genai.Tool{geminiTool}
171}
172
173func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
174 switch reason {
175 case genai.FinishReasonStop:
176 return message.FinishReasonEndTurn
177 case genai.FinishReasonMaxTokens:
178 return message.FinishReasonMaxTokens
179 default:
180 return message.FinishReasonUnknown
181 }
182}
183
184func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
185 // Convert messages
186 geminiMessages := g.convertMessages(messages)
187 model := g.providerOptions.model(g.providerOptions.modelType)
188 cfg := config.Get()
189
190 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
191 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
192 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
193 }
194
195 maxTokens := model.DefaultMaxTokens
196 if modelConfig.MaxTokens > 0 {
197 maxTokens = modelConfig.MaxTokens
198 }
199 systemMessage := g.providerOptions.systemMessage
200 if g.providerOptions.systemPromptPrefix != "" {
201 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
202 }
203 history := geminiMessages[:len(geminiMessages)-1] // All but last message
204 lastMsg := geminiMessages[len(geminiMessages)-1]
205 config := &genai.GenerateContentConfig{
206 MaxOutputTokens: int32(maxTokens),
207 SystemInstruction: &genai.Content{
208 Parts: []*genai.Part{{Text: systemMessage}},
209 },
210 }
211 config.Tools = g.convertTools(tools)
212 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
213
214 attempts := 0
215 for {
216 attempts++
217 var toolCalls []message.ToolCall
218
219 var lastMsgParts []genai.Part
220 for _, part := range lastMsg.Parts {
221 lastMsgParts = append(lastMsgParts, *part)
222 }
223 resp, err := chat.SendMessage(ctx, lastMsgParts...)
224 // If there is an error we are going to see if we can retry the call
225 if err != nil {
226 retry, after, retryErr := g.shouldRetry(attempts, err)
227 if retryErr != nil {
228 return nil, retryErr
229 }
230 if retry {
231 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
232 select {
233 case <-ctx.Done():
234 return nil, ctx.Err()
235 case <-time.After(time.Duration(after) * time.Millisecond):
236 continue
237 }
238 }
239 return nil, retryErr
240 }
241
242 content := ""
243
244 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
245 for _, part := range resp.Candidates[0].Content.Parts {
246 switch {
247 case part.Text != "":
248 content = string(part.Text)
249 case part.FunctionCall != nil:
250 id := "call_" + uuid.New().String()
251 args, _ := json.Marshal(part.FunctionCall.Args)
252 toolCalls = append(toolCalls, message.ToolCall{
253 ID: id,
254 Name: part.FunctionCall.Name,
255 Input: string(args),
256 Type: "function",
257 Finished: true,
258 })
259 }
260 }
261 }
262 finishReason := message.FinishReasonEndTurn
263 if len(resp.Candidates) > 0 {
264 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
265 }
266 if len(toolCalls) > 0 {
267 finishReason = message.FinishReasonToolUse
268 }
269
270 return &ProviderResponse{
271 Content: content,
272 ToolCalls: toolCalls,
273 Usage: g.usage(resp),
274 FinishReason: finishReason,
275 }, nil
276 }
277}
278
279func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
280 // Convert messages
281 geminiMessages := g.convertMessages(messages)
282
283 model := g.providerOptions.model(g.providerOptions.modelType)
284 cfg := config.Get()
285
286 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
287 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
288 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
289 }
290 maxTokens := model.DefaultMaxTokens
291 if modelConfig.MaxTokens > 0 {
292 maxTokens = modelConfig.MaxTokens
293 }
294
295 // Override max tokens if set in provider options
296 if g.providerOptions.maxTokens > 0 {
297 maxTokens = g.providerOptions.maxTokens
298 }
299 systemMessage := g.providerOptions.systemMessage
300 if g.providerOptions.systemPromptPrefix != "" {
301 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
302 }
303 history := geminiMessages[:len(geminiMessages)-1] // All but last message
304 lastMsg := geminiMessages[len(geminiMessages)-1]
305 config := &genai.GenerateContentConfig{
306 MaxOutputTokens: int32(maxTokens),
307 SystemInstruction: &genai.Content{
308 Parts: []*genai.Part{{Text: systemMessage}},
309 },
310 }
311 config.Tools = g.convertTools(tools)
312 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
313
314 attempts := 0
315 eventChan := make(chan ProviderEvent)
316
317 go func() {
318 defer close(eventChan)
319
320 for {
321 attempts++
322
323 currentContent := ""
324 toolCalls := []message.ToolCall{}
325 var finalResp *genai.GenerateContentResponse
326
327 eventChan <- ProviderEvent{Type: EventContentStart}
328
329 var lastMsgParts []genai.Part
330
331 for _, part := range lastMsg.Parts {
332 lastMsgParts = append(lastMsgParts, *part)
333 }
334
335 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
336 if err != nil {
337 retry, after, retryErr := g.shouldRetry(attempts, err)
338 if retryErr != nil {
339 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
340 return
341 }
342 if retry {
343 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
344 select {
345 case <-ctx.Done():
346 if ctx.Err() != nil {
347 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
348 }
349
350 return
351 case <-time.After(time.Duration(after) * time.Millisecond):
352 continue
353 }
354 } else {
355 eventChan <- ProviderEvent{Type: EventError, Error: err}
356 return
357 }
358 }
359
360 finalResp = resp
361
362 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
363 for _, part := range resp.Candidates[0].Content.Parts {
364 switch {
365 case part.Text != "":
366 delta := string(part.Text)
367 if delta != "" {
368 eventChan <- ProviderEvent{
369 Type: EventContentDelta,
370 Content: delta,
371 }
372 currentContent += delta
373 }
374 case part.FunctionCall != nil:
375 id := "call_" + uuid.New().String()
376 args, _ := json.Marshal(part.FunctionCall.Args)
377 newCall := message.ToolCall{
378 ID: id,
379 Name: part.FunctionCall.Name,
380 Input: string(args),
381 Type: "function",
382 Finished: true,
383 }
384
385 toolCalls = append(toolCalls, newCall)
386 }
387 }
388 } else {
389 // no content received
390 break
391 }
392 }
393
394 eventChan <- ProviderEvent{Type: EventContentStop}
395
396 if finalResp != nil {
397 finishReason := message.FinishReasonEndTurn
398 if len(finalResp.Candidates) > 0 {
399 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
400 }
401 if len(toolCalls) > 0 {
402 finishReason = message.FinishReasonToolUse
403 }
404 eventChan <- ProviderEvent{
405 Type: EventComplete,
406 Response: &ProviderResponse{
407 Content: currentContent,
408 ToolCalls: toolCalls,
409 Usage: g.usage(finalResp),
410 FinishReason: finishReason,
411 },
412 }
413 return
414 } else {
415 eventChan <- ProviderEvent{
416 Type: EventError,
417 Error: errors.New("no content received"),
418 }
419 }
420 }
421 }()
422
423 return eventChan
424}
425
426func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
427 // Check if error is a rate limit error
428 if attempts > maxRetries {
429 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
430 }
431
432 // Gemini doesn't have a standard error type we can check against
433 // So we'll check the error message for rate limit indicators
434 if errors.Is(err, io.EOF) {
435 return false, 0, err
436 }
437
438 errMsg := err.Error()
439 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
440
441 // Check for token expiration (401 Unauthorized)
442 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
443 prev := g.providerOptions.apiKey
444 // in case the key comes from a script, we try to re-evaluate it.
445 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
446 if err != nil {
447 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
448 }
449 // if it didn't change, do not retry.
450 if prev == g.providerOptions.apiKey {
451 return false, 0, err
452 }
453 g.client, err = createGeminiClient(g.providerOptions)
454 if err != nil {
455 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
456 }
457 return true, 0, nil
458 }
459
460 // Check for common rate limit error messages
461
462 if !isRateLimit {
463 return false, 0, err
464 }
465
466 // Calculate backoff with jitter
467 backoffMs := 2000 * (1 << (attempts - 1))
468 jitterMs := int(float64(backoffMs) * 0.2)
469 retryMs := backoffMs + jitterMs
470
471 return true, int64(retryMs), nil
472}
473
474func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
475 if resp == nil || resp.UsageMetadata == nil {
476 return TokenUsage{}
477 }
478
479 return TokenUsage{
480 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
481 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
482 CacheCreationTokens: 0, // Not directly provided by Gemini
483 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
484 }
485}
486
487func (g *geminiClient) Model() catwalk.Model {
488 return g.providerOptions.model(g.providerOptions.modelType)
489}
490
491// Helper functions
492func parseJSONToMap(jsonStr string) (map[string]any, error) {
493 var result map[string]any
494 err := json.Unmarshal([]byte(jsonStr), &result)
495 return result, err
496}
497
498func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
499 properties := make(map[string]*genai.Schema)
500
501 for name, param := range parameters {
502 properties[name] = convertToSchema(param)
503 }
504
505 return properties
506}
507
508func convertToSchema(param any) *genai.Schema {
509 schema := &genai.Schema{Type: genai.TypeString}
510
511 paramMap, ok := param.(map[string]any)
512 if !ok {
513 return schema
514 }
515
516 if desc, ok := paramMap["description"].(string); ok {
517 schema.Description = desc
518 }
519
520 typeVal, hasType := paramMap["type"]
521 if !hasType {
522 return schema
523 }
524
525 typeStr, ok := typeVal.(string)
526 if !ok {
527 return schema
528 }
529
530 schema.Type = mapJSONTypeToGenAI(typeStr)
531
532 switch typeStr {
533 case "array":
534 schema.Items = processArrayItems(paramMap)
535 case "object":
536 if props, ok := paramMap["properties"].(map[string]any); ok {
537 schema.Properties = convertSchemaProperties(props)
538 }
539 }
540
541 return schema
542}
543
544func processArrayItems(paramMap map[string]any) *genai.Schema {
545 items, ok := paramMap["items"].(map[string]any)
546 if !ok {
547 return nil
548 }
549
550 return convertToSchema(items)
551}
552
553func mapJSONTypeToGenAI(jsonType string) genai.Type {
554 switch jsonType {
555 case "string":
556 return genai.TypeString
557 case "number":
558 return genai.TypeNumber
559 case "integer":
560 return genai.TypeInteger
561 case "boolean":
562 return genai.TypeBoolean
563 case "array":
564 return genai.TypeArray
565 case "object":
566 return genai.TypeObject
567 default:
568 return genai.TypeString // Default to string for unknown types
569 }
570}
571
572func contains(s string, substrs ...string) bool {
573 for _, substr := range substrs {
574 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
575 return true
576 }
577 }
578 return false
579}