1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "log"
8
9 "github.com/google/generative-ai-go/genai"
10 "github.com/google/uuid"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/llm/tools"
13 "github.com/kujtimiihoxha/termai/internal/message"
14 "google.golang.org/api/googleapi"
15 "google.golang.org/api/iterator"
16 "google.golang.org/api/option"
17)
18
19type geminiProvider struct {
20 client *genai.Client
21 model models.Model
22 maxTokens int32
23 apiKey string
24 systemMessage string
25}
26
27type GeminiOption func(*geminiProvider)
28
29func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
30 provider := &geminiProvider{
31 maxTokens: 5000,
32 }
33
34 for _, opt := range opts {
35 opt(provider)
36 }
37
38 if provider.systemMessage == "" {
39 return nil, errors.New("system message is required")
40 }
41
42 client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
43 if err != nil {
44 return nil, err
45 }
46 provider.client = client
47
48 return provider, nil
49}
50
51func WithGeminiSystemMessage(message string) GeminiOption {
52 return func(p *geminiProvider) {
53 p.systemMessage = message
54 }
55}
56
57func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
58 return func(p *geminiProvider) {
59 p.maxTokens = maxTokens
60 }
61}
62
63func WithGeminiModel(model models.Model) GeminiOption {
64 return func(p *geminiProvider) {
65 p.model = model
66 }
67}
68
69func WithGeminiKey(apiKey string) GeminiOption {
70 return func(p *geminiProvider) {
71 p.apiKey = apiKey
72 }
73}
74
75func (p *geminiProvider) Close() {
76 if p.client != nil {
77 p.client.Close()
78 }
79}
80
81// convertToGeminiHistory converts the message history to Gemini's format
82func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
83 var history []*genai.Content
84
85 for _, msg := range messages {
86 switch msg.Role {
87 case message.User:
88 history = append(history, &genai.Content{
89 Parts: []genai.Part{genai.Text(msg.Content)},
90 Role: "user",
91 })
92 case message.Assistant:
93 content := &genai.Content{
94 Role: "model",
95 Parts: []genai.Part{},
96 }
97
98 // Handle regular content
99 if msg.Content != "" {
100 content.Parts = append(content.Parts, genai.Text(msg.Content))
101 }
102
103 // Handle tool calls if any
104 if len(msg.ToolCalls) > 0 {
105 for _, call := range msg.ToolCalls {
106 args, _ := parseJsonToMap(call.Input)
107 content.Parts = append(content.Parts, genai.FunctionCall{
108 Name: call.Name,
109 Args: args,
110 })
111 }
112 }
113
114 history = append(history, content)
115 case message.Tool:
116 for _, result := range msg.ToolResults {
117 // Parse response content to map if possible
118 response := map[string]interface{}{"result": result.Content}
119 parsed, err := parseJsonToMap(result.Content)
120 if err == nil {
121 response = parsed
122 }
123 var toolCall message.ToolCall
124 for _, msg := range messages {
125 if msg.Role == message.Assistant {
126 for _, call := range msg.ToolCalls {
127 if call.ID == result.ToolCallID {
128 toolCall = call
129 break
130 }
131 }
132 }
133 }
134
135 history = append(history, &genai.Content{
136 Parts: []genai.Part{genai.FunctionResponse{
137 Name: toolCall.Name,
138 Response: response,
139 }},
140 Role: "function",
141 })
142 }
143 }
144 }
145
146 return history
147}
148
149// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
150func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
151 declarations := make([]*genai.FunctionDeclaration, len(tools))
152
153 for i, tool := range tools {
154 info := tool.Info()
155
156 // Convert parameters to genai.Schema format
157 properties := make(map[string]*genai.Schema)
158 for name, param := range info.Parameters {
159 // Try to extract type and description from the parameter
160 paramMap, ok := param.(map[string]interface{})
161 if !ok {
162 // Default to string if unable to determine type
163 properties[name] = &genai.Schema{Type: genai.TypeString}
164 continue
165 }
166
167 schemaType := genai.TypeString // Default
168 var description string
169 var itemsTypeSchema *genai.Schema
170 if typeVal, found := paramMap["type"]; found {
171 if typeStr, ok := typeVal.(string); ok {
172 switch typeStr {
173 case "string":
174 schemaType = genai.TypeString
175 case "number":
176 schemaType = genai.TypeNumber
177 case "integer":
178 schemaType = genai.TypeInteger
179 case "boolean":
180 schemaType = genai.TypeBoolean
181 case "array":
182 schemaType = genai.TypeArray
183 items, found := paramMap["items"]
184 if found {
185 itemsMap, ok := items.(map[string]interface{})
186 if ok {
187 itemsType, found := itemsMap["type"]
188 if found {
189 itemsTypeStr, ok := itemsType.(string)
190 if ok {
191 switch itemsTypeStr {
192 case "string":
193 itemsTypeSchema = &genai.Schema{
194 Type: genai.TypeString,
195 }
196 case "number":
197 itemsTypeSchema = &genai.Schema{
198 Type: genai.TypeNumber,
199 }
200 case "integer":
201 itemsTypeSchema = &genai.Schema{
202 Type: genai.TypeInteger,
203 }
204 case "boolean":
205 itemsTypeSchema = &genai.Schema{
206 Type: genai.TypeBoolean,
207 }
208 }
209 }
210 }
211 }
212 }
213 case "object":
214 schemaType = genai.TypeObject
215 if _, found := paramMap["properties"]; !found {
216 continue
217 }
218 // TODO: Add support for other types
219 }
220 }
221 }
222
223 if desc, found := paramMap["description"]; found {
224 if descStr, ok := desc.(string); ok {
225 description = descStr
226 }
227 }
228
229 properties[name] = &genai.Schema{
230 Type: schemaType,
231 Description: description,
232 Items: itemsTypeSchema,
233 }
234 }
235
236 declarations[i] = &genai.FunctionDeclaration{
237 Name: info.Name,
238 Description: info.Description,
239 Parameters: &genai.Schema{
240 Type: genai.TypeObject,
241 Properties: properties,
242 Required: info.Required,
243 },
244 }
245 }
246
247 return declarations
248}
249
250// extractTokenUsage extracts token usage information from Gemini's response
251func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
252 if resp == nil || resp.UsageMetadata == nil {
253 return TokenUsage{}
254 }
255
256 return TokenUsage{
257 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
258 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
259 CacheCreationTokens: 0, // Not directly provided by Gemini
260 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
261 }
262}
263
264// SendMessages sends a batch of messages to Gemini and returns the response
265func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
266 // Create a generative model
267 model := p.client.GenerativeModel(p.model.APIModel)
268 model.SetMaxOutputTokens(p.maxTokens)
269
270 // Set system instruction
271 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
272
273 // Set up tools if provided
274 if len(tools) > 0 {
275 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
276 model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
277 }
278
279 // Create chat session and set history
280 chat := model.StartChat()
281 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
282
283 // Get the most recent user message
284 var lastUserMsg message.Message
285 for i := len(messages) - 1; i >= 0; i-- {
286 if messages[i].Role == message.User {
287 lastUserMsg = messages[i]
288 break
289 }
290 }
291
292 // Send the message
293 resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
294 if err != nil {
295 return nil, err
296 }
297
298 // Process the response
299 var content string
300 var toolCalls []message.ToolCall
301
302 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
303 for _, part := range resp.Candidates[0].Content.Parts {
304 switch p := part.(type) {
305 case genai.Text:
306 content = string(p)
307 case genai.FunctionCall:
308 id := "call_" + uuid.New().String()
309 args, _ := json.Marshal(p.Args)
310 toolCalls = append(toolCalls, message.ToolCall{
311 ID: id,
312 Name: p.Name,
313 Input: string(args),
314 Type: "function",
315 })
316 }
317 }
318 }
319
320 // Extract token usage
321 tokenUsage := p.extractTokenUsage(resp)
322
323 return &ProviderResponse{
324 Content: content,
325 ToolCalls: toolCalls,
326 Usage: tokenUsage,
327 }, nil
328}
329
330// StreamResponse streams the response from Gemini
331func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
332 // Create a generative model
333 model := p.client.GenerativeModel(p.model.APIModel)
334 model.SetMaxOutputTokens(p.maxTokens)
335
336 // Set system instruction
337 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
338
339 // Set up tools if provided
340 if len(tools) > 0 {
341 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
342 for _, declaration := range declarations {
343 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
344 }
345 }
346
347 // Create chat session and set history
348 chat := model.StartChat()
349 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
350
351 lastUserMsg := messages[len(messages)-1]
352
353 // Start streaming
354 iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
355
356 eventChan := make(chan ProviderEvent)
357
358 go func() {
359 defer close(eventChan)
360
361 var finalResp *genai.GenerateContentResponse
362 currentContent := ""
363 toolCalls := []message.ToolCall{}
364
365 for {
366 resp, err := iter.Next()
367 if err == iterator.Done {
368 break
369 }
370 if err != nil {
371 var apiErr *googleapi.Error
372 if errors.As(err, &apiErr) {
373 log.Printf("%s", apiErr.Body)
374 }
375 eventChan <- ProviderEvent{
376 Type: EventError,
377 Error: err,
378 }
379 return
380 }
381
382 finalResp = resp
383
384 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
385 for _, part := range resp.Candidates[0].Content.Parts {
386 switch p := part.(type) {
387 case genai.Text:
388 newText := string(p)
389 eventChan <- ProviderEvent{
390 Type: EventContentDelta,
391 Content: newText,
392 }
393 currentContent += newText
394 case genai.FunctionCall:
395 // For function calls, we assume they come complete, not streamed in parts
396 id := "call_" + uuid.New().String()
397 args, _ := json.Marshal(p.Args)
398 newCall := message.ToolCall{
399 ID: id,
400 Name: p.Name,
401 Input: string(args),
402 Type: "function",
403 }
404
405 // Check if this is a new tool call
406 isNew := true
407 for _, existing := range toolCalls {
408 if existing.Name == newCall.Name && existing.Input == newCall.Input {
409 isNew = false
410 break
411 }
412 }
413
414 if isNew {
415 toolCalls = append(toolCalls, newCall)
416 }
417 }
418 }
419 }
420 }
421
422 // Extract token usage from the final response
423 tokenUsage := p.extractTokenUsage(finalResp)
424
425 eventChan <- ProviderEvent{
426 Type: EventComplete,
427 Response: &ProviderResponse{
428 Content: currentContent,
429 ToolCalls: toolCalls,
430 Usage: tokenUsage,
431 },
432 }
433 }()
434
435 return eventChan, nil
436}
437
438// Helper function to parse JSON string into map
439func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
440 var result map[string]interface{}
441 err := json.Unmarshal([]byte(jsonStr), &result)
442 return result, err
443}