1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22type geminiClient struct {
23 providerOptions providerClientOptions
24 client *genai.Client
25}
26
27type GeminiClient ProviderClient
28
29func newGeminiClient(opts providerClientOptions) GeminiClient {
30 client, err := createGeminiClient(opts)
31 if err != nil {
32 slog.Error("Failed to create Gemini client", "error", err)
33 return nil
34 }
35
36 return &geminiClient{
37 providerOptions: opts,
38 client: client,
39 }
40}
41
42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
43 cc := &genai.ClientConfig{
44 APIKey: opts.apiKey,
45 Backend: genai.BackendGeminiAPI,
46 }
47 if config.Get().Options.Debug {
48 cc.HTTPClient = log.NewHTTPClient()
49 }
50 client, err := genai.NewClient(context.Background(), cc)
51 if err != nil {
52 return nil, err
53 }
54 return client, nil
55}
56
57func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
58 var history []*genai.Content
59 for _, msg := range messages {
60 switch msg.Role {
61 case message.User:
62 var parts []*genai.Part
63 parts = append(parts, &genai.Part{Text: msg.Content().String()})
64 for _, binaryContent := range msg.BinaryContent() {
65 imageFormat := strings.Split(binaryContent.MIMEType, "/")
66 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
67 MIMEType: imageFormat[1],
68 Data: binaryContent.Data,
69 }})
70 }
71 history = append(history, &genai.Content{
72 Parts: parts,
73 Role: genai.RoleUser,
74 })
75 case message.Assistant:
76 var assistantParts []*genai.Part
77
78 if msg.Content().String() != "" {
79 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
80 }
81
82 if len(msg.ToolCalls()) > 0 {
83 for _, call := range msg.ToolCalls() {
84 if !call.Finished {
85 continue
86 }
87 args, _ := parseJSONToMap(call.Input)
88 assistantParts = append(assistantParts, &genai.Part{
89 FunctionCall: &genai.FunctionCall{
90 Name: call.Name,
91 Args: args,
92 },
93 })
94 }
95 }
96
97 if len(assistantParts) > 0 {
98 history = append(history, &genai.Content{
99 Role: genai.RoleModel,
100 Parts: assistantParts,
101 })
102 }
103
104 case message.Tool:
105 for _, result := range msg.ToolResults() {
106 response := map[string]any{"result": result.Content}
107 parsed, err := parseJSONToMap(result.Content)
108 if err == nil {
109 response = parsed
110 }
111
112 var toolCall message.ToolCall
113 for _, m := range messages {
114 if m.Role == message.Assistant {
115 for _, call := range m.ToolCalls() {
116 if call.ID == result.ToolCallID {
117 toolCall = call
118 break
119 }
120 }
121 }
122 }
123
124 history = append(history, &genai.Content{
125 Parts: []*genai.Part{
126 {
127 FunctionResponse: &genai.FunctionResponse{
128 Name: toolCall.Name,
129 Response: response,
130 },
131 },
132 },
133 Role: genai.RoleModel,
134 })
135 }
136 }
137 }
138
139 return history
140}
141
142func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
143 geminiTool := &genai.Tool{}
144 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
145
146 for _, tool := range tools {
147 info := tool.Info()
148 declaration := &genai.FunctionDeclaration{
149 Name: info.Name,
150 Description: info.Description,
151 Parameters: &genai.Schema{
152 Type: genai.TypeObject,
153 Properties: convertSchemaProperties(info.Parameters),
154 Required: info.Required,
155 },
156 }
157
158 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
159 }
160
161 return []*genai.Tool{geminiTool}
162}
163
164func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
165 switch reason {
166 case genai.FinishReasonStop:
167 return message.FinishReasonEndTurn
168 case genai.FinishReasonMaxTokens:
169 return message.FinishReasonMaxTokens
170 default:
171 return message.FinishReasonUnknown
172 }
173}
174
175func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
176 // Convert messages
177 geminiMessages := g.convertMessages(messages)
178 model := g.providerOptions.model(g.providerOptions.modelType)
179 cfg := config.Get()
180
181 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
182 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
183 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
184 }
185
186 maxTokens := model.DefaultMaxTokens
187 if modelConfig.MaxTokens > 0 {
188 maxTokens = modelConfig.MaxTokens
189 }
190 systemMessage := g.providerOptions.systemMessage
191 if g.providerOptions.systemPromptPrefix != "" {
192 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
193 }
194 history := geminiMessages[:len(geminiMessages)-1] // All but last message
195 lastMsg := geminiMessages[len(geminiMessages)-1]
196 config := &genai.GenerateContentConfig{
197 MaxOutputTokens: int32(maxTokens),
198 SystemInstruction: &genai.Content{
199 Parts: []*genai.Part{{Text: systemMessage}},
200 },
201 }
202 config.Tools = g.convertTools(tools)
203 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
204
205 attempts := 0
206 for {
207 attempts++
208 var toolCalls []message.ToolCall
209
210 var lastMsgParts []genai.Part
211 for _, part := range lastMsg.Parts {
212 lastMsgParts = append(lastMsgParts, *part)
213 }
214 resp, err := chat.SendMessage(ctx, lastMsgParts...)
215 // If there is an error we are going to see if we can retry the call
216 if err != nil {
217 retry, after, retryErr := g.shouldRetry(attempts, err)
218 if retryErr != nil {
219 return nil, retryErr
220 }
221 if retry {
222 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
223 select {
224 case <-ctx.Done():
225 return nil, ctx.Err()
226 case <-time.After(time.Duration(after) * time.Millisecond):
227 continue
228 }
229 }
230 return nil, retryErr
231 }
232
233 content := ""
234
235 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
236 for _, part := range resp.Candidates[0].Content.Parts {
237 switch {
238 case part.Text != "":
239 content = string(part.Text)
240 case part.FunctionCall != nil:
241 id := "call_" + uuid.New().String()
242 args, _ := json.Marshal(part.FunctionCall.Args)
243 toolCalls = append(toolCalls, message.ToolCall{
244 ID: id,
245 Name: part.FunctionCall.Name,
246 Input: string(args),
247 Type: "function",
248 Finished: true,
249 })
250 }
251 }
252 }
253 finishReason := message.FinishReasonEndTurn
254 if len(resp.Candidates) > 0 {
255 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
256 }
257 if len(toolCalls) > 0 {
258 finishReason = message.FinishReasonToolUse
259 }
260
261 return &ProviderResponse{
262 Content: content,
263 ToolCalls: toolCalls,
264 Usage: g.usage(resp),
265 FinishReason: finishReason,
266 }, nil
267 }
268}
269
270func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
271 // Convert messages
272 geminiMessages := g.convertMessages(messages)
273
274 model := g.providerOptions.model(g.providerOptions.modelType)
275 cfg := config.Get()
276
277 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
278 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
279 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
280 }
281 maxTokens := model.DefaultMaxTokens
282 if modelConfig.MaxTokens > 0 {
283 maxTokens = modelConfig.MaxTokens
284 }
285
286 // Override max tokens if set in provider options
287 if g.providerOptions.maxTokens > 0 {
288 maxTokens = g.providerOptions.maxTokens
289 }
290 systemMessage := g.providerOptions.systemMessage
291 if g.providerOptions.systemPromptPrefix != "" {
292 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
293 }
294 history := geminiMessages[:len(geminiMessages)-1] // All but last message
295 lastMsg := geminiMessages[len(geminiMessages)-1]
296 config := &genai.GenerateContentConfig{
297 MaxOutputTokens: int32(maxTokens),
298 SystemInstruction: &genai.Content{
299 Parts: []*genai.Part{{Text: systemMessage}},
300 },
301 }
302 config.Tools = g.convertTools(tools)
303 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
304
305 attempts := 0
306 eventChan := make(chan ProviderEvent)
307
308 go func() {
309 defer close(eventChan)
310
311 for {
312 attempts++
313
314 currentContent := ""
315 toolCalls := []message.ToolCall{}
316 var finalResp *genai.GenerateContentResponse
317
318 eventChan <- ProviderEvent{Type: EventContentStart}
319
320 var lastMsgParts []genai.Part
321
322 for _, part := range lastMsg.Parts {
323 lastMsgParts = append(lastMsgParts, *part)
324 }
325
326 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
327 if err != nil {
328 retry, after, retryErr := g.shouldRetry(attempts, err)
329 if retryErr != nil {
330 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
331 return
332 }
333 if retry {
334 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
335 select {
336 case <-ctx.Done():
337 if ctx.Err() != nil {
338 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
339 }
340
341 return
342 case <-time.After(time.Duration(after) * time.Millisecond):
343 continue
344 }
345 } else {
346 eventChan <- ProviderEvent{Type: EventError, Error: err}
347 return
348 }
349 }
350
351 finalResp = resp
352
353 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
354 for _, part := range resp.Candidates[0].Content.Parts {
355 switch {
356 case part.Text != "":
357 delta := string(part.Text)
358 if delta != "" {
359 eventChan <- ProviderEvent{
360 Type: EventContentDelta,
361 Content: delta,
362 }
363 currentContent += delta
364 }
365 case part.FunctionCall != nil:
366 id := "call_" + uuid.New().String()
367 args, _ := json.Marshal(part.FunctionCall.Args)
368 newCall := message.ToolCall{
369 ID: id,
370 Name: part.FunctionCall.Name,
371 Input: string(args),
372 Type: "function",
373 Finished: true,
374 }
375
376 isNew := true
377 for _, existing := range toolCalls {
378 if existing.Name == newCall.Name && existing.Input == newCall.Input {
379 isNew = false
380 break
381 }
382 }
383
384 if isNew {
385 toolCalls = append(toolCalls, newCall)
386 }
387 }
388 }
389 } else {
390 // no content received
391 break
392 }
393 }
394
395 eventChan <- ProviderEvent{Type: EventContentStop}
396
397 if finalResp != nil {
398 finishReason := message.FinishReasonEndTurn
399 if len(finalResp.Candidates) > 0 {
400 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
401 }
402 if len(toolCalls) > 0 {
403 finishReason = message.FinishReasonToolUse
404 }
405 eventChan <- ProviderEvent{
406 Type: EventComplete,
407 Response: &ProviderResponse{
408 Content: currentContent,
409 ToolCalls: toolCalls,
410 Usage: g.usage(finalResp),
411 FinishReason: finishReason,
412 },
413 }
414 return
415 } else {
416 eventChan <- ProviderEvent{
417 Type: EventError,
418 Error: errors.New("no content received"),
419 }
420 }
421 }
422 }()
423
424 return eventChan
425}
426
427func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
428 // Check if error is a rate limit error
429 if attempts > maxRetries {
430 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
431 }
432
433 // Gemini doesn't have a standard error type we can check against
434 // So we'll check the error message for rate limit indicators
435 if errors.Is(err, io.EOF) {
436 return false, 0, err
437 }
438
439 errMsg := err.Error()
440 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
441
442 // Check for token expiration (401 Unauthorized)
443 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
444 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
445 if err != nil {
446 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
447 }
448 g.client, err = createGeminiClient(g.providerOptions)
449 if err != nil {
450 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
451 }
452 return true, 0, nil
453 }
454
455 // Check for common rate limit error messages
456
457 if !isRateLimit {
458 return false, 0, err
459 }
460
461 // Calculate backoff with jitter
462 backoffMs := 2000 * (1 << (attempts - 1))
463 jitterMs := int(float64(backoffMs) * 0.2)
464 retryMs := backoffMs + jitterMs
465
466 return true, int64(retryMs), nil
467}
468
469func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
470 if resp == nil || resp.UsageMetadata == nil {
471 return TokenUsage{}
472 }
473
474 return TokenUsage{
475 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
476 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
477 CacheCreationTokens: 0, // Not directly provided by Gemini
478 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
479 }
480}
481
482func (g *geminiClient) Model() catwalk.Model {
483 return g.providerOptions.model(g.providerOptions.modelType)
484}
485
486// Helper functions
487func parseJSONToMap(jsonStr string) (map[string]any, error) {
488 var result map[string]any
489 err := json.Unmarshal([]byte(jsonStr), &result)
490 return result, err
491}
492
493func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
494 properties := make(map[string]*genai.Schema)
495
496 for name, param := range parameters {
497 properties[name] = convertToSchema(param)
498 }
499
500 return properties
501}
502
503func convertToSchema(param any) *genai.Schema {
504 schema := &genai.Schema{Type: genai.TypeString}
505
506 paramMap, ok := param.(map[string]any)
507 if !ok {
508 return schema
509 }
510
511 if desc, ok := paramMap["description"].(string); ok {
512 schema.Description = desc
513 }
514
515 typeVal, hasType := paramMap["type"]
516 if !hasType {
517 return schema
518 }
519
520 typeStr, ok := typeVal.(string)
521 if !ok {
522 return schema
523 }
524
525 schema.Type = mapJSONTypeToGenAI(typeStr)
526
527 switch typeStr {
528 case "array":
529 schema.Items = processArrayItems(paramMap)
530 case "object":
531 if props, ok := paramMap["properties"].(map[string]any); ok {
532 schema.Properties = convertSchemaProperties(props)
533 }
534 }
535
536 return schema
537}
538
539func processArrayItems(paramMap map[string]any) *genai.Schema {
540 items, ok := paramMap["items"].(map[string]any)
541 if !ok {
542 return nil
543 }
544
545 return convertToSchema(items)
546}
547
548func mapJSONTypeToGenAI(jsonType string) genai.Type {
549 switch jsonType {
550 case "string":
551 return genai.TypeString
552 case "number":
553 return genai.TypeNumber
554 case "integer":
555 return genai.TypeInteger
556 case "boolean":
557 return genai.TypeBoolean
558 case "array":
559 return genai.TypeArray
560 case "object":
561 return genai.TypeObject
562 default:
563 return genai.TypeString // Default to string for unknown types
564 }
565}
566
567func contains(s string, substrs ...string) bool {
568 for _, substr := range substrs {
569 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
570 return true
571 }
572 }
573 return false
574}