1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/logging"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type Agent interface {
21 Generate(ctx context.Context, sessionID string, content string) error
22}
23
24type agent struct {
25 *app.App
26 model models.Model
27 tools []tools.BaseTool
28 agent provider.Provider
29 titleGenerator provider.Provider
30}
31
32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
33 response, err := c.titleGenerator.SendMessages(
34 ctx,
35 []message.Message{
36 {
37 Role: message.User,
38 Parts: []message.ContentPart{
39 message.TextContent{
40 Text: content,
41 },
42 },
43 },
44 },
45 nil,
46 )
47 if err != nil {
48 return
49 }
50
51 session, err := c.Sessions.Get(sessionID)
52 if err != nil {
53 return
54 }
55 if response.Content != "" {
56 session.Title = response.Content
57 session.Title = strings.TrimSpace(session.Title)
58 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
59 c.Sessions.Save(session)
60 }
61}
62
63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
64 session, err := c.Sessions.Get(sessionID)
65 if err != nil {
66 return err
67 }
68
69 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
70 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
71 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
72 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
73
74 session.Cost += cost
75 session.CompletionTokens += usage.OutputTokens
76 session.PromptTokens += usage.InputTokens
77
78 _, err = c.Sessions.Save(session)
79 return err
80}
81
82func (c *agent) processEvent(
83 sessionID string,
84 assistantMsg *message.Message,
85 event provider.ProviderEvent,
86) error {
87 switch event.Type {
88 case provider.EventThinkingDelta:
89 assistantMsg.AppendReasoningContent(event.Content)
90 return c.Messages.Update(*assistantMsg)
91 case provider.EventContentDelta:
92 assistantMsg.AppendContent(event.Content)
93 return c.Messages.Update(*assistantMsg)
94 case provider.EventError:
95 if errors.Is(event.Error, context.Canceled) {
96 return nil
97 }
98 logging.ErrorPersist(event.Error.Error())
99 return event.Error
100 case provider.EventWarning:
101 logging.WarnPersist(event.Info)
102 return nil
103 case provider.EventInfo:
104 logging.InfoPersist(event.Info)
105 case provider.EventComplete:
106 assistantMsg.SetToolCalls(event.Response.ToolCalls)
107 assistantMsg.AddFinish(event.Response.FinishReason)
108 err := c.Messages.Update(*assistantMsg)
109 if err != nil {
110 return err
111 }
112 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
113 }
114
115 return nil
116}
117
118func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
119 var wg sync.WaitGroup
120 toolResults := make([]message.ToolResult, len(toolCalls))
121 mutex := &sync.Mutex{}
122 errChan := make(chan error, 1)
123
124 // Create a child context that can be canceled
125 ctx, cancel := context.WithCancel(ctx)
126 defer cancel()
127
128 for i, tc := range toolCalls {
129 wg.Add(1)
130 go func(index int, toolCall message.ToolCall) {
131 defer wg.Done()
132
133 // Check if context is already canceled
134 select {
135 case <-ctx.Done():
136 mutex.Lock()
137 toolResults[index] = message.ToolResult{
138 ToolCallID: toolCall.ID,
139 Content: "Tool execution canceled",
140 IsError: true,
141 }
142 mutex.Unlock()
143
144 // Send cancellation error to error channel if it's empty
145 select {
146 case errChan <- ctx.Err():
147 default:
148 }
149 return
150 default:
151 }
152
153 response := ""
154 isError := false
155 found := false
156
157 for _, tool := range tls {
158 if tool.Info().Name == toolCall.Name {
159 found = true
160 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
161 ID: toolCall.ID,
162 Name: toolCall.Name,
163 Input: toolCall.Input,
164 })
165
166 if toolErr != nil {
167 if errors.Is(toolErr, context.Canceled) {
168 response = "Tool execution canceled"
169
170 // Send cancellation error to error channel if it's empty
171 select {
172 case errChan <- ctx.Err():
173 default:
174 }
175 } else {
176 response = fmt.Sprintf("error running tool: %s", toolErr)
177 }
178 isError = true
179 } else {
180 response = toolResult.Content
181 isError = toolResult.IsError
182 }
183 break
184 }
185 }
186
187 if !found {
188 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
189 isError = true
190 }
191
192 mutex.Lock()
193 defer mutex.Unlock()
194
195 toolResults[index] = message.ToolResult{
196 ToolCallID: toolCall.ID,
197 Content: response,
198 IsError: isError,
199 }
200 }(i, tc)
201 }
202
203 // Wait for all goroutines to finish or context to be canceled
204 done := make(chan struct{})
205 go func() {
206 wg.Wait()
207 close(done)
208 }()
209
210 select {
211 case <-done:
212 // All tools completed successfully
213 case err := <-errChan:
214 // One of the tools encountered a cancellation
215 return toolResults, err
216 case <-ctx.Done():
217 // Context was canceled externally
218 return toolResults, ctx.Err()
219 }
220
221 return toolResults, nil
222}
223
224func (c *agent) handleToolExecution(
225 ctx context.Context,
226 assistantMsg message.Message,
227) (*message.Message, error) {
228 if len(assistantMsg.ToolCalls()) == 0 {
229 return nil, nil
230 }
231
232 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
233 if err != nil {
234 return nil, err
235 }
236 parts := make([]message.ContentPart, 0)
237 for _, toolResult := range toolResults {
238 parts = append(parts, toolResult)
239 }
240 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
241 Role: message.Tool,
242 Parts: parts,
243 })
244
245 return &msg, err
246}
247
248func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
249 messages, err := c.Messages.List(sessionID)
250 if err != nil {
251 return err
252 }
253
254 if len(messages) == 0 {
255 go c.handleTitleGeneration(ctx, sessionID, content)
256 }
257
258 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
259 Role: message.User,
260 Parts: []message.ContentPart{
261 message.TextContent{
262 Text: content,
263 },
264 },
265 })
266 if err != nil {
267 return err
268 }
269
270 messages = append(messages, userMsg)
271 for {
272 select {
273 case <-ctx.Done():
274 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
275 Role: message.Assistant,
276 Parts: []message.ContentPart{},
277 })
278 if err != nil {
279 return err
280 }
281 assistantMsg.AddFinish("canceled")
282 c.Messages.Update(assistantMsg)
283 return context.Canceled
284 default:
285 // Continue processing
286 }
287
288 eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
289 if err != nil {
290 if errors.Is(err, context.Canceled) {
291 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
292 Role: message.Assistant,
293 Parts: []message.ContentPart{},
294 })
295 if err != nil {
296 return err
297 }
298 assistantMsg.AddFinish("canceled")
299 c.Messages.Update(assistantMsg)
300 return context.Canceled
301 }
302 return err
303 }
304
305 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
306 Role: message.Assistant,
307 Parts: []message.ContentPart{},
308 Model: c.model.ID,
309 })
310 if err != nil {
311 return err
312 }
313 for event := range eventChan {
314 err = c.processEvent(sessionID, &assistantMsg, event)
315 if err != nil {
316 if errors.Is(err, context.Canceled) {
317 assistantMsg.AddFinish("canceled")
318 c.Messages.Update(assistantMsg)
319 return context.Canceled
320 }
321 assistantMsg.AddFinish("error:" + err.Error())
322 c.Messages.Update(assistantMsg)
323 return err
324 }
325
326 select {
327 case <-ctx.Done():
328 assistantMsg.AddFinish("canceled")
329 c.Messages.Update(assistantMsg)
330 return context.Canceled
331 default:
332 }
333 }
334
335 // Check for context cancellation before tool execution
336 select {
337 case <-ctx.Done():
338 assistantMsg.AddFinish("canceled")
339 c.Messages.Update(assistantMsg)
340 return context.Canceled
341 default:
342 // Continue processing
343 }
344
345 msg, err := c.handleToolExecution(ctx, assistantMsg)
346 if err != nil {
347 if errors.Is(err, context.Canceled) {
348 assistantMsg.AddFinish("canceled")
349 c.Messages.Update(assistantMsg)
350 return context.Canceled
351 }
352 return err
353 }
354
355 c.Messages.Update(assistantMsg)
356
357 if len(assistantMsg.ToolCalls()) == 0 {
358 break
359 }
360
361 messages = append(messages, assistantMsg)
362 if msg != nil {
363 messages = append(messages, *msg)
364 }
365
366 // Check for context cancellation after tool execution
367 select {
368 case <-ctx.Done():
369 assistantMsg.AddFinish("canceled")
370 c.Messages.Update(assistantMsg)
371 return context.Canceled
372 default:
373 // Continue processing
374 }
375 }
376 return nil
377}
378
379func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
380 maxTokens := config.Get().Model.CoderMaxTokens
381
382 providerConfig, ok := config.Get().Providers[model.Provider]
383 if !ok || !providerConfig.Enabled {
384 return nil, nil, errors.New("provider is not enabled")
385 }
386 var agentProvider provider.Provider
387 var titleGenerator provider.Provider
388
389 switch model.Provider {
390 case models.ProviderOpenAI:
391 var err error
392 agentProvider, err = provider.NewOpenAIProvider(
393 provider.WithOpenAISystemMessage(
394 prompt.CoderOpenAISystemPrompt(),
395 ),
396 provider.WithOpenAIMaxTokens(maxTokens),
397 provider.WithOpenAIModel(model),
398 provider.WithOpenAIKey(providerConfig.APIKey),
399 )
400 if err != nil {
401 return nil, nil, err
402 }
403 titleGenerator, err = provider.NewOpenAIProvider(
404 provider.WithOpenAISystemMessage(
405 prompt.TitlePrompt(),
406 ),
407 provider.WithOpenAIMaxTokens(80),
408 provider.WithOpenAIModel(model),
409 provider.WithOpenAIKey(providerConfig.APIKey),
410 )
411 if err != nil {
412 return nil, nil, err
413 }
414 case models.ProviderAnthropic:
415 var err error
416 agentProvider, err = provider.NewAnthropicProvider(
417 provider.WithAnthropicSystemMessage(
418 prompt.CoderAnthropicSystemPrompt(),
419 ),
420 provider.WithAnthropicMaxTokens(maxTokens),
421 provider.WithAnthropicKey(providerConfig.APIKey),
422 provider.WithAnthropicModel(model),
423 )
424 if err != nil {
425 return nil, nil, err
426 }
427 titleGenerator, err = provider.NewAnthropicProvider(
428 provider.WithAnthropicSystemMessage(
429 prompt.TitlePrompt(),
430 ),
431 provider.WithAnthropicMaxTokens(80),
432 provider.WithAnthropicKey(providerConfig.APIKey),
433 provider.WithAnthropicModel(model),
434 )
435 if err != nil {
436 return nil, nil, err
437 }
438
439 case models.ProviderGemini:
440 var err error
441 agentProvider, err = provider.NewGeminiProvider(
442 ctx,
443 provider.WithGeminiSystemMessage(
444 prompt.CoderOpenAISystemPrompt(),
445 ),
446 provider.WithGeminiMaxTokens(int32(maxTokens)),
447 provider.WithGeminiKey(providerConfig.APIKey),
448 provider.WithGeminiModel(model),
449 )
450 if err != nil {
451 return nil, nil, err
452 }
453 titleGenerator, err = provider.NewGeminiProvider(
454 ctx,
455 provider.WithGeminiSystemMessage(
456 prompt.TitlePrompt(),
457 ),
458 provider.WithGeminiMaxTokens(80),
459 provider.WithGeminiKey(providerConfig.APIKey),
460 provider.WithGeminiModel(model),
461 )
462 if err != nil {
463 return nil, nil, err
464 }
465 case models.ProviderGROQ:
466 var err error
467 agentProvider, err = provider.NewOpenAIProvider(
468 provider.WithOpenAISystemMessage(
469 prompt.CoderAnthropicSystemPrompt(),
470 ),
471 provider.WithOpenAIMaxTokens(maxTokens),
472 provider.WithOpenAIModel(model),
473 provider.WithOpenAIKey(providerConfig.APIKey),
474 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
475 )
476 if err != nil {
477 return nil, nil, err
478 }
479 titleGenerator, err = provider.NewOpenAIProvider(
480 provider.WithOpenAISystemMessage(
481 prompt.TitlePrompt(),
482 ),
483 provider.WithOpenAIMaxTokens(80),
484 provider.WithOpenAIModel(model),
485 provider.WithOpenAIKey(providerConfig.APIKey),
486 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
487 )
488 if err != nil {
489 return nil, nil, err
490 }
491
492 case models.ProviderBedrock:
493 var err error
494 agentProvider, err = provider.NewBedrockProvider(
495 provider.WithBedrockSystemMessage(
496 prompt.CoderAnthropicSystemPrompt(),
497 ),
498 provider.WithBedrockMaxTokens(maxTokens),
499 provider.WithBedrockModel(model),
500 )
501 if err != nil {
502 return nil, nil, err
503 }
504 titleGenerator, err = provider.NewBedrockProvider(
505 provider.WithBedrockSystemMessage(
506 prompt.TitlePrompt(),
507 ),
508 provider.WithBedrockMaxTokens(maxTokens),
509 provider.WithBedrockModel(model),
510 )
511 if err != nil {
512 return nil, nil, err
513 }
514
515 }
516
517 return agentProvider, titleGenerator, nil
518}