agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log"
  8	"strings"
  9	"sync"
 10
 11	"github.com/kujtimiihoxha/termai/internal/app"
 12	"github.com/kujtimiihoxha/termai/internal/config"
 13	"github.com/kujtimiihoxha/termai/internal/llm/models"
 14	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 15	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 16	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type Agent interface {
 21	Generate(sessionID string, content string) error
 22}
 23
 24type agent struct {
 25	*app.App
 26	model          models.Model
 27	tools          []tools.BaseTool
 28	agent          provider.Provider
 29	titleGenerator provider.Provider
 30}
 31
 32func (c *agent) handleTitleGeneration(sessionID, content string) {
 33	response, err := c.titleGenerator.SendMessages(
 34		c.Context,
 35		[]message.Message{
 36			{
 37				Role: message.User,
 38				Parts: []message.ContentPart{
 39					message.TextContent{
 40						Text: content,
 41					},
 42				},
 43			},
 44		},
 45		nil,
 46	)
 47	if err != nil {
 48		return
 49	}
 50
 51	session, err := c.Sessions.Get(sessionID)
 52	if err != nil {
 53		return
 54	}
 55	if response.Content != "" {
 56		session.Title = response.Content
 57		session.Title = strings.TrimSpace(session.Title)
 58		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 59		c.Sessions.Save(session)
 60	}
 61}
 62
 63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 64	session, err := c.Sessions.Get(sessionID)
 65	if err != nil {
 66		return err
 67	}
 68
 69	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 70		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 71		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 72		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 73
 74	session.Cost += cost
 75	session.CompletionTokens += usage.OutputTokens
 76	session.PromptTokens += usage.InputTokens
 77
 78	_, err = c.Sessions.Save(session)
 79	return err
 80}
 81
 82func (c *agent) processEvent(
 83	sessionID string,
 84	assistantMsg *message.Message,
 85	event provider.ProviderEvent,
 86) error {
 87	switch event.Type {
 88	case provider.EventThinkingDelta:
 89		assistantMsg.AppendReasoningContent(event.Content)
 90		return c.Messages.Update(*assistantMsg)
 91	case provider.EventContentDelta:
 92		assistantMsg.AppendContent(event.Content)
 93		return c.Messages.Update(*assistantMsg)
 94	case provider.EventError:
 95		log.Println("error", event.Error)
 96		return event.Error
 97
 98	case provider.EventComplete:
 99		assistantMsg.SetToolCalls(event.Response.ToolCalls)
100		assistantMsg.AddFinish(event.Response.FinishReason)
101		err := c.Messages.Update(*assistantMsg)
102		if err != nil {
103			return err
104		}
105		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
106	}
107
108	return nil
109}
110
111func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
112	var wg sync.WaitGroup
113	toolResults := make([]message.ToolResult, len(toolCalls))
114	mutex := &sync.Mutex{}
115
116	for i, tc := range toolCalls {
117		wg.Add(1)
118		go func(index int, toolCall message.ToolCall) {
119			defer wg.Done()
120
121			response := ""
122			isError := false
123			found := false
124
125			for _, tool := range tls {
126				if tool.Info().Name == toolCall.Name {
127					found = true
128					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
129						ID:    toolCall.ID,
130						Name:  toolCall.Name,
131						Input: toolCall.Input,
132					})
133					if toolErr != nil {
134						response = fmt.Sprintf("error running tool: %s", toolErr)
135						isError = true
136					} else {
137						response = toolResult.Content
138						isError = toolResult.IsError
139					}
140					break
141				}
142			}
143
144			if !found {
145				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
146				isError = true
147			}
148
149			mutex.Lock()
150			defer mutex.Unlock()
151
152			toolResults[index] = message.ToolResult{
153				ToolCallID: toolCall.ID,
154				Content:    response,
155				IsError:    isError,
156			}
157		}(i, tc)
158	}
159
160	wg.Wait()
161	return toolResults, nil
162}
163
164func (c *agent) handleToolExecution(
165	ctx context.Context,
166	assistantMsg message.Message,
167) (*message.Message, error) {
168	if len(assistantMsg.ToolCalls()) == 0 {
169		return nil, nil
170	}
171
172	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
173	if err != nil {
174		return nil, err
175	}
176	parts := make([]message.ContentPart, 0)
177	for _, toolResult := range toolResults {
178		parts = append(parts, toolResult)
179	}
180	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
181		Role:  message.Tool,
182		Parts: parts,
183	})
184
185	return &msg, err
186}
187
188func (c *agent) generate(sessionID string, content string) error {
189	messages, err := c.Messages.List(sessionID)
190	if err != nil {
191		return err
192	}
193
194	if len(messages) == 0 {
195		go c.handleTitleGeneration(sessionID, content)
196	}
197
198	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
199		Role: message.User,
200		Parts: []message.ContentPart{
201			message.TextContent{
202				Text: content,
203			},
204		},
205	})
206	if err != nil {
207		return err
208	}
209
210	messages = append(messages, userMsg)
211	for {
212
213		eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
214		if err != nil {
215			return err
216		}
217
218		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
219			Role:  message.Assistant,
220			Parts: []message.ContentPart{},
221		})
222		if err != nil {
223			return err
224		}
225		for event := range eventChan {
226			err = c.processEvent(sessionID, &assistantMsg, event)
227			if err != nil {
228				assistantMsg.AddFinish("error:" + err.Error())
229				c.Messages.Update(assistantMsg)
230				return err
231			}
232		}
233
234		msg, err := c.handleToolExecution(c.Context, assistantMsg)
235
236		c.Messages.Update(assistantMsg)
237		if err != nil {
238			return err
239		}
240
241		if len(assistantMsg.ToolCalls()) == 0 {
242			break
243		}
244
245		messages = append(messages, assistantMsg)
246		if msg != nil {
247			messages = append(messages, *msg)
248		}
249	}
250	return nil
251}
252
253func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
254	maxTokens := config.Get().Model.CoderMaxTokens
255
256	providerConfig, ok := config.Get().Providers[model.Provider]
257	if !ok || !providerConfig.Enabled {
258		return nil, nil, errors.New("provider is not enabled")
259	}
260	var agentProvider provider.Provider
261	var titleGenerator provider.Provider
262
263	switch model.Provider {
264	case models.ProviderOpenAI:
265		var err error
266		agentProvider, err = provider.NewOpenAIProvider(
267			provider.WithOpenAISystemMessage(
268				prompt.CoderOpenAISystemPrompt(),
269			),
270			provider.WithOpenAIMaxTokens(maxTokens),
271			provider.WithOpenAIModel(model),
272			provider.WithOpenAIKey(providerConfig.APIKey),
273		)
274		if err != nil {
275			return nil, nil, err
276		}
277		titleGenerator, err = provider.NewOpenAIProvider(
278			provider.WithOpenAISystemMessage(
279				prompt.TitlePrompt(),
280			),
281			provider.WithOpenAIMaxTokens(80),
282			provider.WithOpenAIModel(model),
283			provider.WithOpenAIKey(providerConfig.APIKey),
284		)
285		if err != nil {
286			return nil, nil, err
287		}
288	case models.ProviderAnthropic:
289		var err error
290		agentProvider, err = provider.NewAnthropicProvider(
291			provider.WithAnthropicSystemMessage(
292				prompt.CoderAnthropicSystemPrompt(),
293			),
294			provider.WithAnthropicMaxTokens(maxTokens),
295			provider.WithAnthropicKey(providerConfig.APIKey),
296			provider.WithAnthropicModel(model),
297		)
298		if err != nil {
299			return nil, nil, err
300		}
301		titleGenerator, err = provider.NewAnthropicProvider(
302			provider.WithAnthropicSystemMessage(
303				prompt.TitlePrompt(),
304			),
305			provider.WithAnthropicMaxTokens(80),
306			provider.WithAnthropicKey(providerConfig.APIKey),
307			provider.WithAnthropicModel(model),
308		)
309		if err != nil {
310			return nil, nil, err
311		}
312
313	case models.ProviderGemini:
314		var err error
315		agentProvider, err = provider.NewGeminiProvider(
316			ctx,
317			provider.WithGeminiSystemMessage(
318				prompt.CoderOpenAISystemPrompt(),
319			),
320			provider.WithGeminiMaxTokens(int32(maxTokens)),
321			provider.WithGeminiKey(providerConfig.APIKey),
322			provider.WithGeminiModel(model),
323		)
324		if err != nil {
325			return nil, nil, err
326		}
327		titleGenerator, err = provider.NewGeminiProvider(
328			ctx,
329			provider.WithGeminiSystemMessage(
330				prompt.TitlePrompt(),
331			),
332			provider.WithGeminiMaxTokens(80),
333			provider.WithGeminiKey(providerConfig.APIKey),
334			provider.WithGeminiModel(model),
335		)
336		if err != nil {
337			return nil, nil, err
338		}
339	case models.ProviderGROQ:
340		var err error
341		agentProvider, err = provider.NewOpenAIProvider(
342			provider.WithOpenAISystemMessage(
343				prompt.CoderAnthropicSystemPrompt(),
344			),
345			provider.WithOpenAIMaxTokens(maxTokens),
346			provider.WithOpenAIModel(model),
347			provider.WithOpenAIKey(providerConfig.APIKey),
348			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
349		)
350		if err != nil {
351			return nil, nil, err
352		}
353		titleGenerator, err = provider.NewOpenAIProvider(
354			provider.WithOpenAISystemMessage(
355				prompt.TitlePrompt(),
356			),
357			provider.WithOpenAIMaxTokens(80),
358			provider.WithOpenAIModel(model),
359			provider.WithOpenAIKey(providerConfig.APIKey),
360			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
361		)
362		if err != nil {
363			return nil, nil, err
364		}
365
366	}
367
368	return agentProvider, titleGenerator, nil
369}