1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22type geminiClient struct {
23 providerOptions providerClientOptions
24 client *genai.Client
25}
26
27type GeminiClient ProviderClient
28
29func newGeminiClient(opts providerClientOptions) GeminiClient {
30 client, err := createGeminiClient(opts)
31 if err != nil {
32 slog.Error("Failed to create Gemini client", "error", err)
33 return nil
34 }
35
36 return &geminiClient{
37 providerOptions: opts,
38 client: client,
39 }
40}
41
42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
43 cc := &genai.ClientConfig{
44 APIKey: opts.apiKey,
45 Backend: genai.BackendGeminiAPI,
46 HTTPOptions: genai.HTTPOptions{
47 BaseURL: opts.baseURL,
48 },
49 }
50 if config.Get().Options.Debug {
51 cc.HTTPClient = log.NewHTTPClient()
52 }
53 client, err := genai.NewClient(context.Background(), cc)
54 if err != nil {
55 return nil, err
56 }
57 return client, nil
58}
59
60func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
61 var history []*genai.Content
62 for _, msg := range messages {
63 switch msg.Role {
64 case message.User:
65 var parts []*genai.Part
66 parts = append(parts, &genai.Part{Text: msg.Content().String()})
67 for _, binaryContent := range msg.BinaryContent() {
68 imageFormat := strings.Split(binaryContent.MIMEType, "/")
69 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
70 MIMEType: imageFormat[1],
71 Data: binaryContent.Data,
72 }})
73 }
74 history = append(history, &genai.Content{
75 Parts: parts,
76 Role: genai.RoleUser,
77 })
78 case message.Assistant:
79 var assistantParts []*genai.Part
80
81 if msg.Content().String() != "" {
82 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
83 }
84
85 if len(msg.ToolCalls()) > 0 {
86 for _, call := range msg.ToolCalls() {
87 if !call.Finished {
88 continue
89 }
90 args, _ := parseJSONToMap(call.Input)
91 assistantParts = append(assistantParts, &genai.Part{
92 FunctionCall: &genai.FunctionCall{
93 Name: call.Name,
94 Args: args,
95 },
96 })
97 }
98 }
99
100 if len(assistantParts) > 0 {
101 history = append(history, &genai.Content{
102 Role: genai.RoleModel,
103 Parts: assistantParts,
104 })
105 }
106
107 case message.Tool:
108 var toolParts []*genai.Part
109 for _, result := range msg.ToolResults() {
110 response := map[string]any{"result": result.Content}
111 parsed, err := parseJSONToMap(result.Content)
112 if err == nil {
113 response = parsed
114 }
115
116 var toolCall message.ToolCall
117 for _, m := range messages {
118 if m.Role == message.Assistant {
119 for _, call := range m.ToolCalls() {
120 if call.ID == result.ToolCallID {
121 toolCall = call
122 break
123 }
124 }
125 }
126 }
127
128 toolParts = append(toolParts, &genai.Part{
129 FunctionResponse: &genai.FunctionResponse{
130 Name: toolCall.Name,
131 Response: response,
132 },
133 })
134 }
135 if len(toolParts) > 0 {
136 history = append(history, &genai.Content{
137 Parts: toolParts,
138 Role: genai.RoleUser,
139 })
140 }
141 }
142 }
143
144 return history
145}
146
147func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
148 geminiTool := &genai.Tool{}
149 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
150
151 for _, tool := range tools {
152 info := tool.Info()
153 declaration := &genai.FunctionDeclaration{
154 Name: info.Name,
155 Description: info.Description,
156 Parameters: &genai.Schema{
157 Type: genai.TypeObject,
158 Properties: convertSchemaProperties(info.Parameters),
159 Required: info.Required,
160 },
161 }
162
163 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
164 }
165
166 return []*genai.Tool{geminiTool}
167}
168
169func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
170 switch reason {
171 case genai.FinishReasonStop:
172 return message.FinishReasonEndTurn
173 case genai.FinishReasonMaxTokens:
174 return message.FinishReasonMaxTokens
175 default:
176 return message.FinishReasonUnknown
177 }
178}
179
180func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
181 // Convert messages
182 geminiMessages := g.convertMessages(messages)
183 model := g.providerOptions.model(g.providerOptions.modelType)
184 cfg := config.Get()
185
186 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
187 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
188 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
189 }
190
191 maxTokens := model.DefaultMaxTokens
192 if modelConfig.MaxTokens > 0 {
193 maxTokens = modelConfig.MaxTokens
194 }
195 systemMessage := g.providerOptions.systemMessage
196 if g.providerOptions.systemPromptPrefix != "" {
197 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
198 }
199 history := geminiMessages[:len(geminiMessages)-1] // All but last message
200 lastMsg := geminiMessages[len(geminiMessages)-1]
201 config := &genai.GenerateContentConfig{
202 MaxOutputTokens: int32(maxTokens),
203 SystemInstruction: &genai.Content{
204 Parts: []*genai.Part{{Text: systemMessage}},
205 },
206 }
207 config.Tools = g.convertTools(tools)
208 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
209
210 attempts := 0
211 for {
212 attempts++
213 var toolCalls []message.ToolCall
214
215 var lastMsgParts []genai.Part
216 for _, part := range lastMsg.Parts {
217 lastMsgParts = append(lastMsgParts, *part)
218 }
219 resp, err := chat.SendMessage(ctx, lastMsgParts...)
220 // If there is an error we are going to see if we can retry the call
221 if err != nil {
222 retry, after, retryErr := g.shouldRetry(attempts, err)
223 if retryErr != nil {
224 return nil, retryErr
225 }
226 if retry {
227 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
228 select {
229 case <-ctx.Done():
230 return nil, ctx.Err()
231 case <-time.After(time.Duration(after) * time.Millisecond):
232 continue
233 }
234 }
235 return nil, retryErr
236 }
237
238 content := ""
239
240 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
241 for _, part := range resp.Candidates[0].Content.Parts {
242 switch {
243 case part.Text != "":
244 content = string(part.Text)
245 case part.FunctionCall != nil:
246 id := "call_" + uuid.New().String()
247 args, _ := json.Marshal(part.FunctionCall.Args)
248 toolCalls = append(toolCalls, message.ToolCall{
249 ID: id,
250 Name: part.FunctionCall.Name,
251 Input: string(args),
252 Type: "function",
253 Finished: true,
254 })
255 }
256 }
257 }
258 finishReason := message.FinishReasonEndTurn
259 if len(resp.Candidates) > 0 {
260 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
261 }
262 if len(toolCalls) > 0 {
263 finishReason = message.FinishReasonToolUse
264 }
265
266 return &ProviderResponse{
267 Content: content,
268 ToolCalls: toolCalls,
269 Usage: g.usage(resp),
270 FinishReason: finishReason,
271 }, nil
272 }
273}
274
275func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
276 // Convert messages
277 geminiMessages := g.convertMessages(messages)
278
279 model := g.providerOptions.model(g.providerOptions.modelType)
280 cfg := config.Get()
281
282 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
283 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
284 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
285 }
286 maxTokens := model.DefaultMaxTokens
287 if modelConfig.MaxTokens > 0 {
288 maxTokens = modelConfig.MaxTokens
289 }
290
291 // Override max tokens if set in provider options
292 if g.providerOptions.maxTokens > 0 {
293 maxTokens = g.providerOptions.maxTokens
294 }
295 systemMessage := g.providerOptions.systemMessage
296 if g.providerOptions.systemPromptPrefix != "" {
297 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
298 }
299 history := geminiMessages[:len(geminiMessages)-1] // All but last message
300 lastMsg := geminiMessages[len(geminiMessages)-1]
301 config := &genai.GenerateContentConfig{
302 MaxOutputTokens: int32(maxTokens),
303 SystemInstruction: &genai.Content{
304 Parts: []*genai.Part{{Text: systemMessage}},
305 },
306 }
307 config.Tools = g.convertTools(tools)
308 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
309
310 attempts := 0
311 eventChan := make(chan ProviderEvent)
312
313 go func() {
314 defer close(eventChan)
315
316 for {
317 attempts++
318
319 currentContent := ""
320 toolCalls := []message.ToolCall{}
321 var finalResp *genai.GenerateContentResponse
322
323 eventChan <- ProviderEvent{Type: EventContentStart}
324
325 var lastMsgParts []genai.Part
326
327 for _, part := range lastMsg.Parts {
328 lastMsgParts = append(lastMsgParts, *part)
329 }
330
331 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
332 if err != nil {
333 retry, after, retryErr := g.shouldRetry(attempts, err)
334 if retryErr != nil {
335 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
336 return
337 }
338 if retry {
339 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
340 select {
341 case <-ctx.Done():
342 if ctx.Err() != nil {
343 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
344 }
345
346 return
347 case <-time.After(time.Duration(after) * time.Millisecond):
348 continue
349 }
350 } else {
351 eventChan <- ProviderEvent{Type: EventError, Error: err}
352 return
353 }
354 }
355
356 finalResp = resp
357
358 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
359 for _, part := range resp.Candidates[0].Content.Parts {
360 switch {
361 case part.Text != "":
362 delta := string(part.Text)
363 if delta != "" {
364 eventChan <- ProviderEvent{
365 Type: EventContentDelta,
366 Content: delta,
367 }
368 currentContent += delta
369 }
370 case part.FunctionCall != nil:
371 id := "call_" + uuid.New().String()
372 args, _ := json.Marshal(part.FunctionCall.Args)
373 newCall := message.ToolCall{
374 ID: id,
375 Name: part.FunctionCall.Name,
376 Input: string(args),
377 Type: "function",
378 Finished: true,
379 }
380
381 toolCalls = append(toolCalls, newCall)
382 }
383 }
384 } else {
385 // no content received
386 break
387 }
388 }
389
390 eventChan <- ProviderEvent{Type: EventContentStop}
391
392 if finalResp != nil {
393 finishReason := message.FinishReasonEndTurn
394 if len(finalResp.Candidates) > 0 {
395 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
396 }
397 if len(toolCalls) > 0 {
398 finishReason = message.FinishReasonToolUse
399 }
400 eventChan <- ProviderEvent{
401 Type: EventComplete,
402 Response: &ProviderResponse{
403 Content: currentContent,
404 ToolCalls: toolCalls,
405 Usage: g.usage(finalResp),
406 FinishReason: finishReason,
407 },
408 }
409 return
410 } else {
411 eventChan <- ProviderEvent{
412 Type: EventError,
413 Error: errors.New("no content received"),
414 }
415 }
416 }
417 }()
418
419 return eventChan
420}
421
422func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
423 // Check if error is a rate limit error
424 if attempts > maxRetries {
425 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
426 }
427
428 // Gemini doesn't have a standard error type we can check against
429 // So we'll check the error message for rate limit indicators
430 if errors.Is(err, io.EOF) {
431 return false, 0, err
432 }
433
434 errMsg := err.Error()
435 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
436
437 // Check for token expiration (401 Unauthorized)
438 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
439 prev := g.providerOptions.apiKey
440 // in case the key comes from a script, we try to re-evaluate it.
441 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
442 if err != nil {
443 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
444 }
445 // if it didn't change, do not retry.
446 if prev == g.providerOptions.apiKey {
447 return false, 0, err
448 }
449 g.client, err = createGeminiClient(g.providerOptions)
450 if err != nil {
451 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
452 }
453 return true, 0, nil
454 }
455
456 // Check for common rate limit error messages
457
458 if !isRateLimit {
459 return false, 0, err
460 }
461
462 // Calculate backoff with jitter
463 backoffMs := 2000 * (1 << (attempts - 1))
464 jitterMs := int(float64(backoffMs) * 0.2)
465 retryMs := backoffMs + jitterMs
466
467 return true, int64(retryMs), nil
468}
469
470func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
471 if resp == nil || resp.UsageMetadata == nil {
472 return TokenUsage{}
473 }
474
475 return TokenUsage{
476 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
477 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
478 CacheCreationTokens: 0, // Not directly provided by Gemini
479 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
480 }
481}
482
483func (g *geminiClient) Model() catwalk.Model {
484 return g.providerOptions.model(g.providerOptions.modelType)
485}
486
487// Helper functions
488func parseJSONToMap(jsonStr string) (map[string]any, error) {
489 var result map[string]any
490 err := json.Unmarshal([]byte(jsonStr), &result)
491 return result, err
492}
493
494func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
495 properties := make(map[string]*genai.Schema)
496
497 for name, param := range parameters {
498 properties[name] = convertToSchema(param)
499 }
500
501 return properties
502}
503
504func convertToSchema(param any) *genai.Schema {
505 schema := &genai.Schema{Type: genai.TypeString}
506
507 paramMap, ok := param.(map[string]any)
508 if !ok {
509 return schema
510 }
511
512 if desc, ok := paramMap["description"].(string); ok {
513 schema.Description = desc
514 }
515
516 typeVal, hasType := paramMap["type"]
517 if !hasType {
518 return schema
519 }
520
521 typeStr, ok := typeVal.(string)
522 if !ok {
523 return schema
524 }
525
526 schema.Type = mapJSONTypeToGenAI(typeStr)
527
528 switch typeStr {
529 case "array":
530 schema.Items = processArrayItems(paramMap)
531 case "object":
532 if props, ok := paramMap["properties"].(map[string]any); ok {
533 schema.Properties = convertSchemaProperties(props)
534 }
535 }
536
537 return schema
538}
539
540func processArrayItems(paramMap map[string]any) *genai.Schema {
541 items, ok := paramMap["items"].(map[string]any)
542 if !ok {
543 return nil
544 }
545
546 return convertToSchema(items)
547}
548
549func mapJSONTypeToGenAI(jsonType string) genai.Type {
550 switch jsonType {
551 case "string":
552 return genai.TypeString
553 case "number":
554 return genai.TypeNumber
555 case "integer":
556 return genai.TypeInteger
557 case "boolean":
558 return genai.TypeBoolean
559 case "array":
560 return genai.TypeArray
561 case "object":
562 return genai.TypeObject
563 default:
564 return genai.TypeString // Default to string for unknown types
565 }
566}
567
568func contains(s string, substrs ...string) bool {
569 for _, substr := range substrs {
570 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
571 return true
572 }
573 }
574 return false
575}