1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7
8 "github.com/google/generative-ai-go/genai"
9 "github.com/google/uuid"
10 "github.com/kujtimiihoxha/termai/internal/llm/models"
11 "github.com/kujtimiihoxha/termai/internal/llm/tools"
12 "github.com/kujtimiihoxha/termai/internal/message"
13 "google.golang.org/api/iterator"
14 "google.golang.org/api/option"
15)
16
17type geminiProvider struct {
18 client *genai.Client
19 model models.Model
20 maxTokens int32
21 apiKey string
22 systemMessage string
23}
24
25type GeminiOption func(*geminiProvider)
26
27func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
28 provider := &geminiProvider{
29 maxTokens: 5000,
30 }
31
32 for _, opt := range opts {
33 opt(provider)
34 }
35
36 if provider.systemMessage == "" {
37 return nil, errors.New("system message is required")
38 }
39
40 client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
41 if err != nil {
42 return nil, err
43 }
44 provider.client = client
45
46 return provider, nil
47}
48
49func WithGeminiSystemMessage(message string) GeminiOption {
50 return func(p *geminiProvider) {
51 p.systemMessage = message
52 }
53}
54
55func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
56 return func(p *geminiProvider) {
57 p.maxTokens = maxTokens
58 }
59}
60
61func WithGeminiModel(model models.Model) GeminiOption {
62 return func(p *geminiProvider) {
63 p.model = model
64 }
65}
66
67func WithGeminiKey(apiKey string) GeminiOption {
68 return func(p *geminiProvider) {
69 p.apiKey = apiKey
70 }
71}
72
73func (p *geminiProvider) Close() {
74 if p.client != nil {
75 p.client.Close()
76 }
77}
78
79func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
80 var history []*genai.Content
81
82 for _, msg := range messages {
83 switch msg.Role {
84 case message.User:
85 history = append(history, &genai.Content{
86 Parts: []genai.Part{genai.Text(msg.Content().String())},
87 Role: "user",
88 })
89 case message.Assistant:
90 content := &genai.Content{
91 Role: "model",
92 Parts: []genai.Part{},
93 }
94
95 if msg.Content().String() != "" {
96 content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
97 }
98
99 if len(msg.ToolCalls()) > 0 {
100 for _, call := range msg.ToolCalls() {
101 args, _ := parseJsonToMap(call.Input)
102 content.Parts = append(content.Parts, genai.FunctionCall{
103 Name: call.Name,
104 Args: args,
105 })
106 }
107 }
108
109 history = append(history, content)
110 case message.Tool:
111 for _, result := range msg.ToolResults() {
112 response := map[string]interface{}{"result": result.Content}
113 parsed, err := parseJsonToMap(result.Content)
114 if err == nil {
115 response = parsed
116 }
117 var toolCall message.ToolCall
118 for _, msg := range messages {
119 if msg.Role == message.Assistant {
120 for _, call := range msg.ToolCalls() {
121 if call.ID == result.ToolCallID {
122 toolCall = call
123 break
124 }
125 }
126 }
127 }
128
129 history = append(history, &genai.Content{
130 Parts: []genai.Part{genai.FunctionResponse{
131 Name: toolCall.Name,
132 Response: response,
133 }},
134 Role: "function",
135 })
136 }
137 }
138 }
139
140 return history
141}
142
143func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
144 if resp == nil || resp.UsageMetadata == nil {
145 return TokenUsage{}
146 }
147
148 return TokenUsage{
149 InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
150 OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
151 CacheCreationTokens: 0, // Not directly provided by Gemini
152 CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
153 }
154}
155
156func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
157 model := p.client.GenerativeModel(p.model.APIModel)
158 model.SetMaxOutputTokens(p.maxTokens)
159
160 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
161
162 if len(tools) > 0 {
163 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
164 for _, declaration := range declarations {
165 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
166 }
167 }
168
169 chat := model.StartChat()
170 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
171
172 lastUserMsg := messages[len(messages)-1]
173 resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
174 if err != nil {
175 return nil, err
176 }
177
178 var content string
179 var toolCalls []message.ToolCall
180
181 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
182 for _, part := range resp.Candidates[0].Content.Parts {
183 switch p := part.(type) {
184 case genai.Text:
185 content = string(p)
186 case genai.FunctionCall:
187 id := "call_" + uuid.New().String()
188 args, _ := json.Marshal(p.Args)
189 toolCalls = append(toolCalls, message.ToolCall{
190 ID: id,
191 Name: p.Name,
192 Input: string(args),
193 Type: "function",
194 })
195 }
196 }
197 }
198
199 tokenUsage := p.extractTokenUsage(resp)
200
201 return &ProviderResponse{
202 Content: content,
203 ToolCalls: toolCalls,
204 Usage: tokenUsage,
205 }, nil
206}
207
208func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
209 model := p.client.GenerativeModel(p.model.APIModel)
210 model.SetMaxOutputTokens(p.maxTokens)
211
212 model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
213
214 if len(tools) > 0 {
215 declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
216 for _, declaration := range declarations {
217 model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
218 }
219 }
220
221 chat := model.StartChat()
222 chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
223
224 lastUserMsg := messages[len(messages)-1]
225
226 iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
227
228 eventChan := make(chan ProviderEvent)
229
230 go func() {
231 defer close(eventChan)
232
233 var finalResp *genai.GenerateContentResponse
234 currentContent := ""
235 toolCalls := []message.ToolCall{}
236
237 for {
238 resp, err := iter.Next()
239 if err == iterator.Done {
240 break
241 }
242 if err != nil {
243 eventChan <- ProviderEvent{
244 Type: EventError,
245 Error: err,
246 }
247 return
248 }
249
250 finalResp = resp
251
252 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
253 for _, part := range resp.Candidates[0].Content.Parts {
254 switch p := part.(type) {
255 case genai.Text:
256 newText := string(p)
257 eventChan <- ProviderEvent{
258 Type: EventContentDelta,
259 Content: newText,
260 }
261 currentContent += newText
262 case genai.FunctionCall:
263 id := "call_" + uuid.New().String()
264 args, _ := json.Marshal(p.Args)
265 newCall := message.ToolCall{
266 ID: id,
267 Name: p.Name,
268 Input: string(args),
269 Type: "function",
270 }
271
272 isNew := true
273 for _, existing := range toolCalls {
274 if existing.Name == newCall.Name && existing.Input == newCall.Input {
275 isNew = false
276 break
277 }
278 }
279
280 if isNew {
281 toolCalls = append(toolCalls, newCall)
282 }
283 }
284 }
285 }
286 }
287
288 tokenUsage := p.extractTokenUsage(finalResp)
289
290 eventChan <- ProviderEvent{
291 Type: EventComplete,
292 Response: &ProviderResponse{
293 Content: currentContent,
294 ToolCalls: toolCalls,
295 Usage: tokenUsage,
296 FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
297 },
298 }
299 }()
300
301 return eventChan, nil
302}
303
304func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
305 declarations := make([]*genai.FunctionDeclaration, len(tools))
306
307 for i, tool := range tools {
308 info := tool.Info()
309 declarations[i] = &genai.FunctionDeclaration{
310 Name: info.Name,
311 Description: info.Description,
312 Parameters: &genai.Schema{
313 Type: genai.TypeObject,
314 Properties: convertSchemaProperties(info.Parameters),
315 Required: info.Required,
316 },
317 }
318 }
319
320 return declarations
321}
322
323func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
324 properties := make(map[string]*genai.Schema)
325
326 for name, param := range parameters {
327 properties[name] = convertToSchema(param)
328 }
329
330 return properties
331}
332
333func convertToSchema(param interface{}) *genai.Schema {
334 schema := &genai.Schema{Type: genai.TypeString}
335
336 paramMap, ok := param.(map[string]interface{})
337 if !ok {
338 return schema
339 }
340
341 if desc, ok := paramMap["description"].(string); ok {
342 schema.Description = desc
343 }
344
345 typeVal, hasType := paramMap["type"]
346 if !hasType {
347 return schema
348 }
349
350 typeStr, ok := typeVal.(string)
351 if !ok {
352 return schema
353 }
354
355 schema.Type = mapJSONTypeToGenAI(typeStr)
356
357 switch typeStr {
358 case "array":
359 schema.Items = processArrayItems(paramMap)
360 case "object":
361 if props, ok := paramMap["properties"].(map[string]interface{}); ok {
362 schema.Properties = convertSchemaProperties(props)
363 }
364 }
365
366 return schema
367}
368
369func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
370 items, ok := paramMap["items"].(map[string]interface{})
371 if !ok {
372 return nil
373 }
374
375 return convertToSchema(items)
376}
377
378func mapJSONTypeToGenAI(jsonType string) genai.Type {
379 switch jsonType {
380 case "string":
381 return genai.TypeString
382 case "number":
383 return genai.TypeNumber
384 case "integer":
385 return genai.TypeInteger
386 case "boolean":
387 return genai.TypeBoolean
388 case "array":
389 return genai.TypeArray
390 case "object":
391 return genai.TypeObject
392 default:
393 return genai.TypeString // Default to string for unknown types
394 }
395}
396
397func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
398 var result map[string]interface{}
399 err := json.Unmarshal([]byte(jsonStr), &result)
400 return result, err
401}