1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type geminiClient struct {
21 providerOptions providerClientOptions
22 client *genai.Client
23}
24
25type GeminiClient ProviderClient
26
27func newGeminiClient(opts providerClientOptions) GeminiClient {
28 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
29 if err != nil {
30 logging.Error("Failed to create Gemini client", "error", err)
31 return nil
32 }
33
34 return &geminiClient{
35 providerOptions: opts,
36 client: client,
37 }
38}
39
40func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
41 var history []*genai.Content
42 for _, msg := range messages {
43 switch msg.Role {
44 case message.User:
45 var parts []*genai.Part
46 parts = append(parts, &genai.Part{Text: msg.Content().String()})
47 for _, binaryContent := range msg.BinaryContent() {
48 imageFormat := strings.Split(binaryContent.MIMEType, "/")
49 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
50 MIMEType: imageFormat[1],
51 Data: binaryContent.Data,
52 }})
53 }
54 history = append(history, &genai.Content{
55 Parts: parts,
56 Role: "user",
57 })
58 case message.Assistant:
59 var assistantParts []*genai.Part
60
61 if msg.Content().String() != "" {
62 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
63 }
64
65 if len(msg.ToolCalls()) > 0 {
66 for _, call := range msg.ToolCalls() {
67 args, _ := parseJsonToMap(call.Input)
68 assistantParts = append(assistantParts, &genai.Part{
69 FunctionCall: &genai.FunctionCall{
70 Name: call.Name,
71 Args: args,
72 },
73 })
74 }
75 }
76
77 if len(assistantParts) > 0 {
78 history = append(history, &genai.Content{
79 Role: "model",
80 Parts: assistantParts,
81 })
82 }
83
84 case message.Tool:
85 for _, result := range msg.ToolResults() {
86 response := map[string]any{"result": result.Content}
87 parsed, err := parseJsonToMap(result.Content)
88 if err == nil {
89 response = parsed
90 }
91
92 var toolCall message.ToolCall
93 for _, m := range messages {
94 if m.Role == message.Assistant {
95 for _, call := range m.ToolCalls() {
96 if call.ID == result.ToolCallID {
97 toolCall = call
98 break
99 }
100 }
101 }
102 }
103
104 history = append(history, &genai.Content{
105 Parts: []*genai.Part{
106 {
107 FunctionResponse: &genai.FunctionResponse{
108 Name: toolCall.Name,
109 Response: response,
110 },
111 },
112 },
113 Role: "function",
114 })
115 }
116 }
117 }
118
119 return history
120}
121
122func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
123 geminiTool := &genai.Tool{}
124 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
125
126 for _, tool := range tools {
127 info := tool.Info()
128 declaration := &genai.FunctionDeclaration{
129 Name: info.Name,
130 Description: info.Description,
131 Parameters: &genai.Schema{
132 Type: genai.TypeObject,
133 Properties: convertSchemaProperties(info.Parameters),
134 Required: info.Required,
135 },
136 }
137
138 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
139 }
140
141 return []*genai.Tool{geminiTool}
142}
143
144func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
145 switch reason {
146 case genai.FinishReasonStop:
147 return message.FinishReasonEndTurn
148 case genai.FinishReasonMaxTokens:
149 return message.FinishReasonMaxTokens
150 default:
151 return message.FinishReasonUnknown
152 }
153}
154
155func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
156 // Convert messages
157 geminiMessages := g.convertMessages(messages)
158 model := g.providerOptions.model(g.providerOptions.modelType)
159 cfg := config.Get()
160 if cfg.Options.Debug {
161 jsonData, _ := json.Marshal(geminiMessages)
162 logging.Debug("Prepared messages", "messages", string(jsonData))
163 }
164
165 modelConfig := cfg.Models.Large
166 if g.providerOptions.modelType == config.SmallModel {
167 modelConfig = cfg.Models.Small
168 }
169
170 maxTokens := model.DefaultMaxTokens
171 if modelConfig.MaxTokens > 0 {
172 maxTokens = modelConfig.MaxTokens
173 }
174 history := geminiMessages[:len(geminiMessages)-1] // All but last message
175 lastMsg := geminiMessages[len(geminiMessages)-1]
176 config := &genai.GenerateContentConfig{
177 MaxOutputTokens: int32(maxTokens),
178 SystemInstruction: &genai.Content{
179 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
180 },
181 }
182 if len(tools) > 0 {
183 config.Tools = g.convertTools(tools)
184 }
185 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
186
187 attempts := 0
188 for {
189 attempts++
190 var toolCalls []message.ToolCall
191
192 var lastMsgParts []genai.Part
193 for _, part := range lastMsg.Parts {
194 lastMsgParts = append(lastMsgParts, *part)
195 }
196 resp, err := chat.SendMessage(ctx, lastMsgParts...)
197 // If there is an error we are going to see if we can retry the call
198 if err != nil {
199 retry, after, retryErr := g.shouldRetry(attempts, err)
200 if retryErr != nil {
201 return nil, retryErr
202 }
203 if retry {
204 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
205 select {
206 case <-ctx.Done():
207 return nil, ctx.Err()
208 case <-time.After(time.Duration(after) * time.Millisecond):
209 continue
210 }
211 }
212 return nil, retryErr
213 }
214
215 content := ""
216
217 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
218 for _, part := range resp.Candidates[0].Content.Parts {
219 switch {
220 case part.Text != "":
221 content = string(part.Text)
222 case part.FunctionCall != nil:
223 id := "call_" + uuid.New().String()
224 args, _ := json.Marshal(part.FunctionCall.Args)
225 toolCalls = append(toolCalls, message.ToolCall{
226 ID: id,
227 Name: part.FunctionCall.Name,
228 Input: string(args),
229 Type: "function",
230 Finished: true,
231 })
232 }
233 }
234 }
235 finishReason := message.FinishReasonEndTurn
236 if len(resp.Candidates) > 0 {
237 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
238 }
239 if len(toolCalls) > 0 {
240 finishReason = message.FinishReasonToolUse
241 }
242
243 return &ProviderResponse{
244 Content: content,
245 ToolCalls: toolCalls,
246 Usage: g.usage(resp),
247 FinishReason: finishReason,
248 }, nil
249 }
250}
251
252func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253 // Convert messages
254 geminiMessages := g.convertMessages(messages)
255
256 model := g.providerOptions.model(g.providerOptions.modelType)
257 cfg := config.Get()
258 if cfg.Options.Debug {
259 jsonData, _ := json.Marshal(geminiMessages)
260 logging.Debug("Prepared messages", "messages", string(jsonData))
261 }
262
263 modelConfig := cfg.Models.Large
264 if g.providerOptions.modelType == config.SmallModel {
265 modelConfig = cfg.Models.Small
266 }
267 maxTokens := model.DefaultMaxTokens
268 if modelConfig.MaxTokens > 0 {
269 maxTokens = modelConfig.MaxTokens
270 }
271 history := geminiMessages[:len(geminiMessages)-1] // All but last message
272 lastMsg := geminiMessages[len(geminiMessages)-1]
273 config := &genai.GenerateContentConfig{
274 MaxOutputTokens: int32(maxTokens),
275 SystemInstruction: &genai.Content{
276 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
277 },
278 }
279 if len(tools) > 0 {
280 config.Tools = g.convertTools(tools)
281 }
282 chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
283
284 attempts := 0
285 eventChan := make(chan ProviderEvent)
286
287 go func() {
288 defer close(eventChan)
289
290 for {
291 attempts++
292
293 currentContent := ""
294 toolCalls := []message.ToolCall{}
295 var finalResp *genai.GenerateContentResponse
296
297 eventChan <- ProviderEvent{Type: EventContentStart}
298
299 var lastMsgParts []genai.Part
300
301 for _, part := range lastMsg.Parts {
302 lastMsgParts = append(lastMsgParts, *part)
303 }
304 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
305 if err != nil {
306 retry, after, retryErr := g.shouldRetry(attempts, err)
307 if retryErr != nil {
308 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
309 return
310 }
311 if retry {
312 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
313 select {
314 case <-ctx.Done():
315 if ctx.Err() != nil {
316 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
317 }
318
319 return
320 case <-time.After(time.Duration(after) * time.Millisecond):
321 break
322 }
323 } else {
324 eventChan <- ProviderEvent{Type: EventError, Error: err}
325 return
326 }
327 }
328
329 finalResp = resp
330
331 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
332 for _, part := range resp.Candidates[0].Content.Parts {
333 switch {
334 case part.Text != "":
335 delta := string(part.Text)
336 if delta != "" {
337 eventChan <- ProviderEvent{
338 Type: EventContentDelta,
339 Content: delta,
340 }
341 currentContent += delta
342 }
343 case part.FunctionCall != nil:
344 id := "call_" + uuid.New().String()
345 args, _ := json.Marshal(part.FunctionCall.Args)
346 newCall := message.ToolCall{
347 ID: id,
348 Name: part.FunctionCall.Name,
349 Input: string(args),
350 Type: "function",
351 Finished: true,
352 }
353
354 isNew := true
355 for _, existing := range toolCalls {
356 if existing.Name == newCall.Name && existing.Input == newCall.Input {
357 isNew = false
358 break
359 }
360 }
361
362 if isNew {
363 toolCalls = append(toolCalls, newCall)
364 }
365 }
366 }
367 }
368 }
369
370 eventChan <- ProviderEvent{Type: EventContentStop}
371
372 if finalResp != nil {
373 finishReason := message.FinishReasonEndTurn
374 if len(finalResp.Candidates) > 0 {
375 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
376 }
377 if len(toolCalls) > 0 {
378 finishReason = message.FinishReasonToolUse
379 }
380 eventChan <- ProviderEvent{
381 Type: EventComplete,
382 Response: &ProviderResponse{
383 Content: currentContent,
384 ToolCalls: toolCalls,
385 Usage: g.usage(finalResp),
386 FinishReason: finishReason,
387 },
388 }
389 return
390 }
391 }
392 }()
393
394 return eventChan
395}
396
397func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
398 // Check if error is a rate limit error
399 if attempts > maxRetries {
400 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
401 }
402
403 // Gemini doesn't have a standard error type we can check against
404 // So we'll check the error message for rate limit indicators
405 if errors.Is(err, io.EOF) {
406 return false, 0, err
407 }
408
409 errMsg := err.Error()
410 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
411
412 // Check for common rate limit error messages
413
414 if !isRateLimit {
415 return false, 0, err
416 }
417
418 // Calculate backoff with jitter
419 backoffMs := 2000 * (1 << (attempts - 1))
420 jitterMs := int(float64(backoffMs) * 0.2)
421 retryMs := backoffMs + jitterMs
422
423 return true, int64(retryMs), nil
424}
425
426func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
427 if resp == nil || resp.UsageMetadata == nil {
428 return TokenUsage{}
429 }
430
431 return TokenUsage{
432 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
433 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
434 CacheCreationTokens: 0, // Not directly provided by Gemini
435 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
436 }
437}
438
439func (g *geminiClient) Model() config.Model {
440 return g.providerOptions.model(g.providerOptions.modelType)
441}
442
443// Helper functions
444func parseJsonToMap(jsonStr string) (map[string]any, error) {
445 var result map[string]any
446 err := json.Unmarshal([]byte(jsonStr), &result)
447 return result, err
448}
449
450func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
451 properties := make(map[string]*genai.Schema)
452
453 for name, param := range parameters {
454 properties[name] = convertToSchema(param)
455 }
456
457 return properties
458}
459
460func convertToSchema(param any) *genai.Schema {
461 schema := &genai.Schema{Type: genai.TypeString}
462
463 paramMap, ok := param.(map[string]any)
464 if !ok {
465 return schema
466 }
467
468 if desc, ok := paramMap["description"].(string); ok {
469 schema.Description = desc
470 }
471
472 typeVal, hasType := paramMap["type"]
473 if !hasType {
474 return schema
475 }
476
477 typeStr, ok := typeVal.(string)
478 if !ok {
479 return schema
480 }
481
482 schema.Type = mapJSONTypeToGenAI(typeStr)
483
484 switch typeStr {
485 case "array":
486 schema.Items = processArrayItems(paramMap)
487 case "object":
488 if props, ok := paramMap["properties"].(map[string]any); ok {
489 schema.Properties = convertSchemaProperties(props)
490 }
491 }
492
493 return schema
494}
495
496func processArrayItems(paramMap map[string]any) *genai.Schema {
497 items, ok := paramMap["items"].(map[string]any)
498 if !ok {
499 return nil
500 }
501
502 return convertToSchema(items)
503}
504
505func mapJSONTypeToGenAI(jsonType string) genai.Type {
506 switch jsonType {
507 case "string":
508 return genai.TypeString
509 case "number":
510 return genai.TypeNumber
511 case "integer":
512 return genai.TypeInteger
513 case "boolean":
514 return genai.TypeBoolean
515 case "array":
516 return genai.TypeArray
517 case "object":
518 return genai.TypeObject
519 default:
520 return genai.TypeString // Default to string for unknown types
521 }
522}
523
524func contains(s string, substrs ...string) bool {
525 for _, substr := range substrs {
526 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
527 return true
528 }
529 }
530 return false
531}