1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/google/uuid"
18 "google.golang.org/genai"
19)
20
21type geminiClient struct {
22 providerOptions providerClientOptions
23 client *genai.Client
24}
25
26type GeminiClient ProviderClient
27
28func newGeminiClient(opts providerClientOptions) GeminiClient {
29 client, err := createGeminiClient(opts)
30 if err != nil {
31 slog.Error("Failed to create Gemini client", "error", err)
32 return nil
33 }
34
35 return &geminiClient{
36 providerOptions: opts,
37 client: client,
38 }
39}
40
41func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
42 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
43 if err != nil {
44 return nil, err
45 }
46 return client, nil
47}
48
49func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
50 var history []*genai.Content
51 for _, msg := range messages {
52 switch msg.Role {
53 case message.User:
54 var parts []*genai.Part
55 parts = append(parts, &genai.Part{Text: msg.Content().String()})
56 for _, binaryContent := range msg.BinaryContent() {
57 imageFormat := strings.Split(binaryContent.MIMEType, "/")
58 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
59 MIMEType: imageFormat[1],
60 Data: binaryContent.Data,
61 }})
62 }
63 history = append(history, &genai.Content{
64 Parts: parts,
65 Role: "user",
66 })
67 case message.Assistant:
68 var assistantParts []*genai.Part
69
70 if msg.Content().String() != "" {
71 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
72 }
73
74 if len(msg.ToolCalls()) > 0 {
75 for _, call := range msg.ToolCalls() {
76 args, _ := parseJSONToMap(call.Input)
77 assistantParts = append(assistantParts, &genai.Part{
78 FunctionCall: &genai.FunctionCall{
79 Name: call.Name,
80 Args: args,
81 },
82 })
83 }
84 }
85
86 if len(assistantParts) > 0 {
87 history = append(history, &genai.Content{
88 Role: "model",
89 Parts: assistantParts,
90 })
91 }
92
93 case message.Tool:
94 for _, result := range msg.ToolResults() {
95 response := map[string]any{"result": result.Content}
96 parsed, err := parseJSONToMap(result.Content)
97 if err == nil {
98 response = parsed
99 }
100
101 var toolCall message.ToolCall
102 for _, m := range messages {
103 if m.Role == message.Assistant {
104 for _, call := range m.ToolCalls() {
105 if call.ID == result.ToolCallID {
106 toolCall = call
107 break
108 }
109 }
110 }
111 }
112
113 history = append(history, &genai.Content{
114 Parts: []*genai.Part{
115 {
116 FunctionResponse: &genai.FunctionResponse{
117 Name: toolCall.Name,
118 Response: response,
119 },
120 },
121 },
122 Role: "function",
123 })
124 }
125 }
126 }
127
128 return history
129}
130
131func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
132 geminiTool := &genai.Tool{}
133 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
134
135 for _, tool := range tools {
136 info := tool.Info()
137 declaration := &genai.FunctionDeclaration{
138 Name: info.Name,
139 Description: info.Description,
140 Parameters: &genai.Schema{
141 Type: genai.TypeObject,
142 Properties: convertSchemaProperties(info.Parameters),
143 Required: info.Required,
144 },
145 }
146
147 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
148 }
149
150 return []*genai.Tool{geminiTool}
151}
152
153func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
154 switch reason {
155 case genai.FinishReasonStop:
156 return message.FinishReasonEndTurn
157 case genai.FinishReasonMaxTokens:
158 return message.FinishReasonMaxTokens
159 default:
160 return message.FinishReasonUnknown
161 }
162}
163
164func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
165 // Convert messages
166 geminiMessages := g.convertMessages(messages)
167 model := g.providerOptions.model(g.providerOptions.modelType)
168 cfg := config.Get()
169 if cfg.Options.Debug {
170 jsonData, _ := json.Marshal(geminiMessages)
171 slog.Debug("Prepared messages", "messages", string(jsonData))
172 }
173
174 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
175 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
176 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
177 }
178
179 maxTokens := model.DefaultMaxTokens
180 if modelConfig.MaxTokens > 0 {
181 maxTokens = modelConfig.MaxTokens
182 }
183 systemMessage := g.providerOptions.systemMessage
184 if g.providerOptions.systemPromptPrefix != "" {
185 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
186 }
187 history := geminiMessages[:len(geminiMessages)-1] // All but last message
188 lastMsg := geminiMessages[len(geminiMessages)-1]
189 config := &genai.GenerateContentConfig{
190 MaxOutputTokens: int32(maxTokens),
191 SystemInstruction: &genai.Content{
192 Parts: []*genai.Part{{Text: systemMessage}},
193 },
194 }
195 config.Tools = g.convertTools(tools)
196 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
197
198 attempts := 0
199 for {
200 attempts++
201 var toolCalls []message.ToolCall
202
203 var lastMsgParts []genai.Part
204 for _, part := range lastMsg.Parts {
205 lastMsgParts = append(lastMsgParts, *part)
206 }
207 resp, err := chat.SendMessage(ctx, lastMsgParts...)
208 // If there is an error we are going to see if we can retry the call
209 if err != nil {
210 retry, after, retryErr := g.shouldRetry(attempts, err)
211 if retryErr != nil {
212 return nil, retryErr
213 }
214 if retry {
215 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
216 select {
217 case <-ctx.Done():
218 return nil, ctx.Err()
219 case <-time.After(time.Duration(after) * time.Millisecond):
220 continue
221 }
222 }
223 return nil, retryErr
224 }
225
226 content := ""
227
228 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
229 for _, part := range resp.Candidates[0].Content.Parts {
230 switch {
231 case part.Text != "":
232 content = string(part.Text)
233 case part.FunctionCall != nil:
234 id := "call_" + uuid.New().String()
235 args, _ := json.Marshal(part.FunctionCall.Args)
236 toolCalls = append(toolCalls, message.ToolCall{
237 ID: id,
238 Name: part.FunctionCall.Name,
239 Input: string(args),
240 Type: "function",
241 Finished: true,
242 })
243 }
244 }
245 }
246 finishReason := message.FinishReasonEndTurn
247 if len(resp.Candidates) > 0 {
248 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
249 }
250 if len(toolCalls) > 0 {
251 finishReason = message.FinishReasonToolUse
252 }
253
254 return &ProviderResponse{
255 Content: content,
256 ToolCalls: toolCalls,
257 Usage: g.usage(resp),
258 FinishReason: finishReason,
259 }, nil
260 }
261}
262
263func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
264 // Convert messages
265 geminiMessages := g.convertMessages(messages)
266
267 model := g.providerOptions.model(g.providerOptions.modelType)
268 cfg := config.Get()
269 if cfg.Options.Debug {
270 jsonData, _ := json.Marshal(geminiMessages)
271 slog.Debug("Prepared messages", "messages", string(jsonData))
272 }
273
274 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
275 if g.providerOptions.modelType == config.SelectedModelTypeSmall {
276 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
277 }
278 maxTokens := model.DefaultMaxTokens
279 if modelConfig.MaxTokens > 0 {
280 maxTokens = modelConfig.MaxTokens
281 }
282
283 // Override max tokens if set in provider options
284 if g.providerOptions.maxTokens > 0 {
285 maxTokens = g.providerOptions.maxTokens
286 }
287 systemMessage := g.providerOptions.systemMessage
288 if g.providerOptions.systemPromptPrefix != "" {
289 systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
290 }
291 history := geminiMessages[:len(geminiMessages)-1] // All but last message
292 lastMsg := geminiMessages[len(geminiMessages)-1]
293 config := &genai.GenerateContentConfig{
294 MaxOutputTokens: int32(maxTokens),
295 SystemInstruction: &genai.Content{
296 Parts: []*genai.Part{{Text: systemMessage}},
297 },
298 }
299 config.Tools = g.convertTools(tools)
300 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
301
302 attempts := 0
303 eventChan := make(chan ProviderEvent)
304
305 go func() {
306 defer close(eventChan)
307
308 for {
309 attempts++
310
311 currentContent := ""
312 toolCalls := []message.ToolCall{}
313 var finalResp *genai.GenerateContentResponse
314
315 eventChan <- ProviderEvent{Type: EventContentStart}
316
317 var lastMsgParts []genai.Part
318
319 for _, part := range lastMsg.Parts {
320 lastMsgParts = append(lastMsgParts, *part)
321 }
322 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
323 if err != nil {
324 retry, after, retryErr := g.shouldRetry(attempts, err)
325 if retryErr != nil {
326 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
327 return
328 }
329 if retry {
330 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
331 select {
332 case <-ctx.Done():
333 if ctx.Err() != nil {
334 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
335 }
336
337 return
338 case <-time.After(time.Duration(after) * time.Millisecond):
339 break
340 }
341 } else {
342 eventChan <- ProviderEvent{Type: EventError, Error: err}
343 return
344 }
345 }
346
347 finalResp = resp
348
349 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
350 for _, part := range resp.Candidates[0].Content.Parts {
351 switch {
352 case part.Text != "":
353 delta := string(part.Text)
354 if delta != "" {
355 eventChan <- ProviderEvent{
356 Type: EventContentDelta,
357 Content: delta,
358 }
359 currentContent += delta
360 }
361 case part.FunctionCall != nil:
362 id := "call_" + uuid.New().String()
363 args, _ := json.Marshal(part.FunctionCall.Args)
364 newCall := message.ToolCall{
365 ID: id,
366 Name: part.FunctionCall.Name,
367 Input: string(args),
368 Type: "function",
369 Finished: true,
370 }
371
372 isNew := true
373 for _, existing := range toolCalls {
374 if existing.Name == newCall.Name && existing.Input == newCall.Input {
375 isNew = false
376 break
377 }
378 }
379
380 if isNew {
381 toolCalls = append(toolCalls, newCall)
382 }
383 }
384 }
385 }
386 }
387
388 eventChan <- ProviderEvent{Type: EventContentStop}
389
390 if finalResp != nil {
391 finishReason := message.FinishReasonEndTurn
392 if len(finalResp.Candidates) > 0 {
393 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
394 }
395 if len(toolCalls) > 0 {
396 finishReason = message.FinishReasonToolUse
397 }
398 eventChan <- ProviderEvent{
399 Type: EventComplete,
400 Response: &ProviderResponse{
401 Content: currentContent,
402 ToolCalls: toolCalls,
403 Usage: g.usage(finalResp),
404 FinishReason: finishReason,
405 },
406 }
407 return
408 }
409 }
410 }()
411
412 return eventChan
413}
414
415func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
416 // Check if error is a rate limit error
417 if attempts > maxRetries {
418 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
419 }
420
421 // Gemini doesn't have a standard error type we can check against
422 // So we'll check the error message for rate limit indicators
423 if errors.Is(err, io.EOF) {
424 return false, 0, err
425 }
426
427 errMsg := err.Error()
428 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
429
430 // Check for token expiration (401 Unauthorized)
431 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
432 g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
433 if err != nil {
434 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
435 }
436 g.client, err = createGeminiClient(g.providerOptions)
437 if err != nil {
438 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
439 }
440 return true, 0, nil
441 }
442
443 // Check for common rate limit error messages
444
445 if !isRateLimit {
446 return false, 0, err
447 }
448
449 // Calculate backoff with jitter
450 backoffMs := 2000 * (1 << (attempts - 1))
451 jitterMs := int(float64(backoffMs) * 0.2)
452 retryMs := backoffMs + jitterMs
453
454 return true, int64(retryMs), nil
455}
456
457func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
458 if resp == nil || resp.UsageMetadata == nil {
459 return TokenUsage{}
460 }
461
462 return TokenUsage{
463 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
464 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
465 CacheCreationTokens: 0, // Not directly provided by Gemini
466 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
467 }
468}
469
470func (g *geminiClient) Model() catwalk.Model {
471 return g.providerOptions.model(g.providerOptions.modelType)
472}
473
474// Helper functions
475func parseJSONToMap(jsonStr string) (map[string]any, error) {
476 var result map[string]any
477 err := json.Unmarshal([]byte(jsonStr), &result)
478 return result, err
479}
480
481func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
482 properties := make(map[string]*genai.Schema)
483
484 for name, param := range parameters {
485 properties[name] = convertToSchema(param)
486 }
487
488 return properties
489}
490
491func convertToSchema(param any) *genai.Schema {
492 schema := &genai.Schema{Type: genai.TypeString}
493
494 paramMap, ok := param.(map[string]any)
495 if !ok {
496 return schema
497 }
498
499 if desc, ok := paramMap["description"].(string); ok {
500 schema.Description = desc
501 }
502
503 typeVal, hasType := paramMap["type"]
504 if !hasType {
505 return schema
506 }
507
508 typeStr, ok := typeVal.(string)
509 if !ok {
510 return schema
511 }
512
513 schema.Type = mapJSONTypeToGenAI(typeStr)
514
515 switch typeStr {
516 case "array":
517 schema.Items = processArrayItems(paramMap)
518 case "object":
519 if props, ok := paramMap["properties"].(map[string]any); ok {
520 schema.Properties = convertSchemaProperties(props)
521 }
522 }
523
524 return schema
525}
526
527func processArrayItems(paramMap map[string]any) *genai.Schema {
528 items, ok := paramMap["items"].(map[string]any)
529 if !ok {
530 return nil
531 }
532
533 return convertToSchema(items)
534}
535
536func mapJSONTypeToGenAI(jsonType string) genai.Type {
537 switch jsonType {
538 case "string":
539 return genai.TypeString
540 case "number":
541 return genai.TypeNumber
542 case "integer":
543 return genai.TypeInteger
544 case "boolean":
545 return genai.TypeBoolean
546 case "array":
547 return genai.TypeArray
548 case "object":
549 return genai.TypeObject
550 default:
551 return genai.TypeString // Default to string for unknown types
552 }
553}
554
555func contains(s string, substrs ...string) bool {
556 for _, substr := range substrs {
557 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
558 return true
559 }
560 }
561 return false
562}