agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/app"
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/models"
 13	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 14	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/logging"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type Agent interface {
 21	Generate(ctx context.Context, sessionID string, content string) error
 22}
 23
 24type agent struct {
 25	*app.App
 26	model          models.Model
 27	tools          []tools.BaseTool
 28	agent          provider.Provider
 29	titleGenerator provider.Provider
 30}
 31
 32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
 33	response, err := c.titleGenerator.SendMessages(
 34		ctx,
 35		[]message.Message{
 36			{
 37				Role: message.User,
 38				Parts: []message.ContentPart{
 39					message.TextContent{
 40						Text: content,
 41					},
 42				},
 43			},
 44		},
 45		nil,
 46	)
 47	if err != nil {
 48		return
 49	}
 50
 51	session, err := c.Sessions.Get(sessionID)
 52	if err != nil {
 53		return
 54	}
 55	if response.Content != "" {
 56		session.Title = response.Content
 57		session.Title = strings.TrimSpace(session.Title)
 58		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 59		c.Sessions.Save(session)
 60	}
 61}
 62
 63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 64	session, err := c.Sessions.Get(sessionID)
 65	if err != nil {
 66		return err
 67	}
 68
 69	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 70		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 71		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 72		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 73
 74	session.Cost += cost
 75	session.CompletionTokens += usage.OutputTokens
 76	session.PromptTokens += usage.InputTokens
 77
 78	_, err = c.Sessions.Save(session)
 79	return err
 80}
 81
 82func (c *agent) processEvent(
 83	sessionID string,
 84	assistantMsg *message.Message,
 85	event provider.ProviderEvent,
 86) error {
 87	switch event.Type {
 88	case provider.EventThinkingDelta:
 89		assistantMsg.AppendReasoningContent(event.Content)
 90		return c.Messages.Update(*assistantMsg)
 91	case provider.EventContentDelta:
 92		assistantMsg.AppendContent(event.Content)
 93		return c.Messages.Update(*assistantMsg)
 94	case provider.EventError:
 95		if errors.Is(event.Error, context.Canceled) {
 96			return nil
 97		}
 98		logging.ErrorPersist(event.Error.Error())
 99		return event.Error
100	case provider.EventWarning:
101		logging.WarnPersist(event.Info)
102		return nil
103	case provider.EventInfo:
104		logging.InfoPersist(event.Info)
105	case provider.EventComplete:
106		assistantMsg.SetToolCalls(event.Response.ToolCalls)
107		assistantMsg.AddFinish(event.Response.FinishReason)
108		err := c.Messages.Update(*assistantMsg)
109		if err != nil {
110			return err
111		}
112		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
113	}
114
115	return nil
116}
117
118func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
119	var wg sync.WaitGroup
120	toolResults := make([]message.ToolResult, len(toolCalls))
121	mutex := &sync.Mutex{}
122	errChan := make(chan error, 1)
123
124	// Create a child context that can be canceled
125	ctx, cancel := context.WithCancel(ctx)
126	defer cancel()
127
128	for i, tc := range toolCalls {
129		wg.Add(1)
130		go func(index int, toolCall message.ToolCall) {
131			defer wg.Done()
132
133			// Check if context is already canceled
134			select {
135			case <-ctx.Done():
136				mutex.Lock()
137				toolResults[index] = message.ToolResult{
138					ToolCallID: toolCall.ID,
139					Content:    "Tool execution canceled",
140					IsError:    true,
141				}
142				mutex.Unlock()
143
144				// Send cancellation error to error channel if it's empty
145				select {
146				case errChan <- ctx.Err():
147				default:
148				}
149				return
150			default:
151			}
152
153			response := ""
154			isError := false
155			found := false
156
157			for _, tool := range tls {
158				if tool.Info().Name == toolCall.Name {
159					found = true
160					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
161						ID:    toolCall.ID,
162						Name:  toolCall.Name,
163						Input: toolCall.Input,
164					})
165
166					if toolErr != nil {
167						if errors.Is(toolErr, context.Canceled) {
168							response = "Tool execution canceled"
169
170							// Send cancellation error to error channel if it's empty
171							select {
172							case errChan <- ctx.Err():
173							default:
174							}
175						} else {
176							response = fmt.Sprintf("error running tool: %s", toolErr)
177						}
178						isError = true
179					} else {
180						response = toolResult.Content
181						isError = toolResult.IsError
182					}
183					break
184				}
185			}
186
187			if !found {
188				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
189				isError = true
190			}
191
192			mutex.Lock()
193			defer mutex.Unlock()
194
195			toolResults[index] = message.ToolResult{
196				ToolCallID: toolCall.ID,
197				Content:    response,
198				IsError:    isError,
199			}
200		}(i, tc)
201	}
202
203	// Wait for all goroutines to finish or context to be canceled
204	done := make(chan struct{})
205	go func() {
206		wg.Wait()
207		close(done)
208	}()
209
210	select {
211	case <-done:
212		// All tools completed successfully
213	case err := <-errChan:
214		// One of the tools encountered a cancellation
215		return toolResults, err
216	case <-ctx.Done():
217		// Context was canceled externally
218		return toolResults, ctx.Err()
219	}
220
221	return toolResults, nil
222}
223
224func (c *agent) handleToolExecution(
225	ctx context.Context,
226	assistantMsg message.Message,
227) (*message.Message, error) {
228	if len(assistantMsg.ToolCalls()) == 0 {
229		return nil, nil
230	}
231
232	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
233	if err != nil {
234		return nil, err
235	}
236	parts := make([]message.ContentPart, 0)
237	for _, toolResult := range toolResults {
238		parts = append(parts, toolResult)
239	}
240	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
241		Role:  message.Tool,
242		Parts: parts,
243	})
244
245	return &msg, err
246}
247
248func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
249	messages, err := c.Messages.List(sessionID)
250	if err != nil {
251		return err
252	}
253
254	if len(messages) == 0 {
255		go c.handleTitleGeneration(ctx, sessionID, content)
256	}
257
258	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
259		Role: message.User,
260		Parts: []message.ContentPart{
261			message.TextContent{
262				Text: content,
263			},
264		},
265	})
266	if err != nil {
267		return err
268	}
269
270	messages = append(messages, userMsg)
271	for {
272		select {
273		case <-ctx.Done():
274			assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
275				Role:  message.Assistant,
276				Parts: []message.ContentPart{},
277			})
278			if err != nil {
279				return err
280			}
281			assistantMsg.AddFinish("canceled")
282			c.Messages.Update(assistantMsg)
283			return context.Canceled
284		default:
285			// Continue processing
286		}
287
288		eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
289		if err != nil {
290			if errors.Is(err, context.Canceled) {
291				assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
292					Role:  message.Assistant,
293					Parts: []message.ContentPart{},
294				})
295				if err != nil {
296					return err
297				}
298				assistantMsg.AddFinish("canceled")
299				c.Messages.Update(assistantMsg)
300				return context.Canceled
301			}
302			return err
303		}
304
305		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
306			Role:  message.Assistant,
307			Parts: []message.ContentPart{},
308		})
309		if err != nil {
310			return err
311		}
312		for event := range eventChan {
313			err = c.processEvent(sessionID, &assistantMsg, event)
314			if err != nil {
315				if errors.Is(err, context.Canceled) {
316					assistantMsg.AddFinish("canceled")
317					c.Messages.Update(assistantMsg)
318					return context.Canceled
319				}
320				assistantMsg.AddFinish("error:" + err.Error())
321				c.Messages.Update(assistantMsg)
322				return err
323			}
324
325			select {
326			case <-ctx.Done():
327				assistantMsg.AddFinish("canceled")
328				c.Messages.Update(assistantMsg)
329				return context.Canceled
330			default:
331			}
332		}
333
334		// Check for context cancellation before tool execution
335		select {
336		case <-ctx.Done():
337			assistantMsg.AddFinish("canceled")
338			c.Messages.Update(assistantMsg)
339			return context.Canceled
340		default:
341			// Continue processing
342		}
343
344		msg, err := c.handleToolExecution(ctx, assistantMsg)
345		if err != nil {
346			if errors.Is(err, context.Canceled) {
347				assistantMsg.AddFinish("canceled")
348				c.Messages.Update(assistantMsg)
349				return context.Canceled
350			}
351			return err
352		}
353
354		c.Messages.Update(assistantMsg)
355
356		if len(assistantMsg.ToolCalls()) == 0 {
357			break
358		}
359
360		messages = append(messages, assistantMsg)
361		if msg != nil {
362			messages = append(messages, *msg)
363		}
364
365		// Check for context cancellation after tool execution
366		select {
367		case <-ctx.Done():
368			assistantMsg.AddFinish("canceled")
369			c.Messages.Update(assistantMsg)
370			return context.Canceled
371		default:
372			// Continue processing
373		}
374	}
375	return nil
376}
377
378func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
379	maxTokens := config.Get().Model.CoderMaxTokens
380
381	providerConfig, ok := config.Get().Providers[model.Provider]
382	if !ok || !providerConfig.Enabled {
383		return nil, nil, errors.New("provider is not enabled")
384	}
385	var agentProvider provider.Provider
386	var titleGenerator provider.Provider
387
388	switch model.Provider {
389	case models.ProviderOpenAI:
390		var err error
391		agentProvider, err = provider.NewOpenAIProvider(
392			provider.WithOpenAISystemMessage(
393				prompt.CoderOpenAISystemPrompt(),
394			),
395			provider.WithOpenAIMaxTokens(maxTokens),
396			provider.WithOpenAIModel(model),
397			provider.WithOpenAIKey(providerConfig.APIKey),
398		)
399		if err != nil {
400			return nil, nil, err
401		}
402		titleGenerator, err = provider.NewOpenAIProvider(
403			provider.WithOpenAISystemMessage(
404				prompt.TitlePrompt(),
405			),
406			provider.WithOpenAIMaxTokens(80),
407			provider.WithOpenAIModel(model),
408			provider.WithOpenAIKey(providerConfig.APIKey),
409		)
410		if err != nil {
411			return nil, nil, err
412		}
413	case models.ProviderAnthropic:
414		var err error
415		agentProvider, err = provider.NewAnthropicProvider(
416			provider.WithAnthropicSystemMessage(
417				prompt.CoderAnthropicSystemPrompt(),
418			),
419			provider.WithAnthropicMaxTokens(maxTokens),
420			provider.WithAnthropicKey(providerConfig.APIKey),
421			provider.WithAnthropicModel(model),
422		)
423		if err != nil {
424			return nil, nil, err
425		}
426		titleGenerator, err = provider.NewAnthropicProvider(
427			provider.WithAnthropicSystemMessage(
428				prompt.TitlePrompt(),
429			),
430			provider.WithAnthropicMaxTokens(80),
431			provider.WithAnthropicKey(providerConfig.APIKey),
432			provider.WithAnthropicModel(model),
433		)
434		if err != nil {
435			return nil, nil, err
436		}
437
438	case models.ProviderGemini:
439		var err error
440		agentProvider, err = provider.NewGeminiProvider(
441			ctx,
442			provider.WithGeminiSystemMessage(
443				prompt.CoderOpenAISystemPrompt(),
444			),
445			provider.WithGeminiMaxTokens(int32(maxTokens)),
446			provider.WithGeminiKey(providerConfig.APIKey),
447			provider.WithGeminiModel(model),
448		)
449		if err != nil {
450			return nil, nil, err
451		}
452		titleGenerator, err = provider.NewGeminiProvider(
453			ctx,
454			provider.WithGeminiSystemMessage(
455				prompt.TitlePrompt(),
456			),
457			provider.WithGeminiMaxTokens(80),
458			provider.WithGeminiKey(providerConfig.APIKey),
459			provider.WithGeminiModel(model),
460		)
461		if err != nil {
462			return nil, nil, err
463		}
464	case models.ProviderGROQ:
465		var err error
466		agentProvider, err = provider.NewOpenAIProvider(
467			provider.WithOpenAISystemMessage(
468				prompt.CoderAnthropicSystemPrompt(),
469			),
470			provider.WithOpenAIMaxTokens(maxTokens),
471			provider.WithOpenAIModel(model),
472			provider.WithOpenAIKey(providerConfig.APIKey),
473			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
474		)
475		if err != nil {
476			return nil, nil, err
477		}
478		titleGenerator, err = provider.NewOpenAIProvider(
479			provider.WithOpenAISystemMessage(
480				prompt.TitlePrompt(),
481			),
482			provider.WithOpenAIMaxTokens(80),
483			provider.WithOpenAIModel(model),
484			provider.WithOpenAIKey(providerConfig.APIKey),
485			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
486		)
487		if err != nil {
488			return nil, nil, err
489		}
490
491	case models.ProviderBedrock:
492		var err error
493		agentProvider, err = provider.NewBedrockProvider(
494			provider.WithBedrockSystemMessage(
495				prompt.CoderAnthropicSystemPrompt(),
496			),
497			provider.WithBedrockMaxTokens(maxTokens),
498			provider.WithBedrockModel(model),
499		)
500		if err != nil {
501			return nil, nil, err
502		}
503		titleGenerator, err = provider.NewBedrockProvider(
504			provider.WithBedrockSystemMessage(
505				prompt.TitlePrompt(),
506			),
507			provider.WithBedrockMaxTokens(maxTokens),
508			provider.WithBedrockModel(model),
509		)
510		if err != nil {
511			return nil, nil, err
512		}
513
514	}
515
516	return agentProvider, titleGenerator, nil
517}