1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/google/uuid"
18 "google.golang.org/genai"
19)
20
21type geminiClient struct {
22 providerOptions providerClientOptions
23 client *genai.Client
24}
25
26type GeminiClient ProviderClient
27
28func newGeminiClient(opts providerClientOptions) GeminiClient {
29 client, err := createGeminiClient(opts)
30 if err != nil {
31 slog.Error("Failed to create Gemini client", "error", err)
32 return nil
33 }
34
35 return &geminiClient{
36 providerOptions: opts,
37 client: client,
38 }
39}
40
41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
42 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
43 if err != nil {
44 return nil, err
45 }
46 return client, nil
47}
48
49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
50 var history []*genai.Content
51 for _, msg := range messages {
52 switch msg.Role {
53 case message.User:
54 var parts []*genai.Part
55 parts = append(parts, &genai.Part{Text: msg.Content().String()})
56 for _, binaryContent := range msg.BinaryContent() {
57 imageFormat := strings.Split(binaryContent.MIMEType, "/")
58 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
59 MIMEType: imageFormat[1],
60 Data: binaryContent.Data,
61 }})
62 }
63 history = append(history, &genai.Content{
64 Parts: parts,
65 Role: "user",
66 })
67 case message.Assistant:
68 var assistantParts []*genai.Part
69
70 if msg.Content().String() != "" {
71 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
72 }
73
74 if len(msg.ToolCalls()) > 0 {
75 for _, call := range msg.ToolCalls() {
76 args, _ := parseJSONToMap(call.Input)
77 assistantParts = append(assistantParts, &genai.Part{
78 FunctionCall: &genai.FunctionCall{
79 Name: call.Name,
80 Args: args,
81 },
82 })
83 }
84 }
85
86 if len(assistantParts) > 0 {
87 history = append(history, &genai.Content{
88 Role: "model",
89 Parts: assistantParts,
90 })
91 }
92
93 case message.Tool:
94 for _, result := range msg.ToolResults() {
95 response := map[string]any{"result": result.Content}
96 parsed, err := parseJSONToMap(result.Content)
97 if err == nil {
98 response = parsed
99 }
100
101 var toolCall message.ToolCall
102 for _, m := range messages {
103 if m.Role == message.Assistant {
104 for _, call := range m.ToolCalls() {
105 if call.ID == result.ToolCallID {
106 toolCall = call
107 break
108 }
109 }
110 }
111 }
112
113 history = append(history, &genai.Content{
114 Parts: []*genai.Part{
115 {
116 FunctionResponse: &genai.FunctionResponse{
117 Name: toolCall.Name,
118 Response: response,
119 },
120 },
121 },
122 Role: "function",
123 })
124 }
125 }
126 }
127
128 return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132 geminiTool := &genai.Tool{}
133 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135 for _, tool := range tools {
136 info := tool.Info()
137 declaration := &genai.FunctionDeclaration{
138 Name: info.Name,
139 Description: info.Description,
140 Parameters: &genai.Schema{
141 Type: genai.TypeObject,
142 Properties: convertSchemaProperties(info.Parameters),
143 Required: info.Required,
144 },
145 }
146
147 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148 }
149
150 return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154 switch reason {
155 case genai.FinishReasonStop:
156 return message.FinishReasonEndTurn
157 case genai.FinishReasonMaxTokens:
158 return message.FinishReasonMaxTokens
159 default:
160 return message.FinishReasonUnknown
161 }
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165 // Convert messages
166 geminiMessages := g.convertMessages(messages)
167 model := g.providerOptions.model(g.providerOptions.modelType)
168 cfg := config.Get()
169 if cfg.Options.Debug {
170 jsonData, _ := json.Marshal(geminiMessages)
171 slog.Debug("Prepared messages", "messages", string(jsonData))
172 }
173
174 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177 }
178
179 maxTokens := model.DefaultMaxTokens
180 if modelConfig.MaxTokens > 0 {
181 maxTokens = modelConfig.MaxTokens
182 }
183 history := geminiMessages[:len(geminiMessages)-1] // All but last message
184 lastMsg := geminiMessages[len(geminiMessages)-1]
185 config := &genai.GenerateContentConfig{
186 MaxOutputTokens: int32(maxTokens),
187 SystemInstruction: &genai.Content{
188 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
189 },
190 }
191 config.Tools = g.convertTools(tools)
192 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
193
194 attempts := 0
195 for {
196 attempts++
197 var toolCalls []message.ToolCall
198
199 var lastMsgParts []genai.Part
200 for _, part := range lastMsg.Parts {
201 lastMsgParts = append(lastMsgParts, *part)
202 }
203 resp, err := chat.SendMessage(ctx, lastMsgParts...)
204 // If there is an error we are going to see if we can retry the call
205 if err != nil {
206 retry, after, retryErr := g.shouldRetry(attempts, err)
207 if retryErr != nil {
208 return nil, retryErr
209 }
210 if retry {
211 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
212 select {
213 case <-ctx.Done():
214 return nil, ctx.Err()
215 case <-time.After(time.Duration(after) * time.Millisecond):
216 continue
217 }
218 }
219 return nil, retryErr
220 }
221
222 content := ""
223
224 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
225 for _, part := range resp.Candidates[0].Content.Parts {
226 switch {
227 case part.Text != "":
228 content = string(part.Text)
229 case part.FunctionCall != nil:
230 id := "call_" + uuid.New().String()
231 args, _ := json.Marshal(part.FunctionCall.Args)
232 toolCalls = append(toolCalls, message.ToolCall{
233 ID: id,
234 Name: part.FunctionCall.Name,
235 Input: string(args),
236 Type: "function",
237 Finished: true,
238 })
239 }
240 }
241 }
242 finishReason := message.FinishReasonEndTurn
243 if len(resp.Candidates) > 0 {
244 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
245 }
246 if len(toolCalls) > 0 {
247 finishReason = message.FinishReasonToolUse
248 }
249
250 return &ProviderResponse{
251 Content: content,
252 ToolCalls: toolCalls,
253 Usage: g.usage(resp),
254 FinishReason: finishReason,
255 }, nil
256 }
257}
258
259func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
260 // Convert messages
261 geminiMessages := g.convertMessages(messages)
262
263 model := g.providerOptions.model(g.providerOptions.modelType)
264 cfg := config.Get()
265 if cfg.Options.Debug {
266 jsonData, _ := json.Marshal(geminiMessages)
267 slog.Debug("Prepared messages", "messages", string(jsonData))
268 }
269
270 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
271 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
272 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
273 }
274 maxTokens := model.DefaultMaxTokens
275 if modelConfig.MaxTokens > 0 {
276 maxTokens = modelConfig.MaxTokens
277 }
278
279 // Override max tokens if set in provider options
280 if g.providerOptions.maxTokens > 0 {
281 maxTokens = g.providerOptions.maxTokens
282 }
283 history := geminiMessages[:len(geminiMessages)-1] // All but last message
284 lastMsg := geminiMessages[len(geminiMessages)-1]
285 config := &genai.GenerateContentConfig{
286 MaxOutputTokens: int32(maxTokens),
287 SystemInstruction: &genai.Content{
288 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
289 },
290 }
291 config.Tools = g.convertTools(tools)
292 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
293
294 attempts := 0
295 eventChan := make(chan ProviderEvent)
296
297 go func() {
298 defer close(eventChan)
299
300 for {
301 attempts++
302
303 currentContent := ""
304 toolCalls := []message.ToolCall{}
305 var finalResp *genai.GenerateContentResponse
306
307 eventChan <- ProviderEvent{Type: EventContentStart}
308
309 var lastMsgParts []genai.Part
310
311 for _, part := range lastMsg.Parts {
312 lastMsgParts = append(lastMsgParts, *part)
313 }
314 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
315 if err != nil {
316 retry, after, retryErr := g.shouldRetry(attempts, err)
317 if retryErr != nil {
318 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
319 return
320 }
321 if retry {
322 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
323 select {
324 case <-ctx.Done():
325 if ctx.Err() != nil {
326 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
327 }
328
329 return
330 case <-time.After(time.Duration(after) * time.Millisecond):
331 break
332 }
333 } else {
334 eventChan <- ProviderEvent{Type: EventError, Error: err}
335 return
336 }
337 }
338
339 finalResp = resp
340
341 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
342 for _, part := range resp.Candidates[0].Content.Parts {
343 switch {
344 case part.Text != "":
345 delta := string(part.Text)
346 if delta != "" {
347 eventChan <- ProviderEvent{
348 Type: EventContentDelta,
349 Content: delta,
350 }
351 currentContent += delta
352 }
353 case part.FunctionCall != nil:
354 id := "call_" + uuid.New().String()
355 args, _ := json.Marshal(part.FunctionCall.Args)
356 newCall := message.ToolCall{
357 ID: id,
358 Name: part.FunctionCall.Name,
359 Input: string(args),
360 Type: "function",
361 Finished: true,
362 }
363
364 isNew := true
365 for _, existing := range toolCalls {
366 if existing.Name == newCall.Name && existing.Input == newCall.Input {
367 isNew = false
368 break
369 }
370 }
371
372 if isNew {
373 toolCalls = append(toolCalls, newCall)
374 }
375 }
376 }
377 }
378 }
379
380 eventChan <- ProviderEvent{Type: EventContentStop}
381
382 if finalResp != nil {
383 finishReason := message.FinishReasonEndTurn
384 if len(finalResp.Candidates) > 0 {
385 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
386 }
387 if len(toolCalls) > 0 {
388 finishReason = message.FinishReasonToolUse
389 }
390 eventChan <- ProviderEvent{
391 Type: EventComplete,
392 Response: &ProviderResponse{
393 Content: currentContent,
394 ToolCalls: toolCalls,
395 Usage: g.usage(finalResp),
396 FinishReason: finishReason,
397 },
398 }
399 return
400 }
401 }
402 }()
403
404 return eventChan
405}
406
407func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
408 // Check if error is a rate limit error
409 if attempts > maxRetries {
410 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
411 }
412
413 // Gemini doesn't have a standard error type we can check against
414 // So we'll check the error message for rate limit indicators
415 if errors.Is(err, io.EOF) {
416 return false, 0, err
417 }
418
419 errMsg := err.Error()
420 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
421
422 // Check for token expiration (401 Unauthorized)
423 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
424 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
425 if err != nil {
426 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
427 }
428 g.client, err = createGeminiClient(g.providerOptions)
429 if err != nil {
430 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
431 }
432 return true, 0, nil
433 }
434
435 // Check for common rate limit error messages
436
437 if !isRateLimit {
438 return false, 0, err
439 }
440
441 // Calculate backoff with jitter
442 backoffMs := 2000 * (1 << (attempts - 1))
443 jitterMs := int(float64(backoffMs) * 0.2)
444 retryMs := backoffMs + jitterMs
445
446 return true, int64(retryMs), nil
447}
448
449func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
450 if resp == nil || resp.UsageMetadata == nil {
451 return TokenUsage{}
452 }
453
454 return TokenUsage{
455 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
456 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
457 CacheCreationTokens: 0, // Not directly provided by Gemini
458 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
459 }
460}
461
462func (g *geminiClient) Model() catwalk.Model {
463 return g.providerOptions.model(g.providerOptions.modelType)
464}
465
466// Helper functions
467func parseJSONToMap(jsonStr string) (map[string]any, error) {
468 var result map[string]any
469 err := json.Unmarshal([]byte(jsonStr), &result)
470 return result, err
471}
472
473func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
474 properties := make(map[string]*genai.Schema)
475
476 for name, param := range parameters {
477 properties[name] = convertToSchema(param)
478 }
479
480 return properties
481}
482
483func convertToSchema(param any) *genai.Schema {
484 schema := &genai.Schema{Type: genai.TypeString}
485
486 paramMap, ok := param.(map[string]any)
487 if !ok {
488 return schema
489 }
490
491 if desc, ok := paramMap["description"].(string); ok {
492 schema.Description = desc
493 }
494
495 typeVal, hasType := paramMap["type"]
496 if !hasType {
497 return schema
498 }
499
500 typeStr, ok := typeVal.(string)
501 if !ok {
502 return schema
503 }
504
505 schema.Type = mapJSONTypeToGenAI(typeStr)
506
507 switch typeStr {
508 case "array":
509 schema.Items = processArrayItems(paramMap)
510 case "object":
511 if props, ok := paramMap["properties"].(map[string]any); ok {
512 schema.Properties = convertSchemaProperties(props)
513 }
514 }
515
516 return schema
517}
518
519func processArrayItems(paramMap map[string]any) *genai.Schema {
520 items, ok := paramMap["items"].(map[string]any)
521 if !ok {
522 return nil
523 }
524
525 return convertToSchema(items)
526}
527
528func mapJSONTypeToGenAI(jsonType string) genai.Type {
529 switch jsonType {
530 case "string":
531 return genai.TypeString
532 case "number":
533 return genai.TypeNumber
534 case "integer":
535 return genai.TypeInteger
536 case "boolean":
537 return genai.TypeBoolean
538 case "array":
539 return genai.TypeArray
540 case "object":
541 return genai.TypeObject
542 default:
543 return genai.TypeString // Default to string for unknown types
544 }
545}
546
547func contains(s string, substrs ...string) bool {
548 for _, substr := range substrs {
549 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
550 return true
551 }
552 }
553 return false
554}