1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22type geminiClient struct {
23 providerOptions providerClientOptions
24 client *genai.Client
25}
26
27type GeminiClient ProviderClient
28
29func newGeminiClient(opts providerClientOptions) GeminiClient {
30 client, err := createGeminiClient(opts)
31 if err != nil {
32 slog.Error("Failed to create Gemini client", "error", err)
33 return nil
34 }
35
36 return &geminiClient{
37 providerOptions: opts,
38 client: client,
39 }
40}
41
42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
43 cc := &genai.ClientConfig{
44 APIKey: opts.apiKey,
45 Backend: genai.BackendGeminiAPI,
46 }
47 if opts.baseURL != "" {
48 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
49 if err == nil && resolvedBaseURL != "" {
50 cc.HTTPOptions = genai.HTTPOptions{
51 BaseURL: resolvedBaseURL,
52 }
53 }
54 }
55 if config.Get().Options.Debug {
56 cc.HTTPClient = log.NewHTTPClient()
57 }
58 client, err := genai.NewClient(context.Background(), cc)
59 if err != nil {
60 return nil, err
61 }
62 return client, nil
63}
64
65func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
66 var history []*genai.Content
67 for _, msg := range messages {
68 switch msg.Role {
69 case message.User:
70 var parts []*genai.Part
71 parts = append(parts, &genai.Part{Text: msg.Content().String()})
72 for _, binaryContent := range msg.BinaryContent() {
73 imageFormat := strings.Split(binaryContent.MIMEType, "/")
74 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
75 MIMEType: imageFormat[1],
76 Data: binaryContent.Data,
77 }})
78 }
79 history = append(history, &genai.Content{
80 Parts: parts,
81 Role: genai.RoleUser,
82 })
83 case message.Assistant:
84 var assistantParts []*genai.Part
85
86 if msg.Content().String() != "" {
87 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
88 }
89
90 if len(msg.ToolCalls()) > 0 {
91 for _, call := range msg.ToolCalls() {
92 if !call.Finished {
93 continue
94 }
95 args, _ := parseJSONToMap(call.Input)
96 assistantParts = append(assistantParts, &genai.Part{
97 FunctionCall: &genai.FunctionCall{
98 Name: call.Name,
99 Args: args,
100 },
101 })
102 }
103 }
104
105 if len(assistantParts) > 0 {
106 history = append(history, &genai.Content{
107 Role: genai.RoleModel,
108 Parts: assistantParts,
109 })
110 }
111
112 case message.Tool:
113 var toolParts []*genai.Part
114 for _, result := range msg.ToolResults() {
115 response := map[string]any{"result": result.Content}
116 parsed, err := parseJSONToMap(result.Content)
117 if err == nil {
118 response = parsed
119 }
120
121 var toolCall message.ToolCall
122 for _, m := range messages {
123 if m.Role == message.Assistant {
124 for _, call := range m.ToolCalls() {
125 if call.ID == result.ToolCallID {
126 toolCall = call
127 break
128 }
129 }
130 }
131 }
132
133 toolParts = append(toolParts, &genai.Part{
134 FunctionResponse: &genai.FunctionResponse{
135 Name: toolCall.Name,
136 Response: response,
137 },
138 })
139 }
140 if len(toolParts) > 0 {
141 history = append(history, &genai.Content{
142 Parts: toolParts,
143 Role: genai.RoleUser,
144 })
145 }
146 }
147 }
148
149 return history
150}
151
152func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
153 geminiTool := &genai.Tool{}
154 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
155
156 for _, tool := range tools {
157 info := tool.Info()
158 declaration := &genai.FunctionDeclaration{
159 Name: info.Name,
160 Description: info.Description,
161 Parameters: &genai.Schema{
162 Type: genai.TypeObject,
163 Properties: convertSchemaProperties(info.Parameters),
164 Required: info.Required,
165 },
166 }
167
168 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
169 }
170
171 return []*genai.Tool{geminiTool}
172}
173
174func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
175 switch reason {
176 case genai.FinishReasonStop:
177 return message.FinishReasonEndTurn
178 case genai.FinishReasonMaxTokens:
179 return message.FinishReasonMaxTokens
180 default:
181 return message.FinishReasonUnknown
182 }
183}
184
185func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
186 // Convert messages
187 geminiMessages := g.convertMessages(messages)
188 model := g.providerOptions.model(g.providerOptions.modelType)
189 cfg := config.Get()
190
191 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
192 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
193 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
194 }
195
196 maxTokens := model.DefaultMaxTokens
197 if modelConfig.MaxTokens > 0 {
198 maxTokens = modelConfig.MaxTokens
199 }
200 systemMessage := g.providerOptions.systemMessage
201 if g.providerOptions.systemPromptPrefix != "" {
202 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
203 }
204 history := geminiMessages[:len(geminiMessages)-1] // All but last message
205 lastMsg := geminiMessages[len(geminiMessages)-1]
206 config := &genai.GenerateContentConfig{
207 MaxOutputTokens: int32(maxTokens),
208 SystemInstruction: &genai.Content{
209 Parts: []*genai.Part{{Text: systemMessage}},
210 },
211 }
212 config.Tools = g.convertTools(tools)
213 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
214
215 attempts := 0
216 for {
217 attempts++
218 var toolCalls []message.ToolCall
219
220 var lastMsgParts []genai.Part
221 for _, part := range lastMsg.Parts {
222 lastMsgParts = append(lastMsgParts, *part)
223 }
224 resp, err := chat.SendMessage(ctx, lastMsgParts...)
225 // If there is an error we are going to see if we can retry the call
226 if err != nil {
227 retry, after, retryErr := g.shouldRetry(attempts, err)
228 if retryErr != nil {
229 return nil, retryErr
230 }
231 if retry {
232 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
233 select {
234 case <-ctx.Done():
235 return nil, ctx.Err()
236 case <-time.After(time.Duration(after) * time.Millisecond):
237 continue
238 }
239 }
240 return nil, retryErr
241 }
242
243 content := ""
244
245 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
246 for _, part := range resp.Candidates[0].Content.Parts {
247 switch {
248 case part.Text != "":
249 content = string(part.Text)
250 case part.FunctionCall != nil:
251 id := "call_" + uuid.New().String()
252 args, _ := json.Marshal(part.FunctionCall.Args)
253 toolCalls = append(toolCalls, message.ToolCall{
254 ID: id,
255 Name: part.FunctionCall.Name,
256 Input: string(args),
257 Type: "function",
258 Finished: true,
259 })
260 }
261 }
262 }
263 finishReason := message.FinishReasonEndTurn
264 if len(resp.Candidates) > 0 {
265 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
266 }
267 if len(toolCalls) > 0 {
268 finishReason = message.FinishReasonToolUse
269 }
270
271 return &ProviderResponse{
272 Content: content,
273 ToolCalls: toolCalls,
274 Usage: g.usage(resp),
275 FinishReason: finishReason,
276 }, nil
277 }
278}
279
280func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
281 // Convert messages
282 geminiMessages := g.convertMessages(messages)
283
284 model := g.providerOptions.model(g.providerOptions.modelType)
285 cfg := config.Get()
286
287 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
288 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
289 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
290 }
291 maxTokens := model.DefaultMaxTokens
292 if modelConfig.MaxTokens > 0 {
293 maxTokens = modelConfig.MaxTokens
294 }
295
296 // Override max tokens if set in provider options
297 if g.providerOptions.maxTokens > 0 {
298 maxTokens = g.providerOptions.maxTokens
299 }
300 systemMessage := g.providerOptions.systemMessage
301 if g.providerOptions.systemPromptPrefix != "" {
302 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
303 }
304 history := geminiMessages[:len(geminiMessages)-1] // All but last message
305 lastMsg := geminiMessages[len(geminiMessages)-1]
306 config := &genai.GenerateContentConfig{
307 MaxOutputTokens: int32(maxTokens),
308 SystemInstruction: &genai.Content{
309 Parts: []*genai.Part{{Text: systemMessage}},
310 },
311 }
312 config.Tools = g.convertTools(tools)
313 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
314
315 attempts := 0
316 eventChan := make(chan ProviderEvent)
317
318 go func() {
319 defer close(eventChan)
320
321 for {
322 attempts++
323
324 currentContent := ""
325 toolCalls := []message.ToolCall{}
326 var finalResp *genai.GenerateContentResponse
327
328 eventChan <- ProviderEvent{Type: EventContentStart}
329
330 var lastMsgParts []genai.Part
331
332 for _, part := range lastMsg.Parts {
333 lastMsgParts = append(lastMsgParts, *part)
334 }
335
336 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
337 if err != nil {
338 retry, after, retryErr := g.shouldRetry(attempts, err)
339 if retryErr != nil {
340 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
341 return
342 }
343 if retry {
344 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
345 select {
346 case <-ctx.Done():
347 if ctx.Err() != nil {
348 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
349 }
350
351 return
352 case <-time.After(time.Duration(after) * time.Millisecond):
353 continue
354 }
355 } else {
356 eventChan <- ProviderEvent{Type: EventError, Error: err}
357 return
358 }
359 }
360
361 finalResp = resp
362
363 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
364 for _, part := range resp.Candidates[0].Content.Parts {
365 switch {
366 case part.Text != "":
367 delta := string(part.Text)
368 if delta != "" {
369 eventChan <- ProviderEvent{
370 Type: EventContentDelta,
371 Content: delta,
372 }
373 currentContent += delta
374 }
375 case part.FunctionCall != nil:
376 id := "call_" + uuid.New().String()
377 args, _ := json.Marshal(part.FunctionCall.Args)
378 newCall := message.ToolCall{
379 ID: id,
380 Name: part.FunctionCall.Name,
381 Input: string(args),
382 Type: "function",
383 Finished: true,
384 }
385
386 toolCalls = append(toolCalls, newCall)
387 }
388 }
389 } else {
390 // no content received
391 break
392 }
393 }
394
395 eventChan <- ProviderEvent{Type: EventContentStop}
396
397 if finalResp != nil {
398 finishReason := message.FinishReasonEndTurn
399 if len(finalResp.Candidates) > 0 {
400 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
401 }
402 if len(toolCalls) > 0 {
403 finishReason = message.FinishReasonToolUse
404 }
405 eventChan <- ProviderEvent{
406 Type: EventComplete,
407 Response: &ProviderResponse{
408 Content: currentContent,
409 ToolCalls: toolCalls,
410 Usage: g.usage(finalResp),
411 FinishReason: finishReason,
412 },
413 }
414 return
415 } else {
416 eventChan <- ProviderEvent{
417 Type: EventError,
418 Error: errors.New("no content received"),
419 }
420 }
421 }
422 }()
423
424 return eventChan
425}
426
427func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
428 // Check if error is a rate limit error
429 if attempts > maxRetries {
430 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
431 }
432
433 // Gemini doesn't have a standard error type we can check against
434 // So we'll check the error message for rate limit indicators
435 if errors.Is(err, io.EOF) {
436 return false, 0, err
437 }
438
439 errMsg := err.Error()
440 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
441
442 // Check for token expiration (401 Unauthorized)
443 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
444 prev := g.providerOptions.apiKey
445 // in case the key comes from a script, we try to re-evaluate it.
446 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
447 if err != nil {
448 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
449 }
450 // if it didn't change, do not retry.
451 if prev == g.providerOptions.apiKey {
452 return false, 0, err
453 }
454 g.client, err = createGeminiClient(g.providerOptions)
455 if err != nil {
456 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
457 }
458 return true, 0, nil
459 }
460
461 // Check for common rate limit error messages
462
463 if !isRateLimit {
464 return false, 0, err
465 }
466
467 // Calculate backoff with jitter
468 backoffMs := 2000 * (1 << (attempts - 1))
469 jitterMs := int(float64(backoffMs) * 0.2)
470 retryMs := backoffMs + jitterMs
471
472 return true, int64(retryMs), nil
473}
474
475func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
476 if resp == nil || resp.UsageMetadata == nil {
477 return TokenUsage{}
478 }
479
480 return TokenUsage{
481 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
482 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
483 CacheCreationTokens: 0, // Not directly provided by Gemini
484 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
485 }
486}
487
488func (g *geminiClient) Model() catwalk.Model {
489 return g.providerOptions.model(g.providerOptions.modelType)
490}
491
492// Helper functions
493func parseJSONToMap(jsonStr string) (map[string]any, error) {
494 var result map[string]any
495 err := json.Unmarshal([]byte(jsonStr), &result)
496 return result, err
497}
498
499func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
500 properties := make(map[string]*genai.Schema)
501
502 for name, param := range parameters {
503 properties[name] = convertToSchema(param)
504 }
505
506 return properties
507}
508
509func convertToSchema(param any) *genai.Schema {
510 schema := &genai.Schema{Type: genai.TypeString}
511
512 paramMap, ok := param.(map[string]any)
513 if !ok {
514 return schema
515 }
516
517 if desc, ok := paramMap["description"].(string); ok {
518 schema.Description = desc
519 }
520
521 typeVal, hasType := paramMap["type"]
522 if !hasType {
523 return schema
524 }
525
526 typeStr, ok := typeVal.(string)
527 if !ok {
528 return schema
529 }
530
531 schema.Type = mapJSONTypeToGenAI(typeStr)
532
533 switch typeStr {
534 case "array":
535 schema.Items = processArrayItems(paramMap)
536 case "object":
537 if props, ok := paramMap["properties"].(map[string]any); ok {
538 schema.Properties = convertSchemaProperties(props)
539 }
540 }
541
542 return schema
543}
544
545func processArrayItems(paramMap map[string]any) *genai.Schema {
546 items, ok := paramMap["items"].(map[string]any)
547 if !ok {
548 return nil
549 }
550
551 return convertToSchema(items)
552}
553
554func mapJSONTypeToGenAI(jsonType string) genai.Type {
555 switch jsonType {
556 case "string":
557 return genai.TypeString
558 case "number":
559 return genai.TypeNumber
560 case "integer":
561 return genai.TypeInteger
562 case "boolean":
563 return genai.TypeBoolean
564 case "array":
565 return genai.TypeArray
566 case "object":
567 return genai.TypeObject
568 default:
569 return genai.TypeString // Default to string for unknown types
570 }
571}
572
573func contains(s string, substrs ...string) bool {
574 for _, substr := range substrs {
575 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
576 return true
577 }
578 }
579 return false
580}