1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/google/generative-ai-go/genai"
13 "github.com/google/uuid"
14 "github.com/kujtimiihoxha/termai/internal/config"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/logging"
17 "github.com/kujtimiihoxha/termai/internal/message"
18 "google.golang.org/api/iterator"
19 "google.golang.org/api/option"
20)
21
22type geminiOptions struct {
23 disableCache bool
24}
25
26type GeminiOption func(*geminiOptions)
27
28type geminiClient struct {
29 providerOptions providerClientOptions
30 options geminiOptions
31 client *genai.Client
32}
33
34type GeminiClient ProviderClient
35
36func newGeminiClient(opts providerClientOptions) GeminiClient {
37 geminiOpts := geminiOptions{}
38 for _, o := range opts.geminiOptions {
39 o(&geminiOpts)
40 }
41
42 client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
43 if err != nil {
44 logging.Error("Failed to create Gemini client", "error", err)
45 return nil
46 }
47
48 return &geminiClient{
49 providerOptions: opts,
50 options: geminiOpts,
51 client: client,
52 }
53}
54
55func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
56 var history []*genai.Content
57
58 // Add system message first
59 history = append(history, &genai.Content{
60 Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
61 Role: "user",
62 })
63
64 // Add a system response to acknowledge the system message
65 history = append(history, &genai.Content{
66 Parts: []genai.Part{genai.Text("I'll help you with that.")},
67 Role: "model",
68 })
69
70 for _, msg := range messages {
71 switch msg.Role {
72 case message.User:
73 history = append(history, &genai.Content{
74 Parts: []genai.Part{genai.Text(msg.Content().String())},
75 Role: "user",
76 })
77
78 case message.Assistant:
79 content := &genai.Content{
80 Role: "model",
81 Parts: []genai.Part{},
82 }
83
84 if msg.Content().String() != "" {
85 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
86 }
87
88 if len(msg.ToolCalls()) > 0 {
89 for _, call := range msg.ToolCalls() {
90 args, _ := parseJsonToMap(call.Input)
91 content.Parts = append(content.Parts, genai.FunctionCall{
92 Name: call.Name,
93 Args: args,
94 })
95 }
96 }
97
98 history = append(history, content)
99
100 case message.Tool:
101 for _, result := range msg.ToolResults() {
102 response := map[string]interface{}{"result": result.Content}
103 parsed, err := parseJsonToMap(result.Content)
104 if err == nil {
105 response = parsed
106 }
107
108 var toolCall message.ToolCall
109 for _, m := range messages {
110 if m.Role == message.Assistant {
111 for _, call := range m.ToolCalls() {
112 if call.ID == result.ToolCallID {
113 toolCall = call
114 break
115 }
116 }
117 }
118 }
119
120 history = append(history, &genai.Content{
121 Parts: []genai.Part{genai.FunctionResponse{
122 Name: toolCall.Name,
123 Response: response,
124 }},
125 Role: "function",
126 })
127 }
128 }
129 }
130
131 return history
132}
133
134func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
135 geminiTools := make([]*genai.Tool, 0, len(tools))
136
137 for _, tool := range tools {
138 info := tool.Info()
139 declaration := &genai.FunctionDeclaration{
140 Name: info.Name,
141 Description: info.Description,
142 Parameters: &genai.Schema{
143 Type: genai.TypeObject,
144 Properties: convertSchemaProperties(info.Parameters),
145 Required: info.Required,
146 },
147 }
148
149 geminiTools = append(geminiTools, &genai.Tool{
150 FunctionDeclarations: []*genai.FunctionDeclaration{declaration},
151 })
152 }
153
154 return geminiTools
155}
156
157func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
158 reasonStr := reason.String()
159 switch {
160 case reasonStr == "STOP":
161 return message.FinishReasonEndTurn
162 case reasonStr == "MAX_TOKENS":
163 return message.FinishReasonMaxTokens
164 case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
165 return message.FinishReasonToolUse
166 default:
167 return message.FinishReasonUnknown
168 }
169}
170
171func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
172 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
173 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
174
175 // Convert tools
176 if len(tools) > 0 {
177 model.Tools = g.convertTools(tools)
178 }
179
180 // Convert messages
181 geminiMessages := g.convertMessages(messages)
182
183 cfg := config.Get()
184 if cfg.Debug {
185 jsonData, _ := json.Marshal(geminiMessages)
186 logging.Debug("Prepared messages", "messages", string(jsonData))
187 }
188
189 attempts := 0
190 for {
191 attempts++
192 chat := model.StartChat()
193 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
194
195 lastMsg := geminiMessages[len(geminiMessages)-1]
196 var lastText string
197 for _, part := range lastMsg.Parts {
198 if text, ok := part.(genai.Text); ok {
199 lastText = string(text)
200 break
201 }
202 }
203
204 resp, err := chat.SendMessage(ctx, genai.Text(lastText))
205 // If there is an error we are going to see if we can retry the call
206 if err != nil {
207 retry, after, retryErr := g.shouldRetry(attempts, err)
208 if retryErr != nil {
209 return nil, retryErr
210 }
211 if retry {
212 logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
213 select {
214 case <-ctx.Done():
215 return nil, ctx.Err()
216 case <-time.After(time.Duration(after) * time.Millisecond):
217 continue
218 }
219 }
220 return nil, retryErr
221 }
222
223 content := ""
224 var toolCalls []message.ToolCall
225
226 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
227 for _, part := range resp.Candidates[0].Content.Parts {
228 switch p := part.(type) {
229 case genai.Text:
230 content = string(p)
231 case genai.FunctionCall:
232 id := "call_" + uuid.New().String()
233 args, _ := json.Marshal(p.Args)
234 toolCalls = append(toolCalls, message.ToolCall{
235 ID: id,
236 Name: p.Name,
237 Input: string(args),
238 Type: "function",
239 })
240 }
241 }
242 }
243
244 return &ProviderResponse{
245 Content: content,
246 ToolCalls: toolCalls,
247 Usage: g.usage(resp),
248 FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
249 }, nil
250 }
251}
252
253func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
254 model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
255 model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
256
257 // Convert tools
258 if len(tools) > 0 {
259 model.Tools = g.convertTools(tools)
260 }
261
262 // Convert messages
263 geminiMessages := g.convertMessages(messages)
264
265 cfg := config.Get()
266 if cfg.Debug {
267 jsonData, _ := json.Marshal(geminiMessages)
268 logging.Debug("Prepared messages", "messages", string(jsonData))
269 }
270
271 attempts := 0
272 eventChan := make(chan ProviderEvent)
273
274 go func() {
275 defer close(eventChan)
276
277 for {
278 attempts++
279 chat := model.StartChat()
280 chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
281
282 lastMsg := geminiMessages[len(geminiMessages)-1]
283 var lastText string
284 for _, part := range lastMsg.Parts {
285 if text, ok := part.(genai.Text); ok {
286 lastText = string(text)
287 break
288 }
289 }
290
291 iter := chat.SendMessageStream(ctx, genai.Text(lastText))
292
293 currentContent := ""
294 toolCalls := []message.ToolCall{}
295 var finalResp *genai.GenerateContentResponse
296
297 eventChan <- ProviderEvent{Type: EventContentStart}
298
299 for {
300 resp, err := iter.Next()
301 if err == iterator.Done {
302 break
303 }
304 if err != nil {
305 retry, after, retryErr := g.shouldRetry(attempts, err)
306 if retryErr != nil {
307 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
308 return
309 }
310 if retry {
311 logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
312 select {
313 case <-ctx.Done():
314 if ctx.Err() != nil {
315 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
316 }
317
318 return
319 case <-time.After(time.Duration(after) * time.Millisecond):
320 break
321 }
322 } else {
323 eventChan <- ProviderEvent{Type: EventError, Error: err}
324 return
325 }
326 }
327
328 finalResp = resp
329
330 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
331 for _, part := range resp.Candidates[0].Content.Parts {
332 switch p := part.(type) {
333 case genai.Text:
334 newText := string(p)
335 delta := newText[len(currentContent):]
336 if delta != "" {
337 eventChan <- ProviderEvent{
338 Type: EventContentDelta,
339 Content: delta,
340 }
341 currentContent = newText
342 }
343 case genai.FunctionCall:
344 id := "call_" + uuid.New().String()
345 args, _ := json.Marshal(p.Args)
346 newCall := message.ToolCall{
347 ID: id,
348 Name: p.Name,
349 Input: string(args),
350 Type: "function",
351 }
352
353 isNew := true
354 for _, existing := range toolCalls {
355 if existing.Name == newCall.Name && existing.Input == newCall.Input {
356 isNew = false
357 break
358 }
359 }
360
361 if isNew {
362 toolCalls = append(toolCalls, newCall)
363 }
364 }
365 }
366 }
367 }
368
369 eventChan <- ProviderEvent{Type: EventContentStop}
370
371 if finalResp != nil {
372 eventChan <- ProviderEvent{
373 Type: EventComplete,
374 Response: &ProviderResponse{
375 Content: currentContent,
376 ToolCalls: toolCalls,
377 Usage: g.usage(finalResp),
378 FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
379 },
380 }
381 return
382 }
383
384 // If we get here, we need to retry
385 if attempts > maxRetries {
386 eventChan <- ProviderEvent{
387 Type: EventError,
388 Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
389 }
390 return
391 }
392
393 // Wait before retrying
394 select {
395 case <-ctx.Done():
396 if ctx.Err() != nil {
397 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
398 }
399 return
400 case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
401 continue
402 }
403 }
404 }()
405
406 return eventChan
407}
408
409func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
410 // Check if error is a rate limit error
411 if attempts > maxRetries {
412 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
413 }
414
415 // Gemini doesn't have a standard error type we can check against
416 // So we'll check the error message for rate limit indicators
417 if errors.Is(err, io.EOF) {
418 return false, 0, err
419 }
420
421 errMsg := err.Error()
422 isRateLimit := false
423
424 // Check for common rate limit error messages
425 if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
426 isRateLimit = true
427 }
428
429 if !isRateLimit {
430 return false, 0, err
431 }
432
433 // Calculate backoff with jitter
434 backoffMs := 2000 * (1 << (attempts - 1))
435 jitterMs := int(float64(backoffMs) * 0.2)
436 retryMs := backoffMs + jitterMs
437
438 return true, int64(retryMs), nil
439}
440
441func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
442 var toolCalls []message.ToolCall
443
444 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
445 for _, part := range resp.Candidates[0].Content.Parts {
446 if funcCall, ok := part.(genai.FunctionCall); ok {
447 id := "call_" + uuid.New().String()
448 args, _ := json.Marshal(funcCall.Args)
449 toolCalls = append(toolCalls, message.ToolCall{
450 ID: id,
451 Name: funcCall.Name,
452 Input: string(args),
453 Type: "function",
454 })
455 }
456 }
457 }
458
459 return toolCalls
460}
461
462func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
463 if resp == nil || resp.UsageMetadata == nil {
464 return TokenUsage{}
465 }
466
467 return TokenUsage{
468 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
469 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
470 CacheCreationTokens: 0, // Not directly provided by Gemini
471 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
472 }
473}
474
475func WithGeminiDisableCache() GeminiOption {
476 return func(options *geminiOptions) {
477 options.disableCache = true
478 }
479}
480
481// Helper functions
482func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
483 var result map[string]interface{}
484 err := json.Unmarshal([]byte(jsonStr), &result)
485 return result, err
486}
487
488func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
489 properties := make(map[string]*genai.Schema)
490
491 for name, param := range parameters {
492 properties[name] = convertToSchema(param)
493 }
494
495 return properties
496}
497
498func convertToSchema(param interface{}) *genai.Schema {
499 schema := &genai.Schema{Type: genai.TypeString}
500
501 paramMap, ok := param.(map[string]interface{})
502 if !ok {
503 return schema
504 }
505
506 if desc, ok := paramMap["description"].(string); ok {
507 schema.Description = desc
508 }
509
510 typeVal, hasType := paramMap["type"]
511 if !hasType {
512 return schema
513 }
514
515 typeStr, ok := typeVal.(string)
516 if !ok {
517 return schema
518 }
519
520 schema.Type = mapJSONTypeToGenAI(typeStr)
521
522 switch typeStr {
523 case "array":
524 schema.Items = processArrayItems(paramMap)
525 case "object":
526 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
527 schema.Properties = convertSchemaProperties(props)
528 }
529 }
530
531 return schema
532}
533
534func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
535 items, ok := paramMap["items"].(map[string]interface{})
536 if !ok {
537 return nil
538 }
539
540 return convertToSchema(items)
541}
542
543func mapJSONTypeToGenAI(jsonType string) genai.Type {
544 switch jsonType {
545 case "string":
546 return genai.TypeString
547 case "number":
548 return genai.TypeNumber
549 case "integer":
550 return genai.TypeInteger
551 case "boolean":
552 return genai.TypeBoolean
553 case "array":
554 return genai.TypeArray
555 case "object":
556 return genai.TypeObject
557 default:
558 return genai.TypeString // Default to string for unknown types
559 }
560}
561
562func contains(s string, substrs ...string) bool {
563 for _, substr := range substrs {
564 if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
565 return true
566 }
567 }
568 return false
569}
570