1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "log"
8
9 "github.com/google/generative-ai-go/genai"
10 "github.com/google/uuid"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/llm/tools"
13 "github.com/kujtimiihoxha/termai/internal/message"
14 "google.golang.org/api/googleapi"
15 "google.golang.org/api/iterator"
16 "google.golang.org/api/option"
17)
18
19type geminiProvider struct {
20 client *genai.Client
21 model models.Model
22 maxTokens int32
23 apiKey string
24 systemMessage string
25}
26
27type GeminiOption func(*geminiProvider)
28
29func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
30 provider := &geminiProvider{
31 maxTokens: 5000,
32 }
33
34 for _, opt := range opts {
35 opt(provider)
36 }
37
38 if provider.systemMessage == "" {
39 return nil, errors.New("system message is required")
40 }
41
42 client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
43 if err != nil {
44 return nil, err
45 }
46 provider.client = client
47
48 return provider, nil
49}
50
51func WithGeminiSystemMessage(message string) GeminiOption {
52 return func(p *geminiProvider) {
53 p.systemMessage = message
54 }
55}
56
57func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
58 return func(p *geminiProvider) {
59 p.maxTokens = maxTokens
60 }
61}
62
63func WithGeminiModel(model models.Model) GeminiOption {
64 return func(p *geminiProvider) {
65 p.model = model
66 }
67}
68
69func WithGeminiKey(apiKey string) GeminiOption {
70 return func(p *geminiProvider) {
71 p.apiKey = apiKey
72 }
73}
74
75func (p *geminiProvider) Close() {
76 if p.client != nil {
77 p.client.Close()
78 }
79}
80
81func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
82 var history []*genai.Content
83
84 for _, msg := range messages {
85 switch msg.Role {
86 case message.User:
87 history = append(history, &genai.Content{
88 Parts: []genai.Part{genai.Text(msg.Content().String())},
89 Role: "user",
90 })
91 case message.Assistant:
92 content := &genai.Content{
93 Role: "model",
94 Parts: []genai.Part{},
95 }
96
97 if msg.Content().String() != "" {
98 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
99 }
100
101 if len(msg.ToolCalls()) > 0 {
102 for _, call := range msg.ToolCalls() {
103 args, _ := parseJsonToMap(call.Input)
104 content.Parts = append(content.Parts, genai.FunctionCall{
105 Name: call.Name,
106 Args: args,
107 })
108 }
109 }
110
111 history = append(history, content)
112 case message.Tool:
113 for _, result := range msg.ToolResults() {
114 response := map[string]interface{}{"result": result.Content}
115 parsed, err := parseJsonToMap(result.Content)
116 if err == nil {
117 response = parsed
118 }
119 var toolCall message.ToolCall
120 for _, msg := range messages {
121 if msg.Role == message.Assistant {
122 for _, call := range msg.ToolCalls() {
123 if call.ID == result.ToolCallID {
124 toolCall = call
125 break
126 }
127 }
128 }
129 }
130
131 history = append(history, &genai.Content{
132 Parts: []genai.Part{genai.FunctionResponse{
133 Name: toolCall.Name,
134 Response: response,
135 }},
136 Role: "function",
137 })
138 }
139 }
140 }
141
142 return history
143}
144
145func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
146 if resp == nil || resp.UsageMetadata == nil {
147 return TokenUsage{}
148 }
149
150 return TokenUsage{
151 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
152 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
153 CacheCreationTokens: 0, // Not directly provided by Gemini
154 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
155 }
156}
157
158func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
159 model := p.client.GenerativeModel(p.model.APIModel)
160 model.SetMaxOutputTokens(p.maxTokens)
161
162 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
163
164 if len(tools) > 0 {
165 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
166 for _, declaration := range declarations {
167 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
168 }
169 }
170
171 chat := model.StartChat()
172 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
173
174 lastUserMsg := messages[len(messages)-1]
175 resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
176 if err != nil {
177 return nil, err
178 }
179
180 var content string
181 var toolCalls []message.ToolCall
182
183 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
184 for _, part := range resp.Candidates[0].Content.Parts {
185 switch p := part.(type) {
186 case genai.Text:
187 content = string(p)
188 case genai.FunctionCall:
189 id := "call_" + uuid.New().String()
190 args, _ := json.Marshal(p.Args)
191 toolCalls = append(toolCalls, message.ToolCall{
192 ID: id,
193 Name: p.Name,
194 Input: string(args),
195 Type: "function",
196 })
197 }
198 }
199 }
200
201 tokenUsage := p.extractTokenUsage(resp)
202
203 return &ProviderResponse{
204 Content: content,
205 ToolCalls: toolCalls,
206 Usage: tokenUsage,
207 }, nil
208}
209
210func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
211 model := p.client.GenerativeModel(p.model.APIModel)
212 model.SetMaxOutputTokens(p.maxTokens)
213
214 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
215
216 if len(tools) > 0 {
217 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
218 for _, declaration := range declarations {
219 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
220 }
221 }
222
223 chat := model.StartChat()
224 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
225
226 lastUserMsg := messages[len(messages)-1]
227
228 iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
229
230 eventChan := make(chan ProviderEvent)
231
232 go func() {
233 defer close(eventChan)
234
235 var finalResp *genai.GenerateContentResponse
236 currentContent := ""
237 toolCalls := []message.ToolCall{}
238
239 for {
240 resp, err := iter.Next()
241 if err == iterator.Done {
242 break
243 }
244 if err != nil {
245 var apiErr *googleapi.Error
246 if errors.As(err, &apiErr) {
247 log.Printf("%s", apiErr.Body)
248 }
249 eventChan <- ProviderEvent{
250 Type: EventError,
251 Error: err,
252 }
253 return
254 }
255
256 finalResp = resp
257
258 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
259 for _, part := range resp.Candidates[0].Content.Parts {
260 switch p := part.(type) {
261 case genai.Text:
262 newText := string(p)
263 eventChan <- ProviderEvent{
264 Type: EventContentDelta,
265 Content: newText,
266 }
267 currentContent += newText
268 case genai.FunctionCall:
269 id := "call_" + uuid.New().String()
270 args, _ := json.Marshal(p.Args)
271 newCall := message.ToolCall{
272 ID: id,
273 Name: p.Name,
274 Input: string(args),
275 Type: "function",
276 }
277
278 isNew := true
279 for _, existing := range toolCalls {
280 if existing.Name == newCall.Name && existing.Input == newCall.Input {
281 isNew = false
282 break
283 }
284 }
285
286 if isNew {
287 toolCalls = append(toolCalls, newCall)
288 }
289 }
290 }
291 }
292 }
293
294 tokenUsage := p.extractTokenUsage(finalResp)
295
296 eventChan <- ProviderEvent{
297 Type: EventComplete,
298 Response: &ProviderResponse{
299 Content: currentContent,
300 ToolCalls: toolCalls,
301 Usage: tokenUsage,
302 FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
303 },
304 }
305 }()
306
307 return eventChan, nil
308}
309
310func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
311 declarations := make([]*genai.FunctionDeclaration, len(tools))
312
313 for i, tool := range tools {
314 info := tool.Info()
315 declarations[i] = &genai.FunctionDeclaration{
316 Name: info.Name,
317 Description: info.Description,
318 Parameters: &genai.Schema{
319 Type: genai.TypeObject,
320 Properties: convertSchemaProperties(info.Parameters),
321 Required: info.Required,
322 },
323 }
324 }
325
326 return declarations
327}
328
329func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
330 properties := make(map[string]*genai.Schema)
331
332 for name, param := range parameters {
333 properties[name] = convertToSchema(param)
334 }
335
336 return properties
337}
338
339func convertToSchema(param interface{}) *genai.Schema {
340 schema := &genai.Schema{Type: genai.TypeString}
341
342 paramMap, ok := param.(map[string]interface{})
343 if !ok {
344 return schema
345 }
346
347 if desc, ok := paramMap["description"].(string); ok {
348 schema.Description = desc
349 }
350
351 typeVal, hasType := paramMap["type"]
352 if !hasType {
353 return schema
354 }
355
356 typeStr, ok := typeVal.(string)
357 if !ok {
358 return schema
359 }
360
361 schema.Type = mapJSONTypeToGenAI(typeStr)
362
363 switch typeStr {
364 case "array":
365 schema.Items = processArrayItems(paramMap)
366 case "object":
367 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
368 schema.Properties = convertSchemaProperties(props)
369 }
370 }
371
372 return schema
373}
374
375func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
376 items, ok := paramMap["items"].(map[string]interface{})
377 if !ok {
378 return nil
379 }
380
381 return convertToSchema(items)
382}
383
384func mapJSONTypeToGenAI(jsonType string) genai.Type {
385 switch jsonType {
386 case "string":
387 return genai.TypeString
388 case "number":
389 return genai.TypeNumber
390 case "integer":
391 return genai.TypeInteger
392 case "boolean":
393 return genai.TypeBoolean
394 case "array":
395 return genai.TypeArray
396 case "object":
397 return genai.TypeObject
398 default:
399 return genai.TypeString // Default to string for unknown types
400 }
401}
402
403func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
404 var result map[string]interface{}
405 err := json.Unmarshal([]byte(jsonStr), &result)
406 return result, err
407}