1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/logging"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type Agent interface {
21 Generate(ctx context.Context, sessionID string, content string) error
22}
23
24type agent struct {
25 *app.App
26 model models.Model
27 tools []tools.BaseTool
28 agent provider.Provider
29 titleGenerator provider.Provider
30}
31
32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
33 response, err := c.titleGenerator.SendMessages(
34 ctx,
35 []message.Message{
36 {
37 Role: message.User,
38 Parts: []message.ContentPart{
39 message.TextContent{
40 Text: content,
41 },
42 },
43 },
44 },
45 nil,
46 )
47 if err != nil {
48 return
49 }
50
51 session, err := c.Sessions.Get(ctx, sessionID)
52 if err != nil {
53 return
54 }
55 if response.Content != "" {
56 session.Title = response.Content
57 session.Title = strings.TrimSpace(session.Title)
58 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
59 c.Sessions.Save(ctx, session)
60 }
61}
62
63func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
64 session, err := c.Sessions.Get(ctx, sessionID)
65 if err != nil {
66 return err
67 }
68
69 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
70 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
71 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
72 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
73
74 session.Cost += cost
75 session.CompletionTokens += usage.OutputTokens
76 session.PromptTokens += usage.InputTokens
77
78 _, err = c.Sessions.Save(ctx, session)
79 return err
80}
81
82func (c *agent) processEvent(
83 ctx context.Context,
84 sessionID string,
85 assistantMsg *message.Message,
86 event provider.ProviderEvent,
87) error {
88 switch event.Type {
89 case provider.EventThinkingDelta:
90 assistantMsg.AppendReasoningContent(event.Content)
91 return c.Messages.Update(ctx, *assistantMsg)
92 case provider.EventContentDelta:
93 assistantMsg.AppendContent(event.Content)
94 return c.Messages.Update(ctx, *assistantMsg)
95 case provider.EventError:
96 if errors.Is(event.Error, context.Canceled) {
97 return nil
98 }
99 logging.ErrorPersist(event.Error.Error())
100 return event.Error
101 case provider.EventWarning:
102 logging.WarnPersist(event.Info)
103 return nil
104 case provider.EventInfo:
105 logging.InfoPersist(event.Info)
106 case provider.EventComplete:
107 assistantMsg.SetToolCalls(event.Response.ToolCalls)
108 assistantMsg.AddFinish(event.Response.FinishReason)
109 err := c.Messages.Update(ctx, *assistantMsg)
110 if err != nil {
111 return err
112 }
113 return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
114 }
115
116 return nil
117}
118
119func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
120 var wg sync.WaitGroup
121 toolResults := make([]message.ToolResult, len(toolCalls))
122 mutex := &sync.Mutex{}
123 errChan := make(chan error, 1)
124
125 // Create a child context that can be canceled
126 ctx, cancel := context.WithCancel(ctx)
127 defer cancel()
128
129 for i, tc := range toolCalls {
130 wg.Add(1)
131 go func(index int, toolCall message.ToolCall) {
132 defer wg.Done()
133
134 // Check if context is already canceled
135 select {
136 case <-ctx.Done():
137 mutex.Lock()
138 toolResults[index] = message.ToolResult{
139 ToolCallID: toolCall.ID,
140 Content: "Tool execution canceled",
141 IsError: true,
142 }
143 mutex.Unlock()
144
145 // Send cancellation error to error channel if it's empty
146 select {
147 case errChan <- ctx.Err():
148 default:
149 }
150 return
151 default:
152 }
153
154 response := ""
155 isError := false
156 found := false
157
158 for _, tool := range tls {
159 if tool.Info().Name == toolCall.Name {
160 found = true
161 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
162 ID: toolCall.ID,
163 Name: toolCall.Name,
164 Input: toolCall.Input,
165 })
166
167 if toolErr != nil {
168 if errors.Is(toolErr, context.Canceled) {
169 response = "Tool execution canceled"
170
171 // Send cancellation error to error channel if it's empty
172 select {
173 case errChan <- ctx.Err():
174 default:
175 }
176 } else {
177 response = fmt.Sprintf("error running tool: %s", toolErr)
178 }
179 isError = true
180 } else {
181 response = toolResult.Content
182 isError = toolResult.IsError
183 }
184 break
185 }
186 }
187
188 if !found {
189 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
190 isError = true
191 }
192
193 mutex.Lock()
194 defer mutex.Unlock()
195
196 toolResults[index] = message.ToolResult{
197 ToolCallID: toolCall.ID,
198 Content: response,
199 IsError: isError,
200 }
201 }(i, tc)
202 }
203
204 // Wait for all goroutines to finish or context to be canceled
205 done := make(chan struct{})
206 go func() {
207 wg.Wait()
208 close(done)
209 }()
210
211 select {
212 case <-done:
213 // All tools completed successfully
214 case err := <-errChan:
215 // One of the tools encountered a cancellation
216 return toolResults, err
217 case <-ctx.Done():
218 // Context was canceled externally
219 return toolResults, ctx.Err()
220 }
221
222 return toolResults, nil
223}
224
225func (c *agent) handleToolExecution(
226 ctx context.Context,
227 assistantMsg message.Message,
228) (*message.Message, error) {
229 if len(assistantMsg.ToolCalls()) == 0 {
230 return nil, nil
231 }
232
233 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
234 if err != nil {
235 return nil, err
236 }
237 parts := make([]message.ContentPart, 0)
238 for _, toolResult := range toolResults {
239 parts = append(parts, toolResult)
240 }
241 msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
242 Role: message.Tool,
243 Parts: parts,
244 })
245
246 return &msg, err
247}
248
249func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
250 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
251 messages, err := c.Messages.List(ctx, sessionID)
252 if err != nil {
253 return err
254 }
255
256 if len(messages) == 0 {
257 go c.handleTitleGeneration(ctx, sessionID, content)
258 }
259
260 userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
261 Role: message.User,
262 Parts: []message.ContentPart{
263 message.TextContent{
264 Text: content,
265 },
266 },
267 })
268 if err != nil {
269 return err
270 }
271
272 messages = append(messages, userMsg)
273 for {
274 select {
275 case <-ctx.Done():
276 assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
277 Role: message.Assistant,
278 Parts: []message.ContentPart{},
279 })
280 if err != nil {
281 return err
282 }
283 assistantMsg.AddFinish("canceled")
284 c.Messages.Update(ctx, assistantMsg)
285 return context.Canceled
286 default:
287 // Continue processing
288 }
289
290 eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
291 if err != nil {
292 if errors.Is(err, context.Canceled) {
293 assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
294 Role: message.Assistant,
295 Parts: []message.ContentPart{},
296 })
297 if err != nil {
298 return err
299 }
300 assistantMsg.AddFinish("canceled")
301 c.Messages.Update(ctx, assistantMsg)
302 return context.Canceled
303 }
304 return err
305 }
306
307 assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
308 Role: message.Assistant,
309 Parts: []message.ContentPart{},
310 Model: c.model.ID,
311 })
312 if err != nil {
313 return err
314 }
315
316 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
317 for event := range eventChan {
318 err = c.processEvent(ctx, sessionID, &assistantMsg, event)
319 if err != nil {
320 if errors.Is(err, context.Canceled) {
321 assistantMsg.AddFinish("canceled")
322 c.Messages.Update(ctx, assistantMsg)
323 return context.Canceled
324 }
325 assistantMsg.AddFinish("error:" + err.Error())
326 c.Messages.Update(ctx, assistantMsg)
327 return err
328 }
329
330 select {
331 case <-ctx.Done():
332 assistantMsg.AddFinish("canceled")
333 c.Messages.Update(ctx, assistantMsg)
334 return context.Canceled
335 default:
336 }
337 }
338
339 // Check for context cancellation before tool execution
340 select {
341 case <-ctx.Done():
342 assistantMsg.AddFinish("canceled")
343 c.Messages.Update(ctx, assistantMsg)
344 return context.Canceled
345 default:
346 // Continue processing
347 }
348
349 msg, err := c.handleToolExecution(ctx, assistantMsg)
350 if err != nil {
351 if errors.Is(err, context.Canceled) {
352 assistantMsg.AddFinish("canceled")
353 c.Messages.Update(ctx, assistantMsg)
354 return context.Canceled
355 }
356 return err
357 }
358
359 c.Messages.Update(ctx, assistantMsg)
360
361 if len(assistantMsg.ToolCalls()) == 0 {
362 break
363 }
364
365 messages = append(messages, assistantMsg)
366 if msg != nil {
367 messages = append(messages, *msg)
368 }
369
370 // Check for context cancellation after tool execution
371 select {
372 case <-ctx.Done():
373 assistantMsg.AddFinish("canceled")
374 c.Messages.Update(ctx, assistantMsg)
375 return context.Canceled
376 default:
377 // Continue processing
378 }
379 }
380 return nil
381}
382
383func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
384 maxTokens := config.Get().Model.CoderMaxTokens
385
386 providerConfig, ok := config.Get().Providers[model.Provider]
387 if !ok || providerConfig.Disabled {
388 return nil, nil, errors.New("provider is not enabled")
389 }
390 var agentProvider provider.Provider
391 var titleGenerator provider.Provider
392
393 switch model.Provider {
394 case models.ProviderOpenAI:
395 var err error
396 agentProvider, err = provider.NewOpenAIProvider(
397 provider.WithOpenAISystemMessage(
398 prompt.CoderOpenAISystemPrompt(),
399 ),
400 provider.WithOpenAIMaxTokens(maxTokens),
401 provider.WithOpenAIModel(model),
402 provider.WithOpenAIKey(providerConfig.APIKey),
403 )
404 if err != nil {
405 return nil, nil, err
406 }
407 titleGenerator, err = provider.NewOpenAIProvider(
408 provider.WithOpenAISystemMessage(
409 prompt.TitlePrompt(),
410 ),
411 provider.WithOpenAIMaxTokens(80),
412 provider.WithOpenAIModel(model),
413 provider.WithOpenAIKey(providerConfig.APIKey),
414 )
415 if err != nil {
416 return nil, nil, err
417 }
418 case models.ProviderAnthropic:
419 var err error
420 agentProvider, err = provider.NewAnthropicProvider(
421 provider.WithAnthropicSystemMessage(
422 prompt.CoderAnthropicSystemPrompt(),
423 ),
424 provider.WithAnthropicMaxTokens(maxTokens),
425 provider.WithAnthropicKey(providerConfig.APIKey),
426 provider.WithAnthropicModel(model),
427 )
428 if err != nil {
429 return nil, nil, err
430 }
431 titleGenerator, err = provider.NewAnthropicProvider(
432 provider.WithAnthropicSystemMessage(
433 prompt.TitlePrompt(),
434 ),
435 provider.WithAnthropicMaxTokens(80),
436 provider.WithAnthropicKey(providerConfig.APIKey),
437 provider.WithAnthropicModel(model),
438 )
439 if err != nil {
440 return nil, nil, err
441 }
442
443 case models.ProviderGemini:
444 var err error
445 agentProvider, err = provider.NewGeminiProvider(
446 ctx,
447 provider.WithGeminiSystemMessage(
448 prompt.CoderOpenAISystemPrompt(),
449 ),
450 provider.WithGeminiMaxTokens(int32(maxTokens)),
451 provider.WithGeminiKey(providerConfig.APIKey),
452 provider.WithGeminiModel(model),
453 )
454 if err != nil {
455 return nil, nil, err
456 }
457 titleGenerator, err = provider.NewGeminiProvider(
458 ctx,
459 provider.WithGeminiSystemMessage(
460 prompt.TitlePrompt(),
461 ),
462 provider.WithGeminiMaxTokens(80),
463 provider.WithGeminiKey(providerConfig.APIKey),
464 provider.WithGeminiModel(model),
465 )
466 if err != nil {
467 return nil, nil, err
468 }
469 case models.ProviderGROQ:
470 var err error
471 agentProvider, err = provider.NewOpenAIProvider(
472 provider.WithOpenAISystemMessage(
473 prompt.CoderAnthropicSystemPrompt(),
474 ),
475 provider.WithOpenAIMaxTokens(maxTokens),
476 provider.WithOpenAIModel(model),
477 provider.WithOpenAIKey(providerConfig.APIKey),
478 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
479 )
480 if err != nil {
481 return nil, nil, err
482 }
483 titleGenerator, err = provider.NewOpenAIProvider(
484 provider.WithOpenAISystemMessage(
485 prompt.TitlePrompt(),
486 ),
487 provider.WithOpenAIMaxTokens(80),
488 provider.WithOpenAIModel(model),
489 provider.WithOpenAIKey(providerConfig.APIKey),
490 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
491 )
492 if err != nil {
493 return nil, nil, err
494 }
495
496 case models.ProviderBedrock:
497 var err error
498 agentProvider, err = provider.NewBedrockProvider(
499 provider.WithBedrockSystemMessage(
500 prompt.CoderAnthropicSystemPrompt(),
501 ),
502 provider.WithBedrockMaxTokens(maxTokens),
503 provider.WithBedrockModel(model),
504 )
505 if err != nil {
506 return nil, nil, err
507 }
508 titleGenerator, err = provider.NewBedrockProvider(
509 provider.WithBedrockSystemMessage(
510 prompt.TitlePrompt(),
511 ),
512 provider.WithBedrockMaxTokens(maxTokens),
513 provider.WithBedrockModel(model),
514 )
515 if err != nil {
516 return nil, nil, err
517 }
518
519 }
520
521 return agentProvider, titleGenerator, nil
522}