1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/message"
17)
18
19type Agent interface {
20 Generate(sessionID string, content string) error
21}
22
23type agent struct {
24 *app.App
25 model models.Model
26 tools []tools.BaseTool
27 agent provider.Provider
28 titleGenerator provider.Provider
29}
30
31func (c *agent) handleTitleGeneration(sessionID, content string) {
32 response, err := c.titleGenerator.SendMessages(
33 c.Context,
34 []message.Message{
35 {
36 Role: message.User,
37 Parts: []message.ContentPart{
38 message.TextContent{
39 Text: content,
40 },
41 },
42 },
43 },
44 nil,
45 )
46 if err != nil {
47 return
48 }
49
50 session, err := c.Sessions.Get(sessionID)
51 if err != nil {
52 return
53 }
54 if response.Content != "" {
55 session.Title = response.Content
56 session.Title = strings.TrimSpace(session.Title)
57 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
58 c.Sessions.Save(session)
59 }
60}
61
62func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
63 session, err := c.Sessions.Get(sessionID)
64 if err != nil {
65 return err
66 }
67
68 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
69 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
70 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
71 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
72
73 session.Cost += cost
74 session.CompletionTokens += usage.OutputTokens
75 session.PromptTokens += usage.InputTokens
76
77 _, err = c.Sessions.Save(session)
78 return err
79}
80
81func (c *agent) processEvent(
82 sessionID string,
83 assistantMsg *message.Message,
84 event provider.ProviderEvent,
85) error {
86 switch event.Type {
87 case provider.EventThinkingDelta:
88 assistantMsg.AppendReasoningContent(event.Content)
89 return c.Messages.Update(*assistantMsg)
90 case provider.EventContentDelta:
91 assistantMsg.AppendContent(event.Content)
92 return c.Messages.Update(*assistantMsg)
93 case provider.EventError:
94 c.App.Logger.PersistError(event.Error.Error())
95 return event.Error
96 case provider.EventWarning:
97 c.App.Logger.PersistWarn(event.Info)
98 return nil
99 case provider.EventInfo:
100 c.App.Logger.PersistInfo(event.Info)
101 case provider.EventComplete:
102 assistantMsg.SetToolCalls(event.Response.ToolCalls)
103 assistantMsg.AddFinish(event.Response.FinishReason)
104 err := c.Messages.Update(*assistantMsg)
105 if err != nil {
106 return err
107 }
108 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
109 }
110
111 return nil
112}
113
114func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
115 var wg sync.WaitGroup
116 toolResults := make([]message.ToolResult, len(toolCalls))
117 mutex := &sync.Mutex{}
118
119 for i, tc := range toolCalls {
120 wg.Add(1)
121 go func(index int, toolCall message.ToolCall) {
122 defer wg.Done()
123
124 response := ""
125 isError := false
126 found := false
127
128 for _, tool := range tls {
129 if tool.Info().Name == toolCall.Name {
130 found = true
131 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
132 ID: toolCall.ID,
133 Name: toolCall.Name,
134 Input: toolCall.Input,
135 })
136 if toolErr != nil {
137 response = fmt.Sprintf("error running tool: %s", toolErr)
138 isError = true
139 } else {
140 response = toolResult.Content
141 isError = toolResult.IsError
142 }
143 break
144 }
145 }
146
147 if !found {
148 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
149 isError = true
150 }
151
152 mutex.Lock()
153 defer mutex.Unlock()
154
155 toolResults[index] = message.ToolResult{
156 ToolCallID: toolCall.ID,
157 Content: response,
158 IsError: isError,
159 }
160 }(i, tc)
161 }
162
163 wg.Wait()
164 return toolResults, nil
165}
166
167func (c *agent) handleToolExecution(
168 ctx context.Context,
169 assistantMsg message.Message,
170) (*message.Message, error) {
171 if len(assistantMsg.ToolCalls()) == 0 {
172 return nil, nil
173 }
174
175 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
176 if err != nil {
177 return nil, err
178 }
179 parts := make([]message.ContentPart, 0)
180 for _, toolResult := range toolResults {
181 parts = append(parts, toolResult)
182 }
183 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
184 Role: message.Tool,
185 Parts: parts,
186 })
187
188 return &msg, err
189}
190
191func (c *agent) generate(sessionID string, content string) error {
192 messages, err := c.Messages.List(sessionID)
193 if err != nil {
194 return err
195 }
196
197 if len(messages) == 0 {
198 go c.handleTitleGeneration(sessionID, content)
199 }
200
201 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
202 Role: message.User,
203 Parts: []message.ContentPart{
204 message.TextContent{
205 Text: content,
206 },
207 },
208 })
209 if err != nil {
210 return err
211 }
212
213 messages = append(messages, userMsg)
214 for {
215
216 eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
217 if err != nil {
218 return err
219 }
220
221 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
222 Role: message.Assistant,
223 Parts: []message.ContentPart{},
224 })
225 if err != nil {
226 return err
227 }
228 for event := range eventChan {
229 err = c.processEvent(sessionID, &assistantMsg, event)
230 if err != nil {
231 assistantMsg.AddFinish("error:" + err.Error())
232 c.Messages.Update(assistantMsg)
233 return err
234 }
235 }
236
237 msg, err := c.handleToolExecution(c.Context, assistantMsg)
238
239 c.Messages.Update(assistantMsg)
240 if err != nil {
241 return err
242 }
243
244 if len(assistantMsg.ToolCalls()) == 0 {
245 break
246 }
247
248 messages = append(messages, assistantMsg)
249 if msg != nil {
250 messages = append(messages, *msg)
251 }
252 }
253 return nil
254}
255
256func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
257 maxTokens := config.Get().Model.CoderMaxTokens
258
259 providerConfig, ok := config.Get().Providers[model.Provider]
260 if !ok || !providerConfig.Enabled {
261 return nil, nil, errors.New("provider is not enabled")
262 }
263 var agentProvider provider.Provider
264 var titleGenerator provider.Provider
265
266 switch model.Provider {
267 case models.ProviderOpenAI:
268 var err error
269 agentProvider, err = provider.NewOpenAIProvider(
270 provider.WithOpenAISystemMessage(
271 prompt.CoderOpenAISystemPrompt(),
272 ),
273 provider.WithOpenAIMaxTokens(maxTokens),
274 provider.WithOpenAIModel(model),
275 provider.WithOpenAIKey(providerConfig.APIKey),
276 )
277 if err != nil {
278 return nil, nil, err
279 }
280 titleGenerator, err = provider.NewOpenAIProvider(
281 provider.WithOpenAISystemMessage(
282 prompt.TitlePrompt(),
283 ),
284 provider.WithOpenAIMaxTokens(80),
285 provider.WithOpenAIModel(model),
286 provider.WithOpenAIKey(providerConfig.APIKey),
287 )
288 if err != nil {
289 return nil, nil, err
290 }
291 case models.ProviderAnthropic:
292 var err error
293 agentProvider, err = provider.NewAnthropicProvider(
294 provider.WithAnthropicSystemMessage(
295 prompt.CoderAnthropicSystemPrompt(),
296 ),
297 provider.WithAnthropicMaxTokens(maxTokens),
298 provider.WithAnthropicKey(providerConfig.APIKey),
299 provider.WithAnthropicModel(model),
300 )
301 if err != nil {
302 return nil, nil, err
303 }
304 titleGenerator, err = provider.NewAnthropicProvider(
305 provider.WithAnthropicSystemMessage(
306 prompt.TitlePrompt(),
307 ),
308 provider.WithAnthropicMaxTokens(80),
309 provider.WithAnthropicKey(providerConfig.APIKey),
310 provider.WithAnthropicModel(model),
311 )
312 if err != nil {
313 return nil, nil, err
314 }
315
316 case models.ProviderGemini:
317 var err error
318 agentProvider, err = provider.NewGeminiProvider(
319 ctx,
320 provider.WithGeminiSystemMessage(
321 prompt.CoderOpenAISystemPrompt(),
322 ),
323 provider.WithGeminiMaxTokens(int32(maxTokens)),
324 provider.WithGeminiKey(providerConfig.APIKey),
325 provider.WithGeminiModel(model),
326 )
327 if err != nil {
328 return nil, nil, err
329 }
330 titleGenerator, err = provider.NewGeminiProvider(
331 ctx,
332 provider.WithGeminiSystemMessage(
333 prompt.TitlePrompt(),
334 ),
335 provider.WithGeminiMaxTokens(80),
336 provider.WithGeminiKey(providerConfig.APIKey),
337 provider.WithGeminiModel(model),
338 )
339 if err != nil {
340 return nil, nil, err
341 }
342 case models.ProviderGROQ:
343 var err error
344 agentProvider, err = provider.NewOpenAIProvider(
345 provider.WithOpenAISystemMessage(
346 prompt.CoderAnthropicSystemPrompt(),
347 ),
348 provider.WithOpenAIMaxTokens(maxTokens),
349 provider.WithOpenAIModel(model),
350 provider.WithOpenAIKey(providerConfig.APIKey),
351 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
352 )
353 if err != nil {
354 return nil, nil, err
355 }
356 titleGenerator, err = provider.NewOpenAIProvider(
357 provider.WithOpenAISystemMessage(
358 prompt.TitlePrompt(),
359 ),
360 provider.WithOpenAIMaxTokens(80),
361 provider.WithOpenAIModel(model),
362 provider.WithOpenAIKey(providerConfig.APIKey),
363 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
364 )
365 if err != nil {
366 return nil, nil, err
367 }
368
369 case models.ProviderBedrock:
370 var err error
371 agentProvider, err = provider.NewBedrockProvider(
372 provider.WithBedrockSystemMessage(
373 prompt.CoderAnthropicSystemPrompt(),
374 ),
375 provider.WithBedrockMaxTokens(maxTokens),
376 provider.WithBedrockModel(model),
377 )
378 if err != nil {
379 return nil, nil, err
380 }
381 titleGenerator, err = provider.NewBedrockProvider(
382 provider.WithBedrockSystemMessage(
383 prompt.TitlePrompt(),
384 ),
385 provider.WithBedrockMaxTokens(maxTokens),
386 provider.WithBedrockModel(model),
387 )
388 if err != nil {
389 return nil, nil, err
390 }
391
392 }
393
394 return agentProvider, titleGenerator, nil
395}