gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"log"
  8
  9	"github.com/google/generative-ai-go/genai"
 10	"github.com/google/uuid"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/message"
 14	"google.golang.org/api/googleapi"
 15	"google.golang.org/api/iterator"
 16	"google.golang.org/api/option"
 17)
 18
 19type geminiProvider struct {
 20	client        *genai.Client
 21	model         models.Model
 22	maxTokens     int32
 23	apiKey        string
 24	systemMessage string
 25}
 26
 27type GeminiOption func(*geminiProvider)
 28
 29func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
 30	provider := &geminiProvider{
 31		maxTokens: 5000,
 32	}
 33
 34	for _, opt := range opts {
 35		opt(provider)
 36	}
 37
 38	if provider.systemMessage == "" {
 39		return nil, errors.New("system message is required")
 40	}
 41
 42	client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
 43	if err != nil {
 44		return nil, err
 45	}
 46	provider.client = client
 47
 48	return provider, nil
 49}
 50
 51func WithGeminiSystemMessage(message string) GeminiOption {
 52	return func(p *geminiProvider) {
 53		p.systemMessage = message
 54	}
 55}
 56
 57func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
 58	return func(p *geminiProvider) {
 59		p.maxTokens = maxTokens
 60	}
 61}
 62
 63func WithGeminiModel(model models.Model) GeminiOption {
 64	return func(p *geminiProvider) {
 65		p.model = model
 66	}
 67}
 68
 69func WithGeminiKey(apiKey string) GeminiOption {
 70	return func(p *geminiProvider) {
 71		p.apiKey = apiKey
 72	}
 73}
 74
 75func (p *geminiProvider) Close() {
 76	if p.client != nil {
 77		p.client.Close()
 78	}
 79}
 80
 81// convertToGeminiHistory converts the message history to Gemini's format
 82func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
 83	var history []*genai.Content
 84
 85	for _, msg := range messages {
 86		switch msg.Role {
 87		case message.User:
 88			history = append(history, &genai.Content{
 89				Parts: []genai.Part{genai.Text(msg.Content)},
 90				Role:  "user",
 91			})
 92		case message.Assistant:
 93			content := &genai.Content{
 94				Role:  "model",
 95				Parts: []genai.Part{},
 96			}
 97
 98			// Handle regular content
 99			if msg.Content != "" {
100				content.Parts = append(content.Parts, genai.Text(msg.Content))
101			}
102
103			// Handle tool calls if any
104			if len(msg.ToolCalls) > 0 {
105				for _, call := range msg.ToolCalls {
106					args, _ := parseJsonToMap(call.Input)
107					content.Parts = append(content.Parts, genai.FunctionCall{
108						Name: call.Name,
109						Args: args,
110					})
111				}
112			}
113
114			history = append(history, content)
115		case message.Tool:
116			for _, result := range msg.ToolResults {
117				// Parse response content to map if possible
118				response := map[string]interface{}{"result": result.Content}
119				parsed, err := parseJsonToMap(result.Content)
120				if err == nil {
121					response = parsed
122				}
123				var toolCall message.ToolCall
124				for _, msg := range messages {
125					if msg.Role == message.Assistant {
126						for _, call := range msg.ToolCalls {
127							if call.ID == result.ToolCallID {
128								toolCall = call
129								break
130							}
131						}
132					}
133				}
134
135				history = append(history, &genai.Content{
136					Parts: []genai.Part{genai.FunctionResponse{
137						Name:     toolCall.Name,
138						Response: response,
139					}},
140					Role: "function",
141				})
142			}
143		}
144	}
145
146	return history
147}
148
149// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
150func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
151	declarations := make([]*genai.FunctionDeclaration, len(tools))
152
153	for i, tool := range tools {
154		info := tool.Info()
155
156		// Convert parameters to genai.Schema format
157		properties := make(map[string]*genai.Schema)
158		for name, param := range info.Parameters {
159			// Try to extract type and description from the parameter
160			paramMap, ok := param.(map[string]interface{})
161			if !ok {
162				// Default to string if unable to determine type
163				properties[name] = &genai.Schema{Type: genai.TypeString}
164				continue
165			}
166
167			schemaType := genai.TypeString // Default
168			var description string
169			var itemsTypeSchema *genai.Schema
170			if typeVal, found := paramMap["type"]; found {
171				if typeStr, ok := typeVal.(string); ok {
172					switch typeStr {
173					case "string":
174						schemaType = genai.TypeString
175					case "number":
176						schemaType = genai.TypeNumber
177					case "integer":
178						schemaType = genai.TypeInteger
179					case "boolean":
180						schemaType = genai.TypeBoolean
181					case "array":
182						schemaType = genai.TypeArray
183						items, found := paramMap["items"]
184						if found {
185							itemsMap, ok := items.(map[string]interface{})
186							if ok {
187								itemsType, found := itemsMap["type"]
188								if found {
189									itemsTypeStr, ok := itemsType.(string)
190									if ok {
191										switch itemsTypeStr {
192										case "string":
193											itemsTypeSchema = &genai.Schema{
194												Type: genai.TypeString,
195											}
196										case "number":
197											itemsTypeSchema = &genai.Schema{
198												Type: genai.TypeNumber,
199											}
200										case "integer":
201											itemsTypeSchema = &genai.Schema{
202												Type: genai.TypeInteger,
203											}
204										case "boolean":
205											itemsTypeSchema = &genai.Schema{
206												Type: genai.TypeBoolean,
207											}
208										}
209									}
210								}
211							}
212						}
213					case "object":
214						schemaType = genai.TypeObject
215						if _, found := paramMap["properties"]; !found {
216							continue
217						}
218						// TODO: Add support for other types
219					}
220				}
221			}
222
223			if desc, found := paramMap["description"]; found {
224				if descStr, ok := desc.(string); ok {
225					description = descStr
226				}
227			}
228
229			properties[name] = &genai.Schema{
230				Type:        schemaType,
231				Description: description,
232				Items:       itemsTypeSchema,
233			}
234		}
235
236		declarations[i] = &genai.FunctionDeclaration{
237			Name:        info.Name,
238			Description: info.Description,
239			Parameters: &genai.Schema{
240				Type:       genai.TypeObject,
241				Properties: properties,
242				Required:   info.Required,
243			},
244		}
245	}
246
247	return declarations
248}
249
250// extractTokenUsage extracts token usage information from Gemini's response
251func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
252	if resp == nil || resp.UsageMetadata == nil {
253		return TokenUsage{}
254	}
255
256	return TokenUsage{
257		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
258		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
259		CacheCreationTokens: 0, // Not directly provided by Gemini
260		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
261	}
262}
263
264// SendMessages sends a batch of messages to Gemini and returns the response
265func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
266	// Create a generative model
267	model := p.client.GenerativeModel(p.model.APIModel)
268	model.SetMaxOutputTokens(p.maxTokens)
269
270	// Set system instruction
271	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
272
273	// Set up tools if provided
274	if len(tools) > 0 {
275		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
276		model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
277	}
278
279	// Create chat session and set history
280	chat := model.StartChat()
281	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
282
283	// Get the most recent user message
284	var lastUserMsg message.Message
285	for i := len(messages) - 1; i >= 0; i-- {
286		if messages[i].Role == message.User {
287			lastUserMsg = messages[i]
288			break
289		}
290	}
291
292	// Send the message
293	resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
294	if err != nil {
295		return nil, err
296	}
297
298	// Process the response
299	var content string
300	var toolCalls []message.ToolCall
301
302	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
303		for _, part := range resp.Candidates[0].Content.Parts {
304			switch p := part.(type) {
305			case genai.Text:
306				content = string(p)
307			case genai.FunctionCall:
308				id := "call_" + uuid.New().String()
309				args, _ := json.Marshal(p.Args)
310				toolCalls = append(toolCalls, message.ToolCall{
311					ID:    id,
312					Name:  p.Name,
313					Input: string(args),
314					Type:  "function",
315				})
316			}
317		}
318	}
319
320	// Extract token usage
321	tokenUsage := p.extractTokenUsage(resp)
322
323	return &ProviderResponse{
324		Content:   content,
325		ToolCalls: toolCalls,
326		Usage:     tokenUsage,
327	}, nil
328}
329
330// StreamResponse streams the response from Gemini
331func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
332	// Create a generative model
333	model := p.client.GenerativeModel(p.model.APIModel)
334	model.SetMaxOutputTokens(p.maxTokens)
335
336	// Set system instruction
337	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
338
339	// Set up tools if provided
340	if len(tools) > 0 {
341		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
342		for _, declaration := range declarations {
343			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
344		}
345	}
346
347	// Create chat session and set history
348	chat := model.StartChat()
349	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
350
351	lastUserMsg := messages[len(messages)-1]
352
353	// Start streaming
354	iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
355
356	eventChan := make(chan ProviderEvent)
357
358	go func() {
359		defer close(eventChan)
360
361		var finalResp *genai.GenerateContentResponse
362		currentContent := ""
363		toolCalls := []message.ToolCall{}
364
365		for {
366			resp, err := iter.Next()
367			if err == iterator.Done {
368				break
369			}
370			if err != nil {
371				var apiErr *googleapi.Error
372				if errors.As(err, &apiErr) {
373					log.Printf("%s", apiErr.Body)
374				}
375				eventChan <- ProviderEvent{
376					Type:  EventError,
377					Error: err,
378				}
379				return
380			}
381
382			finalResp = resp
383
384			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
385				for _, part := range resp.Candidates[0].Content.Parts {
386					switch p := part.(type) {
387					case genai.Text:
388						newText := string(p)
389						eventChan <- ProviderEvent{
390							Type:    EventContentDelta,
391							Content: newText,
392						}
393						currentContent += newText
394					case genai.FunctionCall:
395						// For function calls, we assume they come complete, not streamed in parts
396						id := "call_" + uuid.New().String()
397						args, _ := json.Marshal(p.Args)
398						newCall := message.ToolCall{
399							ID:    id,
400							Name:  p.Name,
401							Input: string(args),
402							Type:  "function",
403						}
404
405						// Check if this is a new tool call
406						isNew := true
407						for _, existing := range toolCalls {
408							if existing.Name == newCall.Name && existing.Input == newCall.Input {
409								isNew = false
410								break
411							}
412						}
413
414						if isNew {
415							toolCalls = append(toolCalls, newCall)
416						}
417					}
418				}
419			}
420		}
421
422		// Extract token usage from the final response
423		tokenUsage := p.extractTokenUsage(finalResp)
424
425		eventChan <- ProviderEvent{
426			Type: EventComplete,
427			Response: &ProviderResponse{
428				Content:   currentContent,
429				ToolCalls: toolCalls,
430				Usage:     tokenUsage,
431			},
432		}
433	}()
434
435	return eventChan, nil
436}
437
438// Helper function to parse JSON string into map
439func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
440	var result map[string]interface{}
441	err := json.Unmarshal([]byte(jsonStr), &result)
442	return result, err
443}