1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/google/uuid"
18 "google.golang.org/genai"
19)
20
21type geminiClient struct {
22 providerOptions providerClientOptions
23 client *genai.Client
24}
25
26type GeminiClient ProviderClient
27
28func newGeminiClient(opts providerClientOptions) GeminiClient {
29 client, err := createGeminiClient(opts)
30 if err != nil {
31 slog.Error("Failed to create Gemini client", "error", err)
32 return nil
33 }
34
35 return &geminiClient{
36 providerOptions: opts,
37 client: client,
38 }
39}
40
41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
42 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
43 if err != nil {
44 return nil, err
45 }
46 return client, nil
47}
48
49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
50 var history []*genai.Content
51 for _, msg := range messages {
52 switch msg.Role {
53 case message.User:
54 var parts []*genai.Part
55 parts = append(parts, &genai.Part{Text: msg.Content().String()})
56 for _, binaryContent := range msg.BinaryContent() {
57 imageFormat := strings.Split(binaryContent.MIMEType, "/")
58 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
59 MIMEType: imageFormat[1],
60 Data: binaryContent.Data,
61 }})
62 }
63 history = append(history, &genai.Content{
64 Parts: parts,
65 Role: "user",
66 })
67 case message.Assistant:
68 var assistantParts []*genai.Part
69
70 if msg.Content().String() != "" {
71 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
72 }
73
74 if len(msg.ToolCalls()) > 0 {
75 for _, call := range msg.ToolCalls() {
76 args, _ := parseJSONToMap(call.Input)
77 assistantParts = append(assistantParts, &genai.Part{
78 FunctionCall: &genai.FunctionCall{
79 Name: call.Name,
80 Args: args,
81 },
82 })
83 }
84 }
85
86 if len(assistantParts) > 0 {
87 history = append(history, &genai.Content{
88 Role: "model",
89 Parts: assistantParts,
90 })
91 }
92
93 case message.Tool:
94 for _, result := range msg.ToolResults() {
95 response := map[string]any{"result": result.Content}
96 parsed, err := parseJSONToMap(result.Content)
97 if err == nil {
98 response = parsed
99 }
100
101 var toolCall message.ToolCall
102 for _, m := range messages {
103 if m.Role == message.Assistant {
104 for _, call := range m.ToolCalls() {
105 if call.ID == result.ToolCallID {
106 toolCall = call
107 break
108 }
109 }
110 }
111 }
112
113 history = append(history, &genai.Content{
114 Parts: []*genai.Part{
115 {
116 FunctionResponse: &genai.FunctionResponse{
117 Name: toolCall.Name,
118 Response: response,
119 },
120 },
121 },
122 Role: "function",
123 })
124 }
125 }
126 }
127
128 return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132 geminiTool := &genai.Tool{}
133 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135 for _, tool := range tools {
136 info := tool.Info()
137 declaration := &genai.FunctionDeclaration{
138 Name: info.Name,
139 Description: info.Description,
140 Parameters: &genai.Schema{
141 Type: genai.TypeObject,
142 Properties: convertSchemaProperties(info.Parameters),
143 Required: info.Required,
144 },
145 }
146
147 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148 }
149
150 return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154 switch reason {
155 case genai.FinishReasonStop:
156 return message.FinishReasonEndTurn
157 case genai.FinishReasonMaxTokens:
158 return message.FinishReasonMaxTokens
159 default:
160 return message.FinishReasonUnknown
161 }
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165 // Convert messages
166 geminiMessages := g.convertMessages(messages)
167 model := g.providerOptions.model(g.providerOptions.modelType)
168 cfg := config.Get()
169 if cfg.Options.Debug {
170 jsonData, _ := json.Marshal(geminiMessages)
171 slog.Debug("Prepared messages", "messages", string(jsonData))
172 }
173
174 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177 }
178
179 maxTokens := model.DefaultMaxTokens
180 if modelConfig.MaxTokens > 0 {
181 maxTokens = modelConfig.MaxTokens
182 }
183 history := geminiMessages[:len(geminiMessages)-1] // All but last message
184 lastMsg := geminiMessages[len(geminiMessages)-1]
185 config := &genai.GenerateContentConfig{
186 MaxOutputTokens: int32(maxTokens),
187 SystemInstruction: &genai.Content{
188 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
189 },
190 }
191 if len(tools) > 0 {
192 config.Tools = g.convertTools(tools)
193 }
194 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
195
196 attempts := 0
197 for {
198 attempts++
199 var toolCalls []message.ToolCall
200
201 var lastMsgParts []genai.Part
202 for _, part := range lastMsg.Parts {
203 lastMsgParts = append(lastMsgParts, *part)
204 }
205 resp, err := chat.SendMessage(ctx, lastMsgParts...)
206 // If there is an error we are going to see if we can retry the call
207 if err != nil {
208 retry, after, retryErr := g.shouldRetry(attempts, err)
209 if retryErr != nil {
210 return nil, retryErr
211 }
212 if retry {
213 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
214 select {
215 case <-ctx.Done():
216 return nil, ctx.Err()
217 case <-time.After(time.Duration(after) * time.Millisecond):
218 continue
219 }
220 }
221 return nil, retryErr
222 }
223
224 content := ""
225
226 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
227 for _, part := range resp.Candidates[0].Content.Parts {
228 switch {
229 case part.Text != "":
230 content = string(part.Text)
231 case part.FunctionCall != nil:
232 id := "call_" + uuid.New().String()
233 args, _ := json.Marshal(part.FunctionCall.Args)
234 toolCalls = append(toolCalls, message.ToolCall{
235 ID: id,
236 Name: part.FunctionCall.Name,
237 Input: string(args),
238 Type: "function",
239 Finished: true,
240 })
241 }
242 }
243 }
244 finishReason := message.FinishReasonEndTurn
245 if len(resp.Candidates) > 0 {
246 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
247 }
248 if len(toolCalls) > 0 {
249 finishReason = message.FinishReasonToolUse
250 }
251
252 return &ProviderResponse{
253 Content: content,
254 ToolCalls: toolCalls,
255 Usage: g.usage(resp),
256 FinishReason: finishReason,
257 }, nil
258 }
259}
260
261func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262 // Convert messages
263 geminiMessages := g.convertMessages(messages)
264
265 model := g.providerOptions.model(g.providerOptions.modelType)
266 cfg := config.Get()
267 if cfg.Options.Debug {
268 jsonData, _ := json.Marshal(geminiMessages)
269 slog.Debug("Prepared messages", "messages", string(jsonData))
270 }
271
272 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
273 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
274 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
275 }
276 maxTokens := model.DefaultMaxTokens
277 if modelConfig.MaxTokens > 0 {
278 maxTokens = modelConfig.MaxTokens
279 }
280
281 // Override max tokens if set in provider options
282 if g.providerOptions.maxTokens > 0 {
283 maxTokens = g.providerOptions.maxTokens
284 }
285 history := geminiMessages[:len(geminiMessages)-1] // All but last message
286 lastMsg := geminiMessages[len(geminiMessages)-1]
287 config := &genai.GenerateContentConfig{
288 MaxOutputTokens: int32(maxTokens),
289 SystemInstruction: &genai.Content{
290 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
291 },
292 }
293 if len(tools) > 0 {
294 config.Tools = g.convertTools(tools)
295 }
296 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
297
298 attempts := 0
299 eventChan := make(chan ProviderEvent)
300
301 go func() {
302 defer close(eventChan)
303
304 for {
305 attempts++
306
307 currentContent := ""
308 toolCalls := []message.ToolCall{}
309 var finalResp *genai.GenerateContentResponse
310
311 eventChan <- ProviderEvent{Type: EventContentStart}
312
313 var lastMsgParts []genai.Part
314
315 for _, part := range lastMsg.Parts {
316 lastMsgParts = append(lastMsgParts, *part)
317 }
318 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
319 if err != nil {
320 retry, after, retryErr := g.shouldRetry(attempts, err)
321 if retryErr != nil {
322 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
323 return
324 }
325 if retry {
326 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
327 select {
328 case <-ctx.Done():
329 if ctx.Err() != nil {
330 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
331 }
332
333 return
334 case <-time.After(time.Duration(after) * time.Millisecond):
335 break
336 }
337 } else {
338 eventChan <- ProviderEvent{Type: EventError, Error: err}
339 return
340 }
341 }
342
343 finalResp = resp
344
345 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
346 for _, part := range resp.Candidates[0].Content.Parts {
347 switch {
348 case part.Text != "":
349 delta := string(part.Text)
350 if delta != "" {
351 eventChan <- ProviderEvent{
352 Type: EventContentDelta,
353 Content: delta,
354 }
355 currentContent += delta
356 }
357 case part.FunctionCall != nil:
358 id := "call_" + uuid.New().String()
359 args, _ := json.Marshal(part.FunctionCall.Args)
360 newCall := message.ToolCall{
361 ID: id,
362 Name: part.FunctionCall.Name,
363 Input: string(args),
364 Type: "function",
365 Finished: true,
366 }
367
368 isNew := true
369 for _, existing := range toolCalls {
370 if existing.Name == newCall.Name && existing.Input == newCall.Input {
371 isNew = false
372 break
373 }
374 }
375
376 if isNew {
377 toolCalls = append(toolCalls, newCall)
378 }
379 }
380 }
381 }
382 }
383
384 eventChan <- ProviderEvent{Type: EventContentStop}
385
386 if finalResp != nil {
387 finishReason := message.FinishReasonEndTurn
388 if len(finalResp.Candidates) > 0 {
389 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
390 }
391 if len(toolCalls) > 0 {
392 finishReason = message.FinishReasonToolUse
393 }
394 eventChan <- ProviderEvent{
395 Type: EventComplete,
396 Response: &ProviderResponse{
397 Content: currentContent,
398 ToolCalls: toolCalls,
399 Usage: g.usage(finalResp),
400 FinishReason: finishReason,
401 },
402 }
403 return
404 }
405 }
406 }()
407
408 return eventChan
409}
410
411func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
412 // Check if error is a rate limit error
413 if attempts > maxRetries {
414 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
415 }
416
417 // Gemini doesn't have a standard error type we can check against
418 // So we'll check the error message for rate limit indicators
419 if errors.Is(err, io.EOF) {
420 return false, 0, err
421 }
422
423 errMsg := err.Error()
424 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
425
426 // Check for token expiration (401 Unauthorized)
427 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
428 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
429 if err != nil {
430 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
431 }
432 g.client, err = createGeminiClient(g.providerOptions)
433 if err != nil {
434 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
435 }
436 return true, 0, nil
437 }
438
439 // Check for common rate limit error messages
440
441 if !isRateLimit {
442 return false, 0, err
443 }
444
445 // Calculate backoff with jitter
446 backoffMs := 2000 * (1 << (attempts - 1))
447 jitterMs := int(float64(backoffMs) * 0.2)
448 retryMs := backoffMs + jitterMs
449
450 return true, int64(retryMs), nil
451}
452
453func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
454 if resp == nil || resp.UsageMetadata == nil {
455 return TokenUsage{}
456 }
457
458 return TokenUsage{
459 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
460 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
461 CacheCreationTokens: 0, // Not directly provided by Gemini
462 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
463 }
464}
465
466func (g *geminiClient) Model() catwalk.Model {
467 return g.providerOptions.model(g.providerOptions.modelType)
468}
469
470// Helper functions
471func parseJSONToMap(jsonStr string) (map[string]any, error) {
472 var result map[string]any
473 err := json.Unmarshal([]byte(jsonStr), &result)
474 return result, err
475}
476
477func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
478 properties := make(map[string]*genai.Schema)
479
480 for name, param := range parameters {
481 properties[name] = convertToSchema(param)
482 }
483
484 return properties
485}
486
487func convertToSchema(param any) *genai.Schema {
488 schema := &genai.Schema{Type: genai.TypeString}
489
490 paramMap, ok := param.(map[string]any)
491 if !ok {
492 return schema
493 }
494
495 if desc, ok := paramMap["description"].(string); ok {
496 schema.Description = desc
497 }
498
499 typeVal, hasType := paramMap["type"]
500 if !hasType {
501 return schema
502 }
503
504 typeStr, ok := typeVal.(string)
505 if !ok {
506 return schema
507 }
508
509 schema.Type = mapJSONTypeToGenAI(typeStr)
510
511 switch typeStr {
512 case "array":
513 schema.Items = processArrayItems(paramMap)
514 case "object":
515 if props, ok := paramMap["properties"].(map[string]any); ok {
516 schema.Properties = convertSchemaProperties(props)
517 }
518 }
519
520 return schema
521}
522
523func processArrayItems(paramMap map[string]any) *genai.Schema {
524 items, ok := paramMap["items"].(map[string]any)
525 if !ok {
526 return nil
527 }
528
529 return convertToSchema(items)
530}
531
532func mapJSONTypeToGenAI(jsonType string) genai.Type {
533 switch jsonType {
534 case "string":
535 return genai.TypeString
536 case "number":
537 return genai.TypeNumber
538 case "integer":
539 return genai.TypeInteger
540 case "boolean":
541 return genai.TypeBoolean
542 case "array":
543 return genai.TypeArray
544 case "object":
545 return genai.TypeObject
546 default:
547 return genai.TypeString // Default to string for unknown types
548 }
549}
550
551func contains(s string, substrs ...string) bool {
552 for _, substr := range substrs {
553 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
554 return true
555 }
556 }
557 return false
558}