1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type geminiClient struct {
21 providerOptions providerClientOptions
22 client *genai.Client
23}
24
25type GeminiClient ProviderClient
26
27func newGeminiClient(opts providerClientOptions) GeminiClient {
28 client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
29 if err != nil {
30 logging.Error("Failed to create Gemini client", "error", err)
31 return nil
32 }
33
34 return &geminiClient{
35 providerOptions: opts,
36 client: client,
37 }
38}
39
40func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
41 var history []*genai.Content
42 for _, msg := range messages {
43 switch msg.Role {
44 case message.User:
45 var parts []*genai.Part
46 parts = append(parts, &genai.Part{Text: msg.Content().String()})
47 for _, binaryContent := range msg.BinaryContent() {
48 imageFormat := strings.Split(binaryContent.MIMEType, "/")
49 parts = append(parts, &genai.Part{InlineData: &genai.Blob{
50 MIMEType: imageFormat[1],
51 Data: binaryContent.Data,
52 }})
53 }
54 history = append(history, &genai.Content{
55 Parts: parts,
56 Role: "user",
57 })
58 case message.Assistant:
59 var assistantParts []*genai.Part
60
61 if msg.Content().String() != "" {
62 assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()})
63 }
64
65 if len(msg.ToolCalls()) > 0 {
66 for _, call := range msg.ToolCalls() {
67 args, _ := parseJsonToMap(call.Input)
68 assistantParts = append(assistantParts, &genai.Part{
69 FunctionCall: &genai.FunctionCall{
70 Name: call.Name,
71 Args: args,
72 },
73 })
74 }
75 }
76
77 if len(assistantParts) > 0 {
78 history = append(history, &genai.Content{
79 Role: "model",
80 Parts: assistantParts,
81 })
82 }
83
84 case message.Tool:
85 for _, result := range msg.ToolResults() {
86 response := map[string]any{"result": result.Content}
87 parsed, err := parseJsonToMap(result.Content)
88 if err == nil {
89 response = parsed
90 }
91
92 var toolCall message.ToolCall
93 for _, m := range messages {
94 if m.Role == message.Assistant {
95 for _, call := range m.ToolCalls() {
96 if call.ID == result.ToolCallID {
97 toolCall = call
98 break
99 }
100 }
101 }
102 }
103
104 history = append(history, &genai.Content{
105 Parts: []*genai.Part{
106 {
107 FunctionResponse: &genai.FunctionResponse{
108 Name: toolCall.Name,
109 Response: response,
110 },
111 },
112 },
113 Role: "function",
114 })
115 }
116 }
117 }
118
119 return history
120}
121
122func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
123 geminiTool := &genai.Tool{}
124 geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
125
126 for _, tool := range tools {
127 info := tool.Info()
128 declaration := &genai.FunctionDeclaration{
129 Name: info.Name,
130 Description: info.Description,
131 Parameters: &genai.Schema{
132 Type: genai.TypeObject,
133 Properties: convertSchemaProperties(info.Parameters),
134 Required: info.Required,
135 },
136 }
137
138 geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
139 }
140
141 return []*genai.Tool{geminiTool}
142}
143
144func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
145 switch reason {
146 case genai.FinishReasonStop:
147 return message.FinishReasonEndTurn
148 case genai.FinishReasonMaxTokens:
149 return message.FinishReasonMaxTokens
150 default:
151 return message.FinishReasonUnknown
152 }
153}
154
155func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
156 // Convert messages
157 geminiMessages := g.convertMessages(messages)
158
159 cfg := config.Get()
160 if cfg.Debug {
161 jsonData, _ := json.Marshal(geminiMessages)
162 logging.Debug("Prepared messages", "messages", string(jsonData))
163 }
164
165 history := geminiMessages[:len(geminiMessages)-1] // All but last message
166 lastMsg := geminiMessages[len(geminiMessages)-1]
167 config := &genai.GenerateContentConfig{
168 MaxOutputTokens: int32(g.providerOptions.maxTokens),
169 SystemInstruction: &genai.Content{
170 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
171 },
172 }
173 if len(tools) > 0 {
174 config.Tools = g.convertTools(tools)
175 }
176 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
177
178 attempts := 0
179 for {
180 attempts++
181 var toolCalls []message.ToolCall
182
183 var lastMsgParts []genai.Part
184 for _, part := range lastMsg.Parts {
185 lastMsgParts = append(lastMsgParts, *part)
186 }
187 resp, err := chat.SendMessage(ctx, lastMsgParts...)
188 // If there is an error we are going to see if we can retry the call
189 if err != nil {
190 retry, after, retryErr := g.shouldRetry(attempts, err)
191 if retryErr != nil {
192 return nil, retryErr
193 }
194 if retry {
195 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
196 select {
197 case <-ctx.Done():
198 return nil, ctx.Err()
199 case <-time.After(time.Duration(after) * time.Millisecond):
200 continue
201 }
202 }
203 return nil, retryErr
204 }
205
206 content := ""
207
208 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
209 for _, part := range resp.Candidates[0].Content.Parts {
210 switch {
211 case part.Text != "":
212 content = string(part.Text)
213 case part.FunctionCall != nil:
214 id := "call_" + uuid.New().String()
215 args, _ := json.Marshal(part.FunctionCall.Args)
216 toolCalls = append(toolCalls, message.ToolCall{
217 ID: id,
218 Name: part.FunctionCall.Name,
219 Input: string(args),
220 Type: "function",
221 Finished: true,
222 })
223 }
224 }
225 }
226 finishReason := message.FinishReasonEndTurn
227 if len(resp.Candidates) > 0 {
228 finishReason = g.finishReason(resp.Candidates[0].FinishReason)
229 }
230 if len(toolCalls) > 0 {
231 finishReason = message.FinishReasonToolUse
232 }
233
234 return &ProviderResponse{
235 Content: content,
236 ToolCalls: toolCalls,
237 Usage: g.usage(resp),
238 FinishReason: finishReason,
239 }, nil
240 }
241}
242
243func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
244 // Convert messages
245 geminiMessages := g.convertMessages(messages)
246
247 cfg := config.Get()
248 if cfg.Debug {
249 jsonData, _ := json.Marshal(geminiMessages)
250 logging.Debug("Prepared messages", "messages", string(jsonData))
251 }
252
253 history := geminiMessages[:len(geminiMessages)-1] // All but last message
254 lastMsg := geminiMessages[len(geminiMessages)-1]
255 config := &genai.GenerateContentConfig{
256 MaxOutputTokens: int32(g.providerOptions.maxTokens),
257 SystemInstruction: &genai.Content{
258 Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
259 },
260 }
261 if len(tools) > 0 {
262 config.Tools = g.convertTools(tools)
263 }
264 chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
265
266 attempts := 0
267 eventChan := make(chan ProviderEvent)
268
269 go func() {
270 defer close(eventChan)
271
272 for {
273 attempts++
274
275 currentContent := ""
276 toolCalls := []message.ToolCall{}
277 var finalResp *genai.GenerateContentResponse
278
279 eventChan <- ProviderEvent{Type: EventContentStart}
280
281 var lastMsgParts []genai.Part
282
283 for _, part := range lastMsg.Parts {
284 lastMsgParts = append(lastMsgParts, *part)
285 }
286 for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
287 if err != nil {
288 retry, after, retryErr := g.shouldRetry(attempts, err)
289 if retryErr != nil {
290 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
291 return
292 }
293 if retry {
294 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
295 select {
296 case <-ctx.Done():
297 if ctx.Err() != nil {
298 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
299 }
300
301 return
302 case <-time.After(time.Duration(after) * time.Millisecond):
303 break
304 }
305 } else {
306 eventChan <- ProviderEvent{Type: EventError, Error: err}
307 return
308 }
309 }
310
311 finalResp = resp
312
313 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
314 for _, part := range resp.Candidates[0].Content.Parts {
315 switch {
316 case part.Text != "":
317 delta := string(part.Text)
318 if delta != "" {
319 eventChan <- ProviderEvent{
320 Type: EventContentDelta,
321 Content: delta,
322 }
323 currentContent += delta
324 }
325 case part.FunctionCall != nil:
326 id := "call_" + uuid.New().String()
327 args, _ := json.Marshal(part.FunctionCall.Args)
328 newCall := message.ToolCall{
329 ID: id,
330 Name: part.FunctionCall.Name,
331 Input: string(args),
332 Type: "function",
333 Finished: true,
334 }
335
336 isNew := true
337 for _, existing := range toolCalls {
338 if existing.Name == newCall.Name && existing.Input == newCall.Input {
339 isNew = false
340 break
341 }
342 }
343
344 if isNew {
345 toolCalls = append(toolCalls, newCall)
346 }
347 }
348 }
349 }
350 }
351
352 eventChan <- ProviderEvent{Type: EventContentStop}
353
354 if finalResp != nil {
355 finishReason := message.FinishReasonEndTurn
356 if len(finalResp.Candidates) > 0 {
357 finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
358 }
359 if len(toolCalls) > 0 {
360 finishReason = message.FinishReasonToolUse
361 }
362 eventChan <- ProviderEvent{
363 Type: EventComplete,
364 Response: &ProviderResponse{
365 Content: currentContent,
366 ToolCalls: toolCalls,
367 Usage: g.usage(finalResp),
368 FinishReason: finishReason,
369 },
370 }
371 return
372 }
373 }
374 }()
375
376 return eventChan
377}
378
379func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
380 // Check if error is a rate limit error
381 if attempts > maxRetries {
382 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
383 }
384
385 // Gemini doesn't have a standard error type we can check against
386 // So we'll check the error message for rate limit indicators
387 if errors.Is(err, io.EOF) {
388 return false, 0, err
389 }
390
391 errMsg := err.Error()
392 isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
393
394 // Check for common rate limit error messages
395
396 if !isRateLimit {
397 return false, 0, err
398 }
399
400 // Calculate backoff with jitter
401 backoffMs := 2000 * (1 << (attempts - 1))
402 jitterMs := int(float64(backoffMs) * 0.2)
403 retryMs := backoffMs + jitterMs
404
405 return true, int64(retryMs), nil
406}
407
408func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
409 var toolCalls []message.ToolCall
410
411 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
412 for _, part := range resp.Candidates[0].Content.Parts {
413 if part.FunctionCall != nil {
414 id := "call_" + uuid.New().String()
415 args, _ := json.Marshal(part.FunctionCall.Args)
416 toolCalls = append(toolCalls, message.ToolCall{
417 ID: id,
418 Name: part.FunctionCall.Name,
419 Input: string(args),
420 Type: "function",
421 })
422 }
423 }
424 }
425
426 return toolCalls
427}
428
429func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
430 if resp == nil || resp.UsageMetadata == nil {
431 return TokenUsage{}
432 }
433
434 return TokenUsage{
435 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
436 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
437 CacheCreationTokens: 0, // Not directly provided by Gemini
438 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
439 }
440}
441
442// Helper functions
443func parseJsonToMap(jsonStr string) (map[string]any, error) {
444 var result map[string]any
445 err := json.Unmarshal([]byte(jsonStr), &result)
446 return result, err
447}
448
449func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
450 properties := make(map[string]*genai.Schema)
451
452 for name, param := range parameters {
453 properties[name] = convertToSchema(param)
454 }
455
456 return properties
457}
458
459func convertToSchema(param any) *genai.Schema {
460 schema := &genai.Schema{Type: genai.TypeString}
461
462 paramMap, ok := param.(map[string]any)
463 if !ok {
464 return schema
465 }
466
467 if desc, ok := paramMap["description"].(string); ok {
468 schema.Description = desc
469 }
470
471 typeVal, hasType := paramMap["type"]
472 if !hasType {
473 return schema
474 }
475
476 typeStr, ok := typeVal.(string)
477 if !ok {
478 return schema
479 }
480
481 schema.Type = mapJSONTypeToGenAI(typeStr)
482
483 switch typeStr {
484 case "array":
485 schema.Items = processArrayItems(paramMap)
486 case "object":
487 if props, ok := paramMap["properties"].(map[string]any); ok {
488 schema.Properties = convertSchemaProperties(props)
489 }
490 }
491
492 return schema
493}
494
495func processArrayItems(paramMap map[string]any) *genai.Schema {
496 items, ok := paramMap["items"].(map[string]any)
497 if !ok {
498 return nil
499 }
500
501 return convertToSchema(items)
502}
503
504func mapJSONTypeToGenAI(jsonType string) genai.Type {
505 switch jsonType {
506 case "string":
507 return genai.TypeString
508 case "number":
509 return genai.TypeNumber
510 case "integer":
511 return genai.TypeInteger
512 case "boolean":
513 return genai.TypeBoolean
514 case "array":
515 return genai.TypeArray
516 case "object":
517 return genai.TypeObject
518 default:
519 return genai.TypeString // Default to string for unknown types
520 }
521}
522
523func contains(s string, substrs ...string) bool {
524 for _, substr := range substrs {
525 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
526 return true
527 }
528 }
529 return false
530}