gemini.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7
  8	"github.com/google/generative-ai-go/genai"
  9	"github.com/google/uuid"
 10	"github.com/kujtimiihoxha/termai/internal/llm/models"
 11	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 12	"github.com/kujtimiihoxha/termai/internal/message"
 13	"google.golang.org/api/iterator"
 14	"google.golang.org/api/option"
 15)
 16
 17type geminiProvider struct {
 18	client        *genai.Client
 19	model         models.Model
 20	maxTokens     int32
 21	apiKey        string
 22	systemMessage string
 23}
 24
 25type GeminiOption func(*geminiProvider)
 26
 27func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
 28	provider := &geminiProvider{
 29		maxTokens: 5000,
 30	}
 31
 32	for _, opt := range opts {
 33		opt(provider)
 34	}
 35
 36	if provider.systemMessage == "" {
 37		return nil, errors.New("system message is required")
 38	}
 39
 40	client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
 41	if err != nil {
 42		return nil, err
 43	}
 44	provider.client = client
 45
 46	return provider, nil
 47}
 48
 49func WithGeminiSystemMessage(message string) GeminiOption {
 50	return func(p *geminiProvider) {
 51		p.systemMessage = message
 52	}
 53}
 54
 55func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
 56	return func(p *geminiProvider) {
 57		p.maxTokens = maxTokens
 58	}
 59}
 60
 61func WithGeminiModel(model models.Model) GeminiOption {
 62	return func(p *geminiProvider) {
 63		p.model = model
 64	}
 65}
 66
 67func WithGeminiKey(apiKey string) GeminiOption {
 68	return func(p *geminiProvider) {
 69		p.apiKey = apiKey
 70	}
 71}
 72
 73func (p *geminiProvider) Close() {
 74	if p.client != nil {
 75		p.client.Close()
 76	}
 77}
 78
 79func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
 80	var history []*genai.Content
 81
 82	for _, msg := range messages {
 83		switch msg.Role {
 84		case message.User:
 85			history = append(history, &genai.Content{
 86				Parts: []genai.Part{genai.Text(msg.Content().String())},
 87				Role:  "user",
 88			})
 89		case message.Assistant:
 90			content := &genai.Content{
 91				Role:  "model",
 92				Parts: []genai.Part{},
 93			}
 94
 95			if msg.Content().String() != "" {
 96				content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
 97			}
 98
 99			if len(msg.ToolCalls()) > 0 {
100				for _, call := range msg.ToolCalls() {
101					args, _ := parseJsonToMap(call.Input)
102					content.Parts = append(content.Parts, genai.FunctionCall{
103						Name: call.Name,
104						Args: args,
105					})
106				}
107			}
108
109			history = append(history, content)
110		case message.Tool:
111			for _, result := range msg.ToolResults() {
112				response := map[string]interface{}{"result": result.Content}
113				parsed, err := parseJsonToMap(result.Content)
114				if err == nil {
115					response = parsed
116				}
117				var toolCall message.ToolCall
118				for _, msg := range messages {
119					if msg.Role == message.Assistant {
120						for _, call := range msg.ToolCalls() {
121							if call.ID == result.ToolCallID {
122								toolCall = call
123								break
124							}
125						}
126					}
127				}
128
129				history = append(history, &genai.Content{
130					Parts: []genai.Part{genai.FunctionResponse{
131						Name:     toolCall.Name,
132						Response: response,
133					}},
134					Role: "function",
135				})
136			}
137		}
138	}
139
140	return history
141}
142
143func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
144	if resp == nil || resp.UsageMetadata == nil {
145		return TokenUsage{}
146	}
147
148	return TokenUsage{
149		InputTokens:         int64(resp.UsageMetadata.PromptTokenCount),
150		OutputTokens:        int64(resp.UsageMetadata.CandidatesTokenCount),
151		CacheCreationTokens: 0, // Not directly provided by Gemini
152		CacheReadTokens:     int64(resp.UsageMetadata.CachedContentTokenCount),
153	}
154}
155
156func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
157	model := p.client.GenerativeModel(p.model.APIModel)
158	model.SetMaxOutputTokens(p.maxTokens)
159
160	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
161
162	if len(tools) > 0 {
163		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
164		for _, declaration := range declarations {
165			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
166		}
167	}
168
169	chat := model.StartChat()
170	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
171
172	lastUserMsg := messages[len(messages)-1]
173	resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
174	if err != nil {
175		return nil, err
176	}
177
178	var content string
179	var toolCalls []message.ToolCall
180
181	if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
182		for _, part := range resp.Candidates[0].Content.Parts {
183			switch p := part.(type) {
184			case genai.Text:
185				content = string(p)
186			case genai.FunctionCall:
187				id := "call_" + uuid.New().String()
188				args, _ := json.Marshal(p.Args)
189				toolCalls = append(toolCalls, message.ToolCall{
190					ID:    id,
191					Name:  p.Name,
192					Input: string(args),
193					Type:  "function",
194				})
195			}
196		}
197	}
198
199	tokenUsage := p.extractTokenUsage(resp)
200
201	return &ProviderResponse{
202		Content:   content,
203		ToolCalls: toolCalls,
204		Usage:     tokenUsage,
205	}, nil
206}
207
208func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
209	model := p.client.GenerativeModel(p.model.APIModel)
210	model.SetMaxOutputTokens(p.maxTokens)
211
212	model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
213
214	if len(tools) > 0 {
215		declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
216		for _, declaration := range declarations {
217			model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
218		}
219	}
220
221	chat := model.StartChat()
222	chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
223
224	lastUserMsg := messages[len(messages)-1]
225
226	iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
227
228	eventChan := make(chan ProviderEvent)
229
230	go func() {
231		defer close(eventChan)
232
233		var finalResp *genai.GenerateContentResponse
234		currentContent := ""
235		toolCalls := []message.ToolCall{}
236
237		for {
238			resp, err := iter.Next()
239			if err == iterator.Done {
240				break
241			}
242			if err != nil {
243				eventChan <- ProviderEvent{
244					Type:  EventError,
245					Error: err,
246				}
247				return
248			}
249
250			finalResp = resp
251
252			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
253				for _, part := range resp.Candidates[0].Content.Parts {
254					switch p := part.(type) {
255					case genai.Text:
256						newText := string(p)
257						eventChan <- ProviderEvent{
258							Type:    EventContentDelta,
259							Content: newText,
260						}
261						currentContent += newText
262					case genai.FunctionCall:
263						id := "call_" + uuid.New().String()
264						args, _ := json.Marshal(p.Args)
265						newCall := message.ToolCall{
266							ID:    id,
267							Name:  p.Name,
268							Input: string(args),
269							Type:  "function",
270						}
271
272						isNew := true
273						for _, existing := range toolCalls {
274							if existing.Name == newCall.Name && existing.Input == newCall.Input {
275								isNew = false
276								break
277							}
278						}
279
280						if isNew {
281							toolCalls = append(toolCalls, newCall)
282						}
283					}
284				}
285			}
286		}
287
288		tokenUsage := p.extractTokenUsage(finalResp)
289
290		eventChan <- ProviderEvent{
291			Type: EventComplete,
292			Response: &ProviderResponse{
293				Content:      currentContent,
294				ToolCalls:    toolCalls,
295				Usage:        tokenUsage,
296				FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
297			},
298		}
299	}()
300
301	return eventChan, nil
302}
303
304func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
305	declarations := make([]*genai.FunctionDeclaration, len(tools))
306
307	for i, tool := range tools {
308		info := tool.Info()
309		declarations[i] = &genai.FunctionDeclaration{
310			Name:        info.Name,
311			Description: info.Description,
312			Parameters: &genai.Schema{
313				Type:       genai.TypeObject,
314				Properties: convertSchemaProperties(info.Parameters),
315				Required:   info.Required,
316			},
317		}
318	}
319
320	return declarations
321}
322
323func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
324	properties := make(map[string]*genai.Schema)
325
326	for name, param := range parameters {
327		properties[name] = convertToSchema(param)
328	}
329
330	return properties
331}
332
333func convertToSchema(param interface{}) *genai.Schema {
334	schema := &genai.Schema{Type: genai.TypeString}
335
336	paramMap, ok := param.(map[string]interface{})
337	if !ok {
338		return schema
339	}
340
341	if desc, ok := paramMap["description"].(string); ok {
342		schema.Description = desc
343	}
344
345	typeVal, hasType := paramMap["type"]
346	if !hasType {
347		return schema
348	}
349
350	typeStr, ok := typeVal.(string)
351	if !ok {
352		return schema
353	}
354
355	schema.Type = mapJSONTypeToGenAI(typeStr)
356
357	switch typeStr {
358	case "array":
359		schema.Items = processArrayItems(paramMap)
360	case "object":
361		if props, ok := paramMap["properties"].(map[string]interface{}); ok {
362			schema.Properties = convertSchemaProperties(props)
363		}
364	}
365
366	return schema
367}
368
369func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
370	items, ok := paramMap["items"].(map[string]interface{})
371	if !ok {
372		return nil
373	}
374
375	return convertToSchema(items)
376}
377
378func mapJSONTypeToGenAI(jsonType string) genai.Type {
379	switch jsonType {
380	case "string":
381		return genai.TypeString
382	case "number":
383		return genai.TypeNumber
384	case "integer":
385		return genai.TypeInteger
386	case "boolean":
387		return genai.TypeBoolean
388	case "array":
389		return genai.TypeArray
390	case "object":
391		return genai.TypeObject
392	default:
393		return genai.TypeString // Default to string for unknown types
394	}
395}
396
397func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
398	var result map[string]interface{}
399	err := json.Unmarshal([]byte(jsonStr), &result)
400	return result, err
401}