gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7
  8	"github.com/google/generative-ai-go/genai"
  9	"github.com/google/uuid"
 10	"github.com/kujtimiihoxha/termai/internal/llm/models"
 11	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 12	"github.com/kujtimiihoxha/termai/internal/message"
 13	"google.golang.org/api/iterator"
 14	"google.golang.org/api/option"
 15)
 16
 17type geminiProvider struct {
 18	client        *genai.Client
 19	model         models.Model
 20	maxTokens     int32
 21	apiKey        string
 22	systemMessage string
 23}
 24
 25type GeminiOption func(*geminiProvider)
 26
 27func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
 28	provider := &geminiProvider{
 29		maxTokens: 5000,
 30	}
 31
 32	for _, opt := range opts {
 33		opt(provider)
 34	}
 35
 36	if provider.systemMessage == "" {
 37		return nil, errors.New("system message is required")
 38	}
 39
 40	client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
 41	if err != nil {
 42		return nil, err
 43	}
 44	provider.client = client
 45
 46	return provider, nil
 47}
 48
 49func WithGeminiSystemMessage(message string) GeminiOption {
 50	return func(p *geminiProvider) {
 51		p.systemMessage = message
 52	}
 53}
 54
 55func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
 56	return func(p *geminiProvider) {
 57		p.maxTokens = maxTokens
 58	}
 59}
 60
 61func WithGeminiModel(model models.Model) GeminiOption {
 62	return func(p *geminiProvider) {
 63		p.model = model
 64	}
 65}
 66
 67func WithGeminiKey(apiKey string) GeminiOption {
 68	return func(p *geminiProvider) {
 69		p.apiKey = apiKey
 70	}
 71}
 72
 73func (p *geminiProvider) Close() {
 74	if p.client != nil {
 75		p.client.Close()
 76	}
 77}
 78
 79func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
 80	var history []*genai.Content
 81
 82	for _, msg := range messages {
 83		switch msg.Role {
 84		case message.User:
 85			history = append(history, &genai.Content{
 86				Parts: []genai.Part{genai.Text(msg.Content().String())},
 87				Role:  "user",
 88			})
 89		case message.Assistant:
 90			content := &genai.Content{
 91				Role:  "model",
 92				Parts: []genai.Part{},
 93			}
 94
 95			if msg.Content().String() != "" {
 96				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 97			}
 98
 99			if len(msg.ToolCalls()) > 0 {
100				for _, call := range msg.ToolCalls() {
101					args, _ := parseJsonToMap(call.Input)
102					content.Parts = append(content.Parts, genai.FunctionCall{
103						Name: call.Name,
104						Args: args,
105					})
106				}
107			}
108
109			history = append(history, content)
110		case message.Tool:
111			for _, result := range msg.ToolResults() {
112				response := map[string]interface{}{"result": result.Content}
113				parsed, err := parseJsonToMap(result.Content)
114				if err == nil {
115					response = parsed
116				}
117				var toolCall message.ToolCall
118				for _, msg := range messages {
119					if msg.Role == message.Assistant {
120						for _, call := range msg.ToolCalls() {
121							if call.ID == result.ToolCallID {
122								toolCall = call
123								break
124							}
125						}
126					}
127				}
128
129				history = append(history, &genai.Content{
130					Parts: []genai.Part{genai.FunctionResponse{
131						Name:     toolCall.Name,
132						Response: response,
133					}},
134					Role: "function",
135				})
136			}
137		}
138	}
139
140	return history
141}
142
143func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
144	if resp == nil || resp.UsageMetadata == nil {
145		return TokenUsage{}
146	}
147
148	return TokenUsage{
149		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
150		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
151		CacheCreationTokens: 0, // Not directly provided by Gemini
152		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
153	}
154}
155
156func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
157	messages = cleanupMessages(messages)
158	model := p.client.GenerativeModel(p.model.APIModel)
159	model.SetMaxOutputTokens(p.maxTokens)
160
161	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
162
163	if len(tools) > 0 {
164		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
165		for _, declaration := range declarations {
166			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
167		}
168	}
169
170	chat := model.StartChat()
171	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
172
173	lastUserMsg := messages[len(messages)-1]
174	resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
175	if err != nil {
176		return nil, err
177	}
178
179	var content string
180	var toolCalls []message.ToolCall
181
182	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
183		for _, part := range resp.Candidates[0].Content.Parts {
184			switch p := part.(type) {
185			case genai.Text:
186				content = string(p)
187			case genai.FunctionCall:
188				id := "call_" + uuid.New().String()
189				args, _ := json.Marshal(p.Args)
190				toolCalls = append(toolCalls, message.ToolCall{
191					ID:    id,
192					Name:  p.Name,
193					Input: string(args),
194					Type:  "function",
195				})
196			}
197		}
198	}
199
200	tokenUsage := p.extractTokenUsage(resp)
201
202	return &ProviderResponse{
203		Content:   content,
204		ToolCalls: toolCalls,
205		Usage:     tokenUsage,
206	}, nil
207}
208
209func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
210	messages = cleanupMessages(messages)
211	model := p.client.GenerativeModel(p.model.APIModel)
212	model.SetMaxOutputTokens(p.maxTokens)
213
214	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
215
216	if len(tools) > 0 {
217		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
218		for _, declaration := range declarations {
219			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
220		}
221	}
222
223	chat := model.StartChat()
224	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
225
226	lastUserMsg := messages[len(messages)-1]
227
228	iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
229
230	eventChan := make(chan ProviderEvent)
231
232	go func() {
233		defer close(eventChan)
234
235		var finalResp *genai.GenerateContentResponse
236		currentContent := ""
237		toolCalls := []message.ToolCall{}
238
239		for {
240			resp, err := iter.Next()
241			if err == iterator.Done {
242				break
243			}
244			if err != nil {
245				eventChan <- ProviderEvent{
246					Type:  EventError,
247					Error: err,
248				}
249				return
250			}
251
252			finalResp = resp
253
254			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
255				for _, part := range resp.Candidates[0].Content.Parts {
256					switch p := part.(type) {
257					case genai.Text:
258						newText := string(p)
259						eventChan <- ProviderEvent{
260							Type:    EventContentDelta,
261							Content: newText,
262						}
263						currentContent += newText
264					case genai.FunctionCall:
265						id := "call_" + uuid.New().String()
266						args, _ := json.Marshal(p.Args)
267						newCall := message.ToolCall{
268							ID:    id,
269							Name:  p.Name,
270							Input: string(args),
271							Type:  "function",
272						}
273
274						isNew := true
275						for _, existing := range toolCalls {
276							if existing.Name == newCall.Name && existing.Input == newCall.Input {
277								isNew = false
278								break
279							}
280						}
281
282						if isNew {
283							toolCalls = append(toolCalls, newCall)
284						}
285					}
286				}
287			}
288		}
289
290		tokenUsage := p.extractTokenUsage(finalResp)
291
292		eventChan <- ProviderEvent{
293			Type: EventComplete,
294			Response: &ProviderResponse{
295				Content:      currentContent,
296				ToolCalls:    toolCalls,
297				Usage:        tokenUsage,
298				FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
299			},
300		}
301	}()
302
303	return eventChan, nil
304}
305
306func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
307	declarations := make([]*genai.FunctionDeclaration, len(tools))
308
309	for i, tool := range tools {
310		info := tool.Info()
311		declarations[i] = &genai.FunctionDeclaration{
312			Name:        info.Name,
313			Description: info.Description,
314			Parameters: &genai.Schema{
315				Type:       genai.TypeObject,
316				Properties: convertSchemaProperties(info.Parameters),
317				Required:   info.Required,
318			},
319		}
320	}
321
322	return declarations
323}
324
325func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
326	properties := make(map[string]*genai.Schema)
327
328	for name, param := range parameters {
329		properties[name] = convertToSchema(param)
330	}
331
332	return properties
333}
334
335func convertToSchema(param interface{}) *genai.Schema {
336	schema := &genai.Schema{Type: genai.TypeString}
337
338	paramMap, ok := param.(map[string]interface{})
339	if !ok {
340		return schema
341	}
342
343	if desc, ok := paramMap["description"].(string); ok {
344		schema.Description = desc
345	}
346
347	typeVal, hasType := paramMap["type"]
348	if !hasType {
349		return schema
350	}
351
352	typeStr, ok := typeVal.(string)
353	if !ok {
354		return schema
355	}
356
357	schema.Type = mapJSONTypeToGenAI(typeStr)
358
359	switch typeStr {
360	case "array":
361		schema.Items = processArrayItems(paramMap)
362	case "object":
363		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
364			schema.Properties = convertSchemaProperties(props)
365		}
366	}
367
368	return schema
369}
370
371func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
372	items, ok := paramMap["items"].(map[string]interface{})
373	if !ok {
374		return nil
375	}
376
377	return convertToSchema(items)
378}
379
380func mapJSONTypeToGenAI(jsonType string) genai.Type {
381	switch jsonType {
382	case "string":
383		return genai.TypeString
384	case "number":
385		return genai.TypeNumber
386	case "integer":
387		return genai.TypeInteger
388	case "boolean":
389		return genai.TypeBoolean
390	case "array":
391		return genai.TypeArray
392	case "object":
393		return genai.TypeObject
394	default:
395		return genai.TypeString // Default to string for unknown types
396	}
397}
398
399func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
400	var result map[string]interface{}
401	err := json.Unmarshal([]byte(jsonStr), &result)
402	return result, err
403}