1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type geminiClient struct {
21 providerOptions providerClientOptions
22 client *genai.Client
23}
24
25type GeminiClient ProviderClient
26
27func newGeminiClient(opts providerClientOptions) GeminiClient {
28 client, err := createGeminiClient(opts)
29 if err != nil {
30 logging.Error("Failed to create Gemini client", "error", err)
31 return nil
32 }
33
34 return &geminiClient{
35 providerOptions: opts,
36 client: client,
37 }
38}
39
40func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
41 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
42 if err != nil {
43 return nil, err
44 }
45 return client, nil
46}
47
48func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
49 var history []*genai.Content
50 for _, msg := range messages {
51 switch msg.Role {
52 case message.User:
53 var parts []*genai.Part
54 parts = append(parts, &genai.Part{Text: msg.Content().String()})
55 for _, binaryContent := range msg.BinaryContent() {
56 imageFormat := strings.Split(binaryContent.MIMEType, "/")
57 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
58 MIMEType: imageFormat[1],
59 Data: binaryContent.Data,
60 }})
61 }
62 history = append(history, &genai.Content{
63 Parts: parts,
64 Role: "user",
65 })
66 case message.Assistant:
67 var assistantParts []*genai.Part
68
69 if msg.Content().String() != "" {
70 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
71 }
72
73 if len(msg.ToolCalls()) > 0 {
74 for _, call := range msg.ToolCalls() {
75 args, _ := parseJsonToMap(call.Input)
76 assistantParts = append(assistantParts, &genai.Part{
77 FunctionCall: &genai.FunctionCall{
78 Name: call.Name,
79 Args: args,
80 },
81 })
82 }
83 }
84
85 if len(assistantParts) > 0 {
86 history = append(history, &genai.Content{
87 Role: "model",
88 Parts: assistantParts,
89 })
90 }
91
92 case message.Tool:
93 for _, result := range msg.ToolResults() {
94 response := map[string]any{"result": result.Content}
95 parsed, err := parseJsonToMap(result.Content)
96 if err == nil {
97 response = parsed
98 }
99
100 var toolCall message.ToolCall
101 for _, m := range messages {
102 if m.Role == message.Assistant {
103 for _, call := range m.ToolCalls() {
104 if call.ID == result.ToolCallID {
105 toolCall = call
106 break
107 }
108 }
109 }
110 }
111
112 history = append(history, &genai.Content{
113 Parts: []*genai.Part{
114 {
115 FunctionResponse: &genai.FunctionResponse{
116 Name: toolCall.Name,
117 Response: response,
118 },
119 },
120 },
121 Role: "function",
122 })
123 }
124 }
125 }
126
127 return history
128}
129
130func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
131 geminiTool := &genai.Tool{}
132 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
133
134 for _, tool := range tools {
135 info := tool.Info()
136 declaration := &genai.FunctionDeclaration{
137 Name: info.Name,
138 Description: info.Description,
139 Parameters: &genai.Schema{
140 Type: genai.TypeObject,
141 Properties: convertSchemaProperties(info.Parameters),
142 Required: info.Required,
143 },
144 }
145
146 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
147 }
148
149 return []*genai.Tool{geminiTool}
150}
151
152func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
153 switch reason {
154 case genai.FinishReasonStop:
155 return message.FinishReasonEndTurn
156 case genai.FinishReasonMaxTokens:
157 return message.FinishReasonMaxTokens
158 default:
159 return message.FinishReasonUnknown
160 }
161}
162
163func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
164 // Convert messages
165 geminiMessages := g.convertMessages(messages)
166 model := g.providerOptions.model(g.providerOptions.modelType)
167 cfg := config.Get()
168 if cfg.Options.Debug {
169 jsonData, _ := json.Marshal(geminiMessages)
170 logging.Debug("Prepared messages", "messages", string(jsonData))
171 }
172
173 modelConfig := cfg.Models.Large
174 if g.providerOptions.modelType == config.SmallModel {
175 modelConfig = cfg.Models.Small
176 }
177
178 maxTokens := model.DefaultMaxTokens
179 if modelConfig.MaxTokens > 0 {
180 maxTokens = modelConfig.MaxTokens
181 }
182 history := geminiMessages[:len(geminiMessages)-1] // All but last message
183 lastMsg := geminiMessages[len(geminiMessages)-1]
184 config := &genai.GenerateContentConfig{
185 MaxOutputTokens: int32(maxTokens),
186 SystemInstruction: &genai.Content{
187 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
188 },
189 }
190 if len(tools) > 0 {
191 config.Tools = g.convertTools(tools)
192 }
193 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
194
195 attempts := 0
196 for {
197 attempts++
198 var toolCalls []message.ToolCall
199
200 var lastMsgParts []genai.Part
201 for _, part := range lastMsg.Parts {
202 lastMsgParts = append(lastMsgParts, *part)
203 }
204 resp, err := chat.SendMessage(ctx, lastMsgParts...)
205 // If there is an error we are going to see if we can retry the call
206 if err != nil {
207 retry, after, retryErr := g.shouldRetry(attempts, err)
208 if retryErr != nil {
209 return nil, retryErr
210 }
211 if retry {
212 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
213 select {
214 case <-ctx.Done():
215 return nil, ctx.Err()
216 case <-time.After(time.Duration(after) * time.Millisecond):
217 continue
218 }
219 }
220 return nil, retryErr
221 }
222
223 content := ""
224
225 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
226 for _, part := range resp.Candidates[0].Content.Parts {
227 switch {
228 case part.Text != "":
229 content = string(part.Text)
230 case part.FunctionCall != nil:
231 id := "call_" + uuid.New().String()
232 args, _ := json.Marshal(part.FunctionCall.Args)
233 toolCalls = append(toolCalls, message.ToolCall{
234 ID: id,
235 Name: part.FunctionCall.Name,
236 Input: string(args),
237 Type: "function",
238 Finished: true,
239 })
240 }
241 }
242 }
243 finishReason := message.FinishReasonEndTurn
244 if len(resp.Candidates) > 0 {
245 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
246 }
247 if len(toolCalls) > 0 {
248 finishReason = message.FinishReasonToolUse
249 }
250
251 return &ProviderResponse{
252 Content: content,
253 ToolCalls: toolCalls,
254 Usage: g.usage(resp),
255 FinishReason: finishReason,
256 }, nil
257 }
258}
259
260func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
261 // Convert messages
262 geminiMessages := g.convertMessages(messages)
263
264 model := g.providerOptions.model(g.providerOptions.modelType)
265 cfg := config.Get()
266 if cfg.Options.Debug {
267 jsonData, _ := json.Marshal(geminiMessages)
268 logging.Debug("Prepared messages", "messages", string(jsonData))
269 }
270
271 modelConfig := cfg.Models.Large
272 if g.providerOptions.modelType == config.SmallModel {
273 modelConfig = cfg.Models.Small
274 }
275 maxTokens := model.DefaultMaxTokens
276 if modelConfig.MaxTokens > 0 {
277 maxTokens = modelConfig.MaxTokens
278 }
279
280 // Override max tokens if set in provider options
281 if g.providerOptions.maxTokens > 0 {
282 maxTokens = g.providerOptions.maxTokens
283 }
284 history := geminiMessages[:len(geminiMessages)-1] // All but last message
285 lastMsg := geminiMessages[len(geminiMessages)-1]
286 config := &genai.GenerateContentConfig{
287 MaxOutputTokens: int32(maxTokens),
288 SystemInstruction: &genai.Content{
289 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
290 },
291 }
292 if len(tools) > 0 {
293 config.Tools = g.convertTools(tools)
294 }
295 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
296
297 attempts := 0
298 eventChan := make(chan ProviderEvent)
299
300 go func() {
301 defer close(eventChan)
302
303 for {
304 attempts++
305
306 currentContent := ""
307 toolCalls := []message.ToolCall{}
308 var finalResp *genai.GenerateContentResponse
309
310 eventChan <- ProviderEvent{Type: EventContentStart}
311
312 var lastMsgParts []genai.Part
313
314 for _, part := range lastMsg.Parts {
315 lastMsgParts = append(lastMsgParts, *part)
316 }
317 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
318 if err != nil {
319 retry, after, retryErr := g.shouldRetry(attempts, err)
320 if retryErr != nil {
321 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
322 return
323 }
324 if retry {
325 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
326 select {
327 case <-ctx.Done():
328 if ctx.Err() != nil {
329 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
330 }
331
332 return
333 case <-time.After(time.Duration(after) * time.Millisecond):
334 break
335 }
336 } else {
337 eventChan <- ProviderEvent{Type: EventError, Error: err}
338 return
339 }
340 }
341
342 finalResp = resp
343
344 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
345 for _, part := range resp.Candidates[0].Content.Parts {
346 switch {
347 case part.Text != "":
348 delta := string(part.Text)
349 if delta != "" {
350 eventChan <- ProviderEvent{
351 Type: EventContentDelta,
352 Content: delta,
353 }
354 currentContent += delta
355 }
356 case part.FunctionCall != nil:
357 id := "call_" + uuid.New().String()
358 args, _ := json.Marshal(part.FunctionCall.Args)
359 newCall := message.ToolCall{
360 ID: id,
361 Name: part.FunctionCall.Name,
362 Input: string(args),
363 Type: "function",
364 Finished: true,
365 }
366
367 isNew := true
368 for _, existing := range toolCalls {
369 if existing.Name == newCall.Name && existing.Input == newCall.Input {
370 isNew = false
371 break
372 }
373 }
374
375 if isNew {
376 toolCalls = append(toolCalls, newCall)
377 }
378 }
379 }
380 }
381 }
382
383 eventChan <- ProviderEvent{Type: EventContentStop}
384
385 if finalResp != nil {
386 finishReason := message.FinishReasonEndTurn
387 if len(finalResp.Candidates) > 0 {
388 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
389 }
390 if len(toolCalls) > 0 {
391 finishReason = message.FinishReasonToolUse
392 }
393 eventChan <- ProviderEvent{
394 Type: EventComplete,
395 Response: &ProviderResponse{
396 Content: currentContent,
397 ToolCalls: toolCalls,
398 Usage: g.usage(finalResp),
399 FinishReason: finishReason,
400 },
401 }
402 return
403 }
404 }
405 }()
406
407 return eventChan
408}
409
410func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
411 // Check if error is a rate limit error
412 if attempts > maxRetries {
413 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
414 }
415
416 // Gemini doesn't have a standard error type we can check against
417 // So we'll check the error message for rate limit indicators
418 if errors.Is(err, io.EOF) {
419 return false, 0, err
420 }
421
422 errMsg := err.Error()
423 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
424
425 // Check for token expiration (401 Unauthorized)
426 if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
427 g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
428 if err != nil {
429 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
430 }
431 g.client, err = createGeminiClient(g.providerOptions)
432 if err != nil {
433 return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
434 }
435 return true, 0, nil
436 }
437
438 // Check for common rate limit error messages
439
440 if !isRateLimit {
441 return false, 0, err
442 }
443
444 // Calculate backoff with jitter
445 backoffMs := 2000 * (1 << (attempts - 1))
446 jitterMs := int(float64(backoffMs) * 0.2)
447 retryMs := backoffMs + jitterMs
448
449 return true, int64(retryMs), nil
450}
451
452func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
453 if resp == nil || resp.UsageMetadata == nil {
454 return TokenUsage{}
455 }
456
457 return TokenUsage{
458 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
459 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
460 CacheCreationTokens: 0, // Not directly provided by Gemini
461 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
462 }
463}
464
465func (g *geminiClient) Model() config.Model {
466 return g.providerOptions.model(g.providerOptions.modelType)
467}
468
469// Helper functions
470func parseJsonToMap(jsonStr string) (map[string]any, error) {
471 var result map[string]any
472 err := json.Unmarshal([]byte(jsonStr), &result)
473 return result, err
474}
475
476func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
477 properties := make(map[string]*genai.Schema)
478
479 for name, param := range parameters {
480 properties[name] = convertToSchema(param)
481 }
482
483 return properties
484}
485
486func convertToSchema(param any) *genai.Schema {
487 schema := &genai.Schema{Type: genai.TypeString}
488
489 paramMap, ok := param.(map[string]any)
490 if !ok {
491 return schema
492 }
493
494 if desc, ok := paramMap["description"].(string); ok {
495 schema.Description = desc
496 }
497
498 typeVal, hasType := paramMap["type"]
499 if !hasType {
500 return schema
501 }
502
503 typeStr, ok := typeVal.(string)
504 if !ok {
505 return schema
506 }
507
508 schema.Type = mapJSONTypeToGenAI(typeStr)
509
510 switch typeStr {
511 case "array":
512 schema.Items = processArrayItems(paramMap)
513 case "object":
514 if props, ok := paramMap["properties"].(map[string]any); ok {
515 schema.Properties = convertSchemaProperties(props)
516 }
517 }
518
519 return schema
520}
521
522func processArrayItems(paramMap map[string]any) *genai.Schema {
523 items, ok := paramMap["items"].(map[string]any)
524 if !ok {
525 return nil
526 }
527
528 return convertToSchema(items)
529}
530
531func mapJSONTypeToGenAI(jsonType string) genai.Type {
532 switch jsonType {
533 case "string":
534 return genai.TypeString
535 case "number":
536 return genai.TypeNumber
537 case "integer":
538 return genai.TypeInteger
539 case "boolean":
540 return genai.TypeBoolean
541 case "array":
542 return genai.TypeArray
543 case "object":
544 return genai.TypeObject
545 default:
546 return genai.TypeString // Default to string for unknown types
547 }
548}
549
550func contains(s string, substrs ...string) bool {
551 for _, substr := range substrs {
552 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
553 return true
554 }
555 }
556 return false
557}