1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7
8 "github.com/google/generative-ai-go/genai"
9 "github.com/google/uuid"
10 "github.com/kujtimiihoxha/termai/internal/llm/models"
11 "github.com/kujtimiihoxha/termai/internal/llm/tools"
12 "github.com/kujtimiihoxha/termai/internal/message"
13 "google.golang.org/api/iterator"
14 "google.golang.org/api/option"
15)
16
17type geminiProvider struct {
18 client *genai.Client
19 model models.Model
20 maxTokens int32
21 apiKey string
22 systemMessage string
23}
24
25type GeminiOption func(*geminiProvider)
26
27func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
28 provider := &geminiProvider{
29 maxTokens: 5000,
30 }
31
32 for _, opt := range opts {
33 opt(provider)
34 }
35
36 if provider.systemMessage == "" {
37 return nil, errors.New("system message is required")
38 }
39
40 client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
41 if err != nil {
42 return nil, err
43 }
44 provider.client = client
45
46 return provider, nil
47}
48
49func WithGeminiSystemMessage(message string) GeminiOption {
50 return func(p *geminiProvider) {
51 p.systemMessage = message
52 }
53}
54
55func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
56 return func(p *geminiProvider) {
57 p.maxTokens = maxTokens
58 }
59}
60
61func WithGeminiModel(model models.Model) GeminiOption {
62 return func(p *geminiProvider) {
63 p.model = model
64 }
65}
66
67func WithGeminiKey(apiKey string) GeminiOption {
68 return func(p *geminiProvider) {
69 p.apiKey = apiKey
70 }
71}
72
73func (p *geminiProvider) Close() {
74 if p.client != nil {
75 p.client.Close()
76 }
77}
78
79func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
80 var history []*genai.Content
81
82 for _, msg := range messages {
83 switch msg.Role {
84 case message.User:
85 history = append(history, &genai.Content{
86 Parts: []genai.Part{genai.Text(msg.Content().String())},
87 Role: "user",
88 })
89 case message.Assistant:
90 content := &genai.Content{
91 Role: "model",
92 Parts: []genai.Part{},
93 }
94
95 if msg.Content().String() != "" {
96 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
97 }
98
99 if len(msg.ToolCalls()) > 0 {
100 for _, call := range msg.ToolCalls() {
101 args, _ := parseJsonToMap(call.Input)
102 content.Parts = append(content.Parts, genai.FunctionCall{
103 Name: call.Name,
104 Args: args,
105 })
106 }
107 }
108
109 history = append(history, content)
110 case message.Tool:
111 for _, result := range msg.ToolResults() {
112 response := map[string]interface{}{"result": result.Content}
113 parsed, err := parseJsonToMap(result.Content)
114 if err == nil {
115 response = parsed
116 }
117 var toolCall message.ToolCall
118 for _, msg := range messages {
119 if msg.Role == message.Assistant {
120 for _, call := range msg.ToolCalls() {
121 if call.ID == result.ToolCallID {
122 toolCall = call
123 break
124 }
125 }
126 }
127 }
128
129 history = append(history, &genai.Content{
130 Parts: []genai.Part{genai.FunctionResponse{
131 Name: toolCall.Name,
132 Response: response,
133 }},
134 Role: "function",
135 })
136 }
137 }
138 }
139
140 return history
141}
142
143func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
144 if resp == nil || resp.UsageMetadata == nil {
145 return TokenUsage{}
146 }
147
148 return TokenUsage{
149 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
150 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
151 CacheCreationTokens: 0, // Not directly provided by Gemini
152 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
153 }
154}
155
156func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
157 messages = cleanupMessages(messages)
158 model := p.client.GenerativeModel(p.model.APIModel)
159 model.SetMaxOutputTokens(p.maxTokens)
160
161 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
162
163 if len(tools) > 0 {
164 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
165 for _, declaration := range declarations {
166 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
167 }
168 }
169
170 chat := model.StartChat()
171 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
172
173 lastUserMsg := messages[len(messages)-1]
174 resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
175 if err != nil {
176 return nil, err
177 }
178
179 var content string
180 var toolCalls []message.ToolCall
181
182 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
183 for _, part := range resp.Candidates[0].Content.Parts {
184 switch p := part.(type) {
185 case genai.Text:
186 content = string(p)
187 case genai.FunctionCall:
188 id := "call_" + uuid.New().String()
189 args, _ := json.Marshal(p.Args)
190 toolCalls = append(toolCalls, message.ToolCall{
191 ID: id,
192 Name: p.Name,
193 Input: string(args),
194 Type: "function",
195 })
196 }
197 }
198 }
199
200 tokenUsage := p.extractTokenUsage(resp)
201
202 return &ProviderResponse{
203 Content: content,
204 ToolCalls: toolCalls,
205 Usage: tokenUsage,
206 }, nil
207}
208
209func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
210 messages = cleanupMessages(messages)
211 model := p.client.GenerativeModel(p.model.APIModel)
212 model.SetMaxOutputTokens(p.maxTokens)
213
214 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
215
216 if len(tools) > 0 {
217 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
218 for _, declaration := range declarations {
219 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
220 }
221 }
222
223 chat := model.StartChat()
224 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
225
226 lastUserMsg := messages[len(messages)-1]
227
228 iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
229
230 eventChan := make(chan ProviderEvent)
231
232 go func() {
233 defer close(eventChan)
234
235 var finalResp *genai.GenerateContentResponse
236 currentContent := ""
237 toolCalls := []message.ToolCall{}
238
239 for {
240 resp, err := iter.Next()
241 if err == iterator.Done {
242 break
243 }
244 if err != nil {
245 eventChan <- ProviderEvent{
246 Type: EventError,
247 Error: err,
248 }
249 return
250 }
251
252 finalResp = resp
253
254 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
255 for _, part := range resp.Candidates[0].Content.Parts {
256 switch p := part.(type) {
257 case genai.Text:
258 newText := string(p)
259 eventChan <- ProviderEvent{
260 Type: EventContentDelta,
261 Content: newText,
262 }
263 currentContent += newText
264 case genai.FunctionCall:
265 id := "call_" + uuid.New().String()
266 args, _ := json.Marshal(p.Args)
267 newCall := message.ToolCall{
268 ID: id,
269 Name: p.Name,
270 Input: string(args),
271 Type: "function",
272 }
273
274 isNew := true
275 for _, existing := range toolCalls {
276 if existing.Name == newCall.Name && existing.Input == newCall.Input {
277 isNew = false
278 break
279 }
280 }
281
282 if isNew {
283 toolCalls = append(toolCalls, newCall)
284 }
285 }
286 }
287 }
288 }
289
290 tokenUsage := p.extractTokenUsage(finalResp)
291
292 eventChan <- ProviderEvent{
293 Type: EventComplete,
294 Response: &ProviderResponse{
295 Content: currentContent,
296 ToolCalls: toolCalls,
297 Usage: tokenUsage,
298 FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
299 },
300 }
301 }()
302
303 return eventChan, nil
304}
305
306func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
307 declarations := make([]*genai.FunctionDeclaration, len(tools))
308
309 for i, tool := range tools {
310 info := tool.Info()
311 declarations[i] = &genai.FunctionDeclaration{
312 Name: info.Name,
313 Description: info.Description,
314 Parameters: &genai.Schema{
315 Type: genai.TypeObject,
316 Properties: convertSchemaProperties(info.Parameters),
317 Required: info.Required,
318 },
319 }
320 }
321
322 return declarations
323}
324
325func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
326 properties := make(map[string]*genai.Schema)
327
328 for name, param := range parameters {
329 properties[name] = convertToSchema(param)
330 }
331
332 return properties
333}
334
335func convertToSchema(param interface{}) *genai.Schema {
336 schema := &genai.Schema{Type: genai.TypeString}
337
338 paramMap, ok := param.(map[string]interface{})
339 if !ok {
340 return schema
341 }
342
343 if desc, ok := paramMap["description"].(string); ok {
344 schema.Description = desc
345 }
346
347 typeVal, hasType := paramMap["type"]
348 if !hasType {
349 return schema
350 }
351
352 typeStr, ok := typeVal.(string)
353 if !ok {
354 return schema
355 }
356
357 schema.Type = mapJSONTypeToGenAI(typeStr)
358
359 switch typeStr {
360 case "array":
361 schema.Items = processArrayItems(paramMap)
362 case "object":
363 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
364 schema.Properties = convertSchemaProperties(props)
365 }
366 }
367
368 return schema
369}
370
371func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
372 items, ok := paramMap["items"].(map[string]interface{})
373 if !ok {
374 return nil
375 }
376
377 return convertToSchema(items)
378}
379
380func mapJSONTypeToGenAI(jsonType string) genai.Type {
381 switch jsonType {
382 case "string":
383 return genai.TypeString
384 case "number":
385 return genai.TypeNumber
386 case "integer":
387 return genai.TypeInteger
388 case "boolean":
389 return genai.TypeBoolean
390 case "array":
391 return genai.TypeArray
392 case "object":
393 return genai.TypeObject
394 default:
395 return genai.TypeString // Default to string for unknown types
396 }
397}
398
399func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
400 var result map[string]interface{}
401 err := json.Unmarshal([]byte(jsonStr), &result)
402 return result, err
403}