1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22type geminiClient struct {
23 providerOptions providerClientOptions
24 client *genai.Client
25}
26
27type GeminiClient ProviderClient
28
29func newGeminiClient(opts providerClientOptions) GeminiClient {
30 client, err := createGeminiClient(opts)
31 if err != nil {
32 slog.Error("Failed to create Gemini client", "error", err)
33 return nil
34 }
35
36 return &geminiClient{
37 providerOptions: opts,
38 client: client,
39 }
40}
41
42func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
43 cc := &genai.ClientConfig{
44 APIKey: opts.apiKey,
45 Backend: genai.BackendGeminiAPI,
46 }
47 if opts.cfg.Options.Debug {
48 cc.HTTPClient = log.NewHTTPClient()
49 }
50 client, err := genai.NewClient(context.Background(), cc)
51 if err != nil {
52 return nil, err
53 }
54 return client, nil
55}
56
57func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
58 var history []*genai.Content
59 for _, msg := range messages {
60 switch msg.Role {
61 case message.User:
62 var parts []*genai.Part
63 parts = append(parts, &genai.Part{Text: msg.Content().String()})
64 for _, binaryContent := range msg.BinaryContent() {
65 imageFormat := strings.Split(binaryContent.MIMEType, "/")
66 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
67 MIMEType: imageFormat[1],
68 Data: binaryContent.Data,
69 }})
70 }
71 history = append(history, &genai.Content{
72 Parts: parts,
73 Role: genai.RoleUser,
74 })
75 case message.Assistant:
76 var assistantParts []*genai.Part
77
78 if msg.Content().String() != "" {
79 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
80 }
81
82 if len(msg.ToolCalls()) > 0 {
83 for _, call := range msg.ToolCalls() {
84 if !call.Finished {
85 continue
86 }
87 args, _ := parseJSONToMap(call.Input)
88 assistantParts = append(assistantParts, &genai.Part{
89 FunctionCall: &genai.FunctionCall{
90 Name: call.Name,
91 Args: args,
92 },
93 })
94 }
95 }
96
97 if len(assistantParts) > 0 {
98 history = append(history, &genai.Content{
99 Role: genai.RoleModel,
100 Parts: assistantParts,
101 })
102 }
103
104 case message.Tool:
105 var toolParts []*genai.Part
106 for _, result := range msg.ToolResults() {
107 response := map[string]any{"result": result.Content}
108 parsed, err := parseJSONToMap(result.Content)
109 if err == nil {
110 response = parsed
111 }
112
113 var toolCall message.ToolCall
114 for _, m := range messages {
115 if m.Role == message.Assistant {
116 for _, call := range m.ToolCalls() {
117 if call.ID == result.ToolCallID {
118 toolCall = call
119 break
120 }
121 }
122 }
123 }
124
125 toolParts = append(toolParts, &genai.Part{
126 FunctionResponse: &genai.FunctionResponse{
127 Name: toolCall.Name,
128 Response: response,
129 },
130 })
131 }
132 if len(toolParts) > 0 {
133 history = append(history, &genai.Content{
134 Parts: toolParts,
135 Role: genai.RoleUser,
136 })
137 }
138 }
139 }
140
141 return history
142}
143
144func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
145 geminiTool := &genai.Tool{}
146 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
147
148 for _, tool := range tools {
149 info := tool.Info()
150 declaration := &genai.FunctionDeclaration{
151 Name: info.Name,
152 Description: info.Description,
153 Parameters: &genai.Schema{
154 Type: genai.TypeObject,
155 Properties: convertSchemaProperties(info.Parameters),
156 Required: info.Required,
157 },
158 }
159
160 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
161 }
162
163 return []*genai.Tool{geminiTool}
164}
165
166func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
167 switch reason {
168 case genai.FinishReasonStop:
169 return message.FinishReasonEndTurn
170 case genai.FinishReasonMaxTokens:
171 return message.FinishReasonMaxTokens
172 default:
173 return message.FinishReasonUnknown
174 }
175}
176
177func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
178 // Convert messages
179 geminiMessages := g.convertMessages(messages)
180 model := g.providerOptions.model(g.providerOptions.modelType)
181 cfg := g.providerOptions.cfg
182
183 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
184 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
185 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
186 }
187
188 maxTokens := model.DefaultMaxTokens
189 if modelConfig.MaxTokens > 0 {
190 maxTokens = modelConfig.MaxTokens
191 }
192 systemMessage := g.providerOptions.systemMessage
193 if g.providerOptions.systemPromptPrefix != "" {
194 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
195 }
196 history := geminiMessages[:len(geminiMessages)-1] // All but last message
197 lastMsg := geminiMessages[len(geminiMessages)-1]
198 config := &genai.GenerateContentConfig{
199 MaxOutputTokens: int32(maxTokens),
200 SystemInstruction: &genai.Content{
201 Parts: []*genai.Part{{Text: systemMessage}},
202 },
203 }
204 config.Tools = g.convertTools(tools)
205 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
206
207 attempts := 0
208 for {
209 attempts++
210 var toolCalls []message.ToolCall
211
212 var lastMsgParts []genai.Part
213 for _, part := range lastMsg.Parts {
214 lastMsgParts = append(lastMsgParts, *part)
215 }
216 resp, err := chat.SendMessage(ctx, lastMsgParts...)
217 // If there is an error we are going to see if we can retry the call
218 if err != nil {
219 retry, after, retryErr := g.shouldRetry(attempts, err)
220 if retryErr != nil {
221 return nil, retryErr
222 }
223 if retry {
224 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
225 select {
226 case <-ctx.Done():
227 return nil, ctx.Err()
228 case <-time.After(time.Duration(after) * time.Millisecond):
229 continue
230 }
231 }
232 return nil, retryErr
233 }
234
235 content := ""
236
237 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
238 for _, part := range resp.Candidates[0].Content.Parts {
239 switch {
240 case part.Text != "":
241 content = string(part.Text)
242 case part.FunctionCall != nil:
243 id := "call_" + uuid.New().String()
244 args, _ := json.Marshal(part.FunctionCall.Args)
245 toolCalls = append(toolCalls, message.ToolCall{
246 ID: id,
247 Name: part.FunctionCall.Name,
248 Input: string(args),
249 Type: "function",
250 Finished: true,
251 })
252 }
253 }
254 }
255 finishReason := message.FinishReasonEndTurn
256 if len(resp.Candidates) > 0 {
257 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
258 }
259 if len(toolCalls) > 0 {
260 finishReason = message.FinishReasonToolUse
261 }
262
263 return &ProviderResponse{
264 Content: content,
265 ToolCalls: toolCalls,
266 Usage: g.usage(resp),
267 FinishReason: finishReason,
268 }, nil
269 }
270}
271
272func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
273 // Convert messages
274 geminiMessages := g.convertMessages(messages)
275
276 model := g.providerOptions.model(g.providerOptions.modelType)
277 cfg := g.providerOptions.cfg
278
279 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
280 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
281 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
282 }
283 maxTokens := model.DefaultMaxTokens
284 if modelConfig.MaxTokens > 0 {
285 maxTokens = modelConfig.MaxTokens
286 }
287
288 // Override max tokens if set in provider options
289 if g.providerOptions.maxTokens > 0 {
290 maxTokens = g.providerOptions.maxTokens
291 }
292 systemMessage := g.providerOptions.systemMessage
293 if g.providerOptions.systemPromptPrefix != "" {
294 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
295 }
296 history := geminiMessages[:len(geminiMessages)-1] // All but last message
297 lastMsg := geminiMessages[len(geminiMessages)-1]
298 config := &genai.GenerateContentConfig{
299 MaxOutputTokens: int32(maxTokens),
300 SystemInstruction: &genai.Content{
301 Parts: []*genai.Part{{Text: systemMessage}},
302 },
303 }
304 config.Tools = g.convertTools(tools)
305 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
306
307 attempts := 0
308 eventChan := make(chan ProviderEvent)
309
310 go func() {
311 defer close(eventChan)
312
313 for {
314 attempts++
315
316 currentContent := ""
317 toolCalls := []message.ToolCall{}
318 var finalResp *genai.GenerateContentResponse
319
320 eventChan <- ProviderEvent{Type: EventContentStart}
321
322 var lastMsgParts []genai.Part
323
324 for _, part := range lastMsg.Parts {
325 lastMsgParts = append(lastMsgParts, *part)
326 }
327
328 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
329 if err != nil {
330 retry, after, retryErr := g.shouldRetry(attempts, err)
331 if retryErr != nil {
332 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
333 return
334 }
335 if retry {
336 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err)
337 select {
338 case <-ctx.Done():
339 if ctx.Err() != nil {
340 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
341 }
342
343 return
344 case <-time.After(time.Duration(after) * time.Millisecond):
345 continue
346 }
347 } else {
348 eventChan <- ProviderEvent{Type: EventError, Error: err}
349 return
350 }
351 }
352
353 finalResp = resp
354
355 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
356 for _, part := range resp.Candidates[0].Content.Parts {
357 switch {
358 case part.Text != "":
359 delta := string(part.Text)
360 if delta != "" {
361 eventChan <- ProviderEvent{
362 Type: EventContentDelta,
363 Content: delta,
364 }
365 currentContent += delta
366 }
367 case part.FunctionCall != nil:
368 id := "call_" + uuid.New().String()
369 args, _ := json.Marshal(part.FunctionCall.Args)
370 newCall := message.ToolCall{
371 ID: id,
372 Name: part.FunctionCall.Name,
373 Input: string(args),
374 Type: "function",
375 Finished: true,
376 }
377
378 toolCalls = append(toolCalls, newCall)
379 }
380 }
381 } else {
382 // no content received
383 break
384 }
385 }
386
387 eventChan <- ProviderEvent{Type: EventContentStop}
388
389 if finalResp != nil {
390 finishReason := message.FinishReasonEndTurn
391 if len(finalResp.Candidates) > 0 {
392 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
393 }
394 if len(toolCalls) > 0 {
395 finishReason = message.FinishReasonToolUse
396 }
397 eventChan <- ProviderEvent{
398 Type: EventComplete,
399 Response: &ProviderResponse{
400 Content: currentContent,
401 ToolCalls: toolCalls,
402 Usage: g.usage(finalResp),
403 FinishReason: finishReason,
404 },
405 }
406 return
407 } else {
408 eventChan <- ProviderEvent{
409 Type: EventError,
410 Error: errors.New("no content received"),
411 }
412 }
413 }
414 }()
415
416 return eventChan
417}
418
419func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
420 // Check if error is a rate limit error
421 if attempts > maxRetries {
422 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
423 }
424
425 // Gemini doesn't have a standard error type we can check against
426 // So we'll check the error message for rate limit indicators
427 if errors.Is(err, io.EOF) {
428 return false, 0, err
429 }
430
431 errMsg := err.Error()
432 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
433
434 // Check for token expiration (401 Unauthorized)
435 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
436 g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey)
437 if err != nil {
438 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
439 }
440 g.client, err = createGeminiClient(g.providerOptions)
441 if err != nil {
442 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
443 }
444 return true, 0, nil
445 }
446
447 // Check for common rate limit error messages
448
449 if !isRateLimit {
450 return false, 0, err
451 }
452
453 // Calculate backoff with jitter
454 backoffMs := 2000 * (1 << (attempts - 1))
455 jitterMs := int(float64(backoffMs) * 0.2)
456 retryMs := backoffMs + jitterMs
457
458 return true, int64(retryMs), nil
459}
460
461func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
462 if resp == nil || resp.UsageMetadata == nil {
463 return TokenUsage{}
464 }
465
466 return TokenUsage{
467 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
468 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
469 CacheCreationTokens: 0, // Not directly provided by Gemini
470 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
471 }
472}
473
474func (g *geminiClient) Model() catwalk.Model {
475 return g.providerOptions.model(g.providerOptions.modelType)
476}
477
478// Helper functions
479func parseJSONToMap(jsonStr string) (map[string]any, error) {
480 var result map[string]any
481 err := json.Unmarshal([]byte(jsonStr), &result)
482 return result, err
483}
484
485func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
486 properties := make(map[string]*genai.Schema)
487
488 for name, param := range parameters {
489 properties[name] = convertToSchema(param)
490 }
491
492 return properties
493}
494
495func convertToSchema(param any) *genai.Schema {
496 schema := &genai.Schema{Type: genai.TypeString}
497
498 paramMap, ok := param.(map[string]any)
499 if !ok {
500 return schema
501 }
502
503 if desc, ok := paramMap["description"].(string); ok {
504 schema.Description = desc
505 }
506
507 typeVal, hasType := paramMap["type"]
508 if !hasType {
509 return schema
510 }
511
512 typeStr, ok := typeVal.(string)
513 if !ok {
514 return schema
515 }
516
517 schema.Type = mapJSONTypeToGenAI(typeStr)
518
519 switch typeStr {
520 case "array":
521 schema.Items = processArrayItems(paramMap)
522 case "object":
523 if props, ok := paramMap["properties"].(map[string]any); ok {
524 schema.Properties = convertSchemaProperties(props)
525 }
526 }
527
528 return schema
529}
530
531func processArrayItems(paramMap map[string]any) *genai.Schema {
532 items, ok := paramMap["items"].(map[string]any)
533 if !ok {
534 return nil
535 }
536
537 return convertToSchema(items)
538}
539
540func mapJSONTypeToGenAI(jsonType string) genai.Type {
541 switch jsonType {
542 case "string":
543 return genai.TypeString
544 case "number":
545 return genai.TypeNumber
546 case "integer":
547 return genai.TypeInteger
548 case "boolean":
549 return genai.TypeBoolean
550 case "array":
551 return genai.TypeArray
552 case "object":
553 return genai.TypeObject
554 default:
555 return genai.TypeString // Default to string for unknown types
556 }
557}
558
559func contains(s string, substrs ...string) bool {
560 for _, substr := range substrs {
561 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
562 return true
563 }
564 }
565 return false
566}