gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"log"
  8
  9	"github.com/google/generative-ai-go/genai"
 10	"github.com/google/uuid"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/message"
 14	"google.golang.org/api/googleapi"
 15	"google.golang.org/api/iterator"
 16	"google.golang.org/api/option"
 17)
 18
 19type geminiProvider struct {
 20	client        *genai.Client
 21	model         models.Model
 22	maxTokens     int32
 23	apiKey        string
 24	systemMessage string
 25}
 26
 27type GeminiOption func(*geminiProvider)
 28
 29func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
 30	provider := &geminiProvider{
 31		maxTokens: 5000,
 32	}
 33
 34	for _, opt := range opts {
 35		opt(provider)
 36	}
 37
 38	if provider.systemMessage == "" {
 39		return nil, errors.New("system message is required")
 40	}
 41
 42	client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
 43	if err != nil {
 44		return nil, err
 45	}
 46	provider.client = client
 47
 48	return provider, nil
 49}
 50
 51func WithGeminiSystemMessage(message string) GeminiOption {
 52	return func(p *geminiProvider) {
 53		p.systemMessage = message
 54	}
 55}
 56
 57func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
 58	return func(p *geminiProvider) {
 59		p.maxTokens = maxTokens
 60	}
 61}
 62
 63func WithGeminiModel(model models.Model) GeminiOption {
 64	return func(p *geminiProvider) {
 65		p.model = model
 66	}
 67}
 68
 69func WithGeminiKey(apiKey string) GeminiOption {
 70	return func(p *geminiProvider) {
 71		p.apiKey = apiKey
 72	}
 73}
 74
 75func (p *geminiProvider) Close() {
 76	if p.client != nil {
 77		p.client.Close()
 78	}
 79}
 80
 81func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
 82	var history []*genai.Content
 83
 84	for _, msg := range messages {
 85		switch msg.Role {
 86		case message.User:
 87			history = append(history, &genai.Content{
 88				Parts: []genai.Part{genai.Text(msg.Content().String())},
 89				Role:  "user",
 90			})
 91		case message.Assistant:
 92			content := &genai.Content{
 93				Role:  "model",
 94				Parts: []genai.Part{},
 95			}
 96
 97			if msg.Content().String() != "" {
 98				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 99			}
100
101			if len(msg.ToolCalls()) > 0 {
102				for _, call := range msg.ToolCalls() {
103					args, _ := parseJsonToMap(call.Input)
104					content.Parts = append(content.Parts, genai.FunctionCall{
105						Name: call.Name,
106						Args: args,
107					})
108				}
109			}
110
111			history = append(history, content)
112		case message.Tool:
113			for _, result := range msg.ToolResults() {
114				response := map[string]interface{}{"result": result.Content}
115				parsed, err := parseJsonToMap(result.Content)
116				if err == nil {
117					response = parsed
118				}
119				var toolCall message.ToolCall
120				for _, msg := range messages {
121					if msg.Role == message.Assistant {
122						for _, call := range msg.ToolCalls() {
123							if call.ID == result.ToolCallID {
124								toolCall = call
125								break
126							}
127						}
128					}
129				}
130
131				history = append(history, &genai.Content{
132					Parts: []genai.Part{genai.FunctionResponse{
133						Name:     toolCall.Name,
134						Response: response,
135					}},
136					Role: "function",
137				})
138			}
139		}
140	}
141
142	return history
143}
144
145func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
146	if resp == nil || resp.UsageMetadata == nil {
147		return TokenUsage{}
148	}
149
150	return TokenUsage{
151		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
152		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
153		CacheCreationTokens: 0, // Not directly provided by Gemini
154		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
155	}
156}
157
158func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
159	model := p.client.GenerativeModel(p.model.APIModel)
160	model.SetMaxOutputTokens(p.maxTokens)
161
162	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
163
164	if len(tools) > 0 {
165		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
166		for _, declaration := range declarations {
167			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
168		}
169	}
170
171	chat := model.StartChat()
172	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
173
174	lastUserMsg := messages[len(messages)-1]
175	resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
176	if err != nil {
177		return nil, err
178	}
179
180	var content string
181	var toolCalls []message.ToolCall
182
183	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
184		for _, part := range resp.Candidates[0].Content.Parts {
185			switch p := part.(type) {
186			case genai.Text:
187				content = string(p)
188			case genai.FunctionCall:
189				id := "call_" + uuid.New().String()
190				args, _ := json.Marshal(p.Args)
191				toolCalls = append(toolCalls, message.ToolCall{
192					ID:    id,
193					Name:  p.Name,
194					Input: string(args),
195					Type:  "function",
196				})
197			}
198		}
199	}
200
201	tokenUsage := p.extractTokenUsage(resp)
202
203	return &ProviderResponse{
204		Content:   content,
205		ToolCalls: toolCalls,
206		Usage:     tokenUsage,
207	}, nil
208}
209
210func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
211	model := p.client.GenerativeModel(p.model.APIModel)
212	model.SetMaxOutputTokens(p.maxTokens)
213
214	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
215
216	if len(tools) > 0 {
217		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
218		for _, declaration := range declarations {
219			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
220		}
221	}
222
223	chat := model.StartChat()
224	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
225
226	lastUserMsg := messages[len(messages)-1]
227
228	iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
229
230	eventChan := make(chan ProviderEvent)
231
232	go func() {
233		defer close(eventChan)
234
235		var finalResp *genai.GenerateContentResponse
236		currentContent := ""
237		toolCalls := []message.ToolCall{}
238
239		for {
240			resp, err := iter.Next()
241			if err == iterator.Done {
242				break
243			}
244			if err != nil {
245				var apiErr *googleapi.Error
246				if errors.As(err, &apiErr) {
247					log.Printf("%s", apiErr.Body)
248				}
249				eventChan <- ProviderEvent{
250					Type:  EventError,
251					Error: err,
252				}
253				return
254			}
255
256			finalResp = resp
257
258			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
259				for _, part := range resp.Candidates[0].Content.Parts {
260					switch p := part.(type) {
261					case genai.Text:
262						newText := string(p)
263						eventChan <- ProviderEvent{
264							Type:    EventContentDelta,
265							Content: newText,
266						}
267						currentContent += newText
268					case genai.FunctionCall:
269						id := "call_" + uuid.New().String()
270						args, _ := json.Marshal(p.Args)
271						newCall := message.ToolCall{
272							ID:    id,
273							Name:  p.Name,
274							Input: string(args),
275							Type:  "function",
276						}
277
278						isNew := true
279						for _, existing := range toolCalls {
280							if existing.Name == newCall.Name && existing.Input == newCall.Input {
281								isNew = false
282								break
283							}
284						}
285
286						if isNew {
287							toolCalls = append(toolCalls, newCall)
288						}
289					}
290				}
291			}
292		}
293
294		tokenUsage := p.extractTokenUsage(finalResp)
295
296		eventChan <- ProviderEvent{
297			Type: EventComplete,
298			Response: &ProviderResponse{
299				Content:      currentContent,
300				ToolCalls:    toolCalls,
301				Usage:        tokenUsage,
302				FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
303			},
304		}
305	}()
306
307	return eventChan, nil
308}
309
310func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
311	declarations := make([]*genai.FunctionDeclaration, len(tools))
312
313	for i, tool := range tools {
314		info := tool.Info()
315		declarations[i] = &genai.FunctionDeclaration{
316			Name:        info.Name,
317			Description: info.Description,
318			Parameters: &genai.Schema{
319				Type:       genai.TypeObject,
320				Properties: convertSchemaProperties(info.Parameters),
321				Required:   info.Required,
322			},
323		}
324	}
325
326	return declarations
327}
328
329func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
330	properties := make(map[string]*genai.Schema)
331
332	for name, param := range parameters {
333		properties[name] = convertToSchema(param)
334	}
335
336	return properties
337}
338
339func convertToSchema(param interface{}) *genai.Schema {
340	schema := &genai.Schema{Type: genai.TypeString}
341
342	paramMap, ok := param.(map[string]interface{})
343	if !ok {
344		return schema
345	}
346
347	if desc, ok := paramMap["description"].(string); ok {
348		schema.Description = desc
349	}
350
351	typeVal, hasType := paramMap["type"]
352	if !hasType {
353		return schema
354	}
355
356	typeStr, ok := typeVal.(string)
357	if !ok {
358		return schema
359	}
360
361	schema.Type = mapJSONTypeToGenAI(typeStr)
362
363	switch typeStr {
364	case "array":
365		schema.Items = processArrayItems(paramMap)
366	case "object":
367		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
368			schema.Properties = convertSchemaProperties(props)
369		}
370	}
371
372	return schema
373}
374
375func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
376	items, ok := paramMap["items"].(map[string]interface{})
377	if !ok {
378		return nil
379	}
380
381	return convertToSchema(items)
382}
383
384func mapJSONTypeToGenAI(jsonType string) genai.Type {
385	switch jsonType {
386	case "string":
387		return genai.TypeString
388	case "number":
389		return genai.TypeNumber
390	case "integer":
391		return genai.TypeInteger
392	case "boolean":
393		return genai.TypeBoolean
394	case "array":
395		return genai.TypeArray
396	case "object":
397		return genai.TypeObject
398	default:
399		return genai.TypeString // Default to string for unknown types
400	}
401}
402
403func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
404	var result map[string]interface{}
405	err := json.Unmarshal([]byte(jsonStr), &result)
406	return result, err
407}