1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/message"
15 "github.com/google/uuid"
16 "google.golang.org/genai"
17)
18
19type geminiProvider struct {
20 *baseProvider
21 client *genai.Client
22}
23
24func NewGeminiProvider(base *baseProvider) Provider {
25 client, err := createGeminiClient(base)
26 if err != nil {
27 slog.Error("Failed to create Gemini client", "error", err)
28 return nil
29 }
30
31 return &geminiProvider{
32 baseProvider: base,
33 client: client,
34 }
35}
36
37func createGeminiClient(base *baseProvider) (*genai.Client, error) {
38 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: base.apiKey, Backend: genai.BackendGeminiAPI})
39 if err != nil {
40 return nil, err
41 }
42 return client, nil
43}
44
45func (g *geminiProvider) convertMessages(messages []message.Message) []*genai.Content {
46 var history []*genai.Content
47 for _, msg := range messages {
48 switch msg.Role {
49 case message.User:
50 var parts []*genai.Part
51 parts = append(parts, &genai.Part{Text: msg.Content().String()})
52 for _, binaryContent := range msg.BinaryContent() {
53 imageFormat := strings.Split(binaryContent.MIMEType, "/")
54 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
55 MIMEType: imageFormat[1],
56 Data: binaryContent.Data,
57 }})
58 }
59 history = append(history, &genai.Content{
60 Parts: parts,
61 Role: "user",
62 })
63 case message.Assistant:
64 var assistantParts []*genai.Part
65
66 if msg.Content().String() != "" {
67 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
68 }
69
70 if len(msg.ToolCalls()) > 0 {
71 for _, call := range msg.ToolCalls() {
72 args, _ := parseJSONToMap(call.Input)
73 assistantParts = append(assistantParts, &genai.Part{
74 FunctionCall: &genai.FunctionCall{
75 Name: call.Name,
76 Args: args,
77 },
78 })
79 }
80 }
81
82 if len(assistantParts) > 0 {
83 history = append(history, &genai.Content{
84 Role: "model",
85 Parts: assistantParts,
86 })
87 }
88
89 case message.Tool:
90 for _, result := range msg.ToolResults() {
91 response := map[string]any{"result": result.Content}
92 parsed, err := parseJSONToMap(result.Content)
93 if err == nil {
94 response = parsed
95 }
96
97 var toolCall message.ToolCall
98 for _, m := range messages {
99 if m.Role == message.Assistant {
100 for _, call := range m.ToolCalls() {
101 if call.ID == result.ToolCallID {
102 toolCall = call
103 break
104 }
105 }
106 }
107 }
108
109 history = append(history, &genai.Content{
110 Parts: []*genai.Part{
111 {
112 FunctionResponse: &genai.FunctionResponse{
113 Name: toolCall.Name,
114 Response: response,
115 },
116 },
117 },
118 Role: "function",
119 })
120 }
121 }
122 }
123
124 return history
125}
126
127func (g *geminiProvider) convertTools(tools []tools.BaseTool) []*genai.Tool {
128 geminiTool := &genai.Tool{}
129 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
130
131 for _, tool := range tools {
132 info := tool.Info()
133 declaration := &genai.FunctionDeclaration{
134 Name: info.Name,
135 Description: info.Description,
136 Parameters: &genai.Schema{
137 Type: genai.TypeObject,
138 Properties: convertSchemaProperties(info.Parameters),
139 Required: info.Required,
140 },
141 }
142
143 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
144 }
145
146 return []*genai.Tool{geminiTool}
147}
148
149func (g *geminiProvider) finishReason(reason genai.FinishReason) message.FinishReason {
150 switch reason {
151 case genai.FinishReasonStop:
152 return message.FinishReasonEndTurn
153 case genai.FinishReasonMaxTokens:
154 return message.FinishReasonMaxTokens
155 default:
156 return message.FinishReasonUnknown
157 }
158}
159
160func (g *geminiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
161 messages = g.cleanMessages(messages)
162 return g.send(ctx, model, messages, tools)
163}
164
165func (g *geminiProvider) send(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166 // Convert messages
167 geminiMessages := g.convertMessages(messages)
168 if g.debug {
169 jsonData, _ := json.Marshal(geminiMessages)
170 slog.Debug("Prepared messages", "messages", string(jsonData))
171 }
172
173 model := g.Model(modelID)
174 maxTokens := model.DefaultMaxTokens
175 if g.maxTokens > 0 {
176 maxTokens = g.maxTokens
177 }
178 systemMessage := g.systemMessage
179 if g.systemPromptPrefix != "" {
180 systemMessage = g.systemPromptPrefix + "\n" + systemMessage
181 }
182 history := geminiMessages[:len(geminiMessages)-1] // All but last message
183 lastMsg := geminiMessages[len(geminiMessages)-1]
184 config := &genai.GenerateContentConfig{
185 MaxOutputTokens: int32(maxTokens),
186 SystemInstruction: &genai.Content{
187 Parts: []*genai.Part{{Text: systemMessage}},
188 },
189 }
190 config.Tools = g.convertTools(tools)
191 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
192
193 attempts := 0
194 for {
195 attempts++
196 var toolCalls []message.ToolCall
197
198 var lastMsgParts []genai.Part
199 for _, part := range lastMsg.Parts {
200 lastMsgParts = append(lastMsgParts, *part)
201 }
202 resp, err := chat.SendMessage(ctx, lastMsgParts...)
203 // If there is an error we are going to see if we can retry the call
204 if err != nil {
205 retry, after, retryErr := g.shouldRetry(attempts, err)
206 if retryErr != nil {
207 return nil, retryErr
208 }
209 if retry {
210 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
211 select {
212 case <-ctx.Done():
213 return nil, ctx.Err()
214 case <-time.After(time.Duration(after) * time.Millisecond):
215 continue
216 }
217 }
218 return nil, retryErr
219 }
220
221 content := ""
222
223 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
224 for _, part := range resp.Candidates[0].Content.Parts {
225 switch {
226 case part.Text != "":
227 content = string(part.Text)
228 case part.FunctionCall != nil:
229 id := "call_" + uuid.New().String()
230 args, _ := json.Marshal(part.FunctionCall.Args)
231 toolCalls = append(toolCalls, message.ToolCall{
232 ID: id,
233 Name: part.FunctionCall.Name,
234 Input: string(args),
235 Type: "function",
236 Finished: true,
237 })
238 }
239 }
240 }
241 finishReason := message.FinishReasonEndTurn
242 if len(resp.Candidates) > 0 {
243 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
244 }
245 if len(toolCalls) > 0 {
246 finishReason = message.FinishReasonToolUse
247 }
248
249 return &ProviderResponse{
250 Content: content,
251 ToolCalls: toolCalls,
252 Usage: g.usage(resp),
253 FinishReason: finishReason,
254 }, nil
255 }
256}
257
258func (g *geminiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
259 messages = g.cleanMessages(messages)
260 return g.stream(ctx, model, messages, tools)
261}
262
263func (g *geminiProvider) stream(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
264 // Convert messages
265 geminiMessages := g.convertMessages(messages)
266
267 model := g.Model(modelID)
268 if g.debug {
269 jsonData, _ := json.Marshal(geminiMessages)
270 slog.Debug("Prepared messages", "messages", string(jsonData))
271 }
272
273 maxTokens := model.DefaultMaxTokens
274 if g.maxTokens > 0 {
275 maxTokens = g.maxTokens
276 }
277
278 systemMessage := g.systemMessage
279 if g.systemPromptPrefix != "" {
280 systemMessage = g.systemPromptPrefix + "\n" + systemMessage
281 }
282
283 history := geminiMessages[:len(geminiMessages)-1] // All but last message
284 lastMsg := geminiMessages[len(geminiMessages)-1]
285 config := &genai.GenerateContentConfig{
286 MaxOutputTokens: int32(maxTokens),
287 SystemInstruction: &genai.Content{
288 Parts: []*genai.Part{{Text: systemMessage}},
289 },
290 }
291 config.Tools = g.convertTools(tools)
292 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
293
294 attempts := 0
295 eventChan := make(chan ProviderEvent)
296
297 go func() {
298 defer close(eventChan)
299
300 for {
301 attempts++
302
303 currentContent := ""
304 toolCalls := []message.ToolCall{}
305 var finalResp *genai.GenerateContentResponse
306
307 eventChan <- ProviderEvent{Type: EventContentStart}
308
309 var lastMsgParts []genai.Part
310
311 for _, part := range lastMsg.Parts {
312 lastMsgParts = append(lastMsgParts, *part)
313 }
314 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
315 if err != nil {
316 retry, after, retryErr := g.shouldRetry(attempts, err)
317 if retryErr != nil {
318 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
319 return
320 }
321 if retry {
322 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
323 select {
324 case <-ctx.Done():
325 if ctx.Err() != nil {
326 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
327 }
328
329 return
330 case <-time.After(time.Duration(after) * time.Millisecond):
331 break
332 }
333 } else {
334 eventChan <- ProviderEvent{Type: EventError, Error: err}
335 return
336 }
337 }
338
339 finalResp = resp
340
341 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
342 for _, part := range resp.Candidates[0].Content.Parts {
343 switch {
344 case part.Text != "":
345 delta := string(part.Text)
346 if delta != "" {
347 eventChan <- ProviderEvent{
348 Type: EventContentDelta,
349 Content: delta,
350 }
351 currentContent += delta
352 }
353 case part.FunctionCall != nil:
354 id := "call_" + uuid.New().String()
355 args, _ := json.Marshal(part.FunctionCall.Args)
356 newCall := message.ToolCall{
357 ID: id,
358 Name: part.FunctionCall.Name,
359 Input: string(args),
360 Type: "function",
361 Finished: true,
362 }
363
364 isNew := true
365 for _, existing := range toolCalls {
366 if existing.Name == newCall.Name && existing.Input == newCall.Input {
367 isNew = false
368 break
369 }
370 }
371
372 if isNew {
373 toolCalls = append(toolCalls, newCall)
374 }
375 }
376 }
377 }
378 }
379
380 eventChan <- ProviderEvent{Type: EventContentStop}
381
382 if finalResp != nil {
383 finishReason := message.FinishReasonEndTurn
384 if len(finalResp.Candidates) > 0 {
385 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
386 }
387 if len(toolCalls) > 0 {
388 finishReason = message.FinishReasonToolUse
389 }
390 eventChan <- ProviderEvent{
391 Type: EventComplete,
392 Response: &ProviderResponse{
393 Content: currentContent,
394 ToolCalls: toolCalls,
395 Usage: g.usage(finalResp),
396 FinishReason: finishReason,
397 },
398 }
399 return
400 }
401 }
402 }()
403
404 return eventChan
405}
406
407func (g *geminiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
408 // Check if error is a rate limit error
409 if attempts > maxRetries {
410 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
411 }
412
413 // Gemini doesn't have a standard error type we can check against
414 // So we'll check the error message for rate limit indicators
415 if errors.Is(err, io.EOF) {
416 return false, 0, err
417 }
418
419 errMsg := err.Error()
420 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
421
422 // Check for token expiration (401 Unauthorized)
423 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
424 g.apiKey, err = g.resolver.ResolveValue(g.config.APIKey)
425 if err != nil {
426 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
427 }
428 g.client, err = createGeminiClient(g.baseProvider)
429 if err != nil {
430 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
431 }
432 return true, 0, nil
433 }
434
435 // Check for common rate limit error messages
436
437 if !isRateLimit {
438 return false, 0, err
439 }
440
441 // Calculate backoff with jitter
442 backoffMs := 2000 * (1 << (attempts - 1))
443 jitterMs := int(float64(backoffMs) * 0.2)
444 retryMs := backoffMs + jitterMs
445
446 return true, int64(retryMs), nil
447}
448
449func (g *geminiProvider) usage(resp *genai.GenerateContentResponse) TokenUsage {
450 if resp == nil || resp.UsageMetadata == nil {
451 return TokenUsage{}
452 }
453
454 return TokenUsage{
455 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
456 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
457 CacheCreationTokens: 0, // Not directly provided by Gemini
458 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
459 }
460}
461
462// Helper functions
463func parseJSONToMap(jsonStr string) (map[string]any, error) {
464 var result map[string]any
465 err := json.Unmarshal([]byte(jsonStr), &result)
466 return result, err
467}
468
469func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
470 properties := make(map[string]*genai.Schema)
471
472 for name, param := range parameters {
473 properties[name] = convertToSchema(param)
474 }
475
476 return properties
477}
478
479func convertToSchema(param any) *genai.Schema {
480 schema := &genai.Schema{Type: genai.TypeString}
481
482 paramMap, ok := param.(map[string]any)
483 if !ok {
484 return schema
485 }
486
487 if desc, ok := paramMap["description"].(string); ok {
488 schema.Description = desc
489 }
490
491 typeVal, hasType := paramMap["type"]
492 if !hasType {
493 return schema
494 }
495
496 typeStr, ok := typeVal.(string)
497 if !ok {
498 return schema
499 }
500
501 schema.Type = mapJSONTypeToGenAI(typeStr)
502
503 switch typeStr {
504 case "array":
505 schema.Items = processArrayItems(paramMap)
506 case "object":
507 if props, ok := paramMap["properties"].(map[string]any); ok {
508 schema.Properties = convertSchemaProperties(props)
509 }
510 }
511
512 return schema
513}
514
515func processArrayItems(paramMap map[string]any) *genai.Schema {
516 items, ok := paramMap["items"].(map[string]any)
517 if !ok {
518 return nil
519 }
520
521 return convertToSchema(items)
522}
523
524func mapJSONTypeToGenAI(jsonType string) genai.Type {
525 switch jsonType {
526 case "string":
527 return genai.TypeString
528 case "number":
529 return genai.TypeNumber
530 case "integer":
531 return genai.TypeInteger
532 case "boolean":
533 return genai.TypeBoolean
534 case "array":
535 return genai.TypeArray
536 case "object":
537 return genai.TypeObject
538 default:
539 return genai.TypeString // Default to string for unknown types
540 }
541}
542
543func contains(s string, substrs ...string) bool {
544 for _, substr := range substrs {
545 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
546 return true
547 }
548 }
549 return false
550}