agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/app"
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/models"
 13	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 14	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/logging"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type Agent interface {
 21	Generate(ctx context.Context, sessionID string, content string) error
 22}
 23
 24type agent struct {
 25	*app.App
 26	model          models.Model
 27	tools          []tools.BaseTool
 28	agent          provider.Provider
 29	titleGenerator provider.Provider
 30}
 31
 32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
 33	response, err := c.titleGenerator.SendMessages(
 34		ctx,
 35		[]message.Message{
 36			{
 37				Role: message.User,
 38				Parts: []message.ContentPart{
 39					message.TextContent{
 40						Text: content,
 41					},
 42				},
 43			},
 44		},
 45		nil,
 46	)
 47	if err != nil {
 48		return
 49	}
 50
 51	session, err := c.Sessions.Get(sessionID)
 52	if err != nil {
 53		return
 54	}
 55	if response.Content != "" {
 56		session.Title = response.Content
 57		session.Title = strings.TrimSpace(session.Title)
 58		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 59		c.Sessions.Save(session)
 60	}
 61}
 62
 63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 64	session, err := c.Sessions.Get(sessionID)
 65	if err != nil {
 66		return err
 67	}
 68
 69	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 70		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 71		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 72		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 73
 74	session.Cost += cost
 75	session.CompletionTokens += usage.OutputTokens
 76	session.PromptTokens += usage.InputTokens
 77
 78	_, err = c.Sessions.Save(session)
 79	return err
 80}
 81
 82func (c *agent) processEvent(
 83	sessionID string,
 84	assistantMsg *message.Message,
 85	event provider.ProviderEvent,
 86) error {
 87	switch event.Type {
 88	case provider.EventThinkingDelta:
 89		assistantMsg.AppendReasoningContent(event.Content)
 90		return c.Messages.Update(*assistantMsg)
 91	case provider.EventContentDelta:
 92		assistantMsg.AppendContent(event.Content)
 93		return c.Messages.Update(*assistantMsg)
 94	case provider.EventError:
 95		if errors.Is(event.Error, context.Canceled) {
 96			return nil
 97		}
 98		logging.ErrorPersist(event.Error.Error())
 99		return event.Error
100	case provider.EventWarning:
101		logging.WarnPersist(event.Info)
102		return nil
103	case provider.EventInfo:
104		logging.InfoPersist(event.Info)
105	case provider.EventComplete:
106		assistantMsg.SetToolCalls(event.Response.ToolCalls)
107		assistantMsg.AddFinish(event.Response.FinishReason)
108		err := c.Messages.Update(*assistantMsg)
109		if err != nil {
110			return err
111		}
112		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
113	}
114
115	return nil
116}
117
118func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
119	var wg sync.WaitGroup
120	toolResults := make([]message.ToolResult, len(toolCalls))
121	mutex := &sync.Mutex{}
122	errChan := make(chan error, 1)
123
124	// Create a child context that can be canceled
125	ctx, cancel := context.WithCancel(ctx)
126	defer cancel()
127
128	for i, tc := range toolCalls {
129		wg.Add(1)
130		go func(index int, toolCall message.ToolCall) {
131			defer wg.Done()
132
133			// Check if context is already canceled
134			select {
135			case <-ctx.Done():
136				mutex.Lock()
137				toolResults[index] = message.ToolResult{
138					ToolCallID: toolCall.ID,
139					Content:    "Tool execution canceled",
140					IsError:    true,
141				}
142				mutex.Unlock()
143
144				// Send cancellation error to error channel if it's empty
145				select {
146				case errChan <- ctx.Err():
147				default:
148				}
149				return
150			default:
151			}
152
153			response := ""
154			isError := false
155			found := false
156
157			for _, tool := range tls {
158				if tool.Info().Name == toolCall.Name {
159					found = true
160					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
161						ID:    toolCall.ID,
162						Name:  toolCall.Name,
163						Input: toolCall.Input,
164					})
165
166					if toolErr != nil {
167						if errors.Is(toolErr, context.Canceled) {
168							response = "Tool execution canceled"
169
170							// Send cancellation error to error channel if it's empty
171							select {
172							case errChan <- ctx.Err():
173							default:
174							}
175						} else {
176							response = fmt.Sprintf("error running tool: %s", toolErr)
177						}
178						isError = true
179					} else {
180						response = toolResult.Content
181						isError = toolResult.IsError
182					}
183					break
184				}
185			}
186
187			if !found {
188				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
189				isError = true
190			}
191
192			mutex.Lock()
193			defer mutex.Unlock()
194
195			toolResults[index] = message.ToolResult{
196				ToolCallID: toolCall.ID,
197				Content:    response,
198				IsError:    isError,
199			}
200		}(i, tc)
201	}
202
203	// Wait for all goroutines to finish or context to be canceled
204	done := make(chan struct{})
205	go func() {
206		wg.Wait()
207		close(done)
208	}()
209
210	select {
211	case <-done:
212		// All tools completed successfully
213	case err := <-errChan:
214		// One of the tools encountered a cancellation
215		return toolResults, err
216	case <-ctx.Done():
217		// Context was canceled externally
218		return toolResults, ctx.Err()
219	}
220
221	return toolResults, nil
222}
223
224func (c *agent) handleToolExecution(
225	ctx context.Context,
226	assistantMsg message.Message,
227) (*message.Message, error) {
228	if len(assistantMsg.ToolCalls()) == 0 {
229		return nil, nil
230	}
231
232	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
233	if err != nil {
234		return nil, err
235	}
236	parts := make([]message.ContentPart, 0)
237	for _, toolResult := range toolResults {
238		parts = append(parts, toolResult)
239	}
240	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
241		Role:  message.Tool,
242		Parts: parts,
243	})
244
245	return &msg, err
246}
247
248func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
249	messages, err := c.Messages.List(sessionID)
250	if err != nil {
251		return err
252	}
253
254	if len(messages) == 0 {
255		go c.handleTitleGeneration(ctx, sessionID, content)
256	}
257
258	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
259		Role: message.User,
260		Parts: []message.ContentPart{
261			message.TextContent{
262				Text: content,
263			},
264		},
265	})
266	if err != nil {
267		return err
268	}
269
270	messages = append(messages, userMsg)
271	for {
272		select {
273		case <-ctx.Done():
274			assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
275				Role:  message.Assistant,
276				Parts: []message.ContentPart{},
277			})
278			if err != nil {
279				return err
280			}
281			assistantMsg.AddFinish("canceled")
282			c.Messages.Update(assistantMsg)
283			return context.Canceled
284		default:
285			// Continue processing
286		}
287
288		eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
289		if err != nil {
290			if errors.Is(err, context.Canceled) {
291				assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
292					Role:  message.Assistant,
293					Parts: []message.ContentPart{},
294				})
295				if err != nil {
296					return err
297				}
298				assistantMsg.AddFinish("canceled")
299				c.Messages.Update(assistantMsg)
300				return context.Canceled
301			}
302			return err
303		}
304
305		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
306			Role:  message.Assistant,
307			Parts: []message.ContentPart{},
308			Model: c.model.ID,
309		})
310		if err != nil {
311			return err
312		}
313		for event := range eventChan {
314			err = c.processEvent(sessionID, &assistantMsg, event)
315			if err != nil {
316				if errors.Is(err, context.Canceled) {
317					assistantMsg.AddFinish("canceled")
318					c.Messages.Update(assistantMsg)
319					return context.Canceled
320				}
321				assistantMsg.AddFinish("error:" + err.Error())
322				c.Messages.Update(assistantMsg)
323				return err
324			}
325
326			select {
327			case <-ctx.Done():
328				assistantMsg.AddFinish("canceled")
329				c.Messages.Update(assistantMsg)
330				return context.Canceled
331			default:
332			}
333		}
334
335		// Check for context cancellation before tool execution
336		select {
337		case <-ctx.Done():
338			assistantMsg.AddFinish("canceled")
339			c.Messages.Update(assistantMsg)
340			return context.Canceled
341		default:
342			// Continue processing
343		}
344
345		msg, err := c.handleToolExecution(ctx, assistantMsg)
346		if err != nil {
347			if errors.Is(err, context.Canceled) {
348				assistantMsg.AddFinish("canceled")
349				c.Messages.Update(assistantMsg)
350				return context.Canceled
351			}
352			return err
353		}
354
355		c.Messages.Update(assistantMsg)
356
357		if len(assistantMsg.ToolCalls()) == 0 {
358			break
359		}
360
361		messages = append(messages, assistantMsg)
362		if msg != nil {
363			messages = append(messages, *msg)
364		}
365
366		// Check for context cancellation after tool execution
367		select {
368		case <-ctx.Done():
369			assistantMsg.AddFinish("canceled")
370			c.Messages.Update(assistantMsg)
371			return context.Canceled
372		default:
373			// Continue processing
374		}
375	}
376	return nil
377}
378
379func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
380	maxTokens := config.Get().Model.CoderMaxTokens
381
382	providerConfig, ok := config.Get().Providers[model.Provider]
383	if !ok || !providerConfig.Enabled {
384		return nil, nil, errors.New("provider is not enabled")
385	}
386	var agentProvider provider.Provider
387	var titleGenerator provider.Provider
388
389	switch model.Provider {
390	case models.ProviderOpenAI:
391		var err error
392		agentProvider, err = provider.NewOpenAIProvider(
393			provider.WithOpenAISystemMessage(
394				prompt.CoderOpenAISystemPrompt(),
395			),
396			provider.WithOpenAIMaxTokens(maxTokens),
397			provider.WithOpenAIModel(model),
398			provider.WithOpenAIKey(providerConfig.APIKey),
399		)
400		if err != nil {
401			return nil, nil, err
402		}
403		titleGenerator, err = provider.NewOpenAIProvider(
404			provider.WithOpenAISystemMessage(
405				prompt.TitlePrompt(),
406			),
407			provider.WithOpenAIMaxTokens(80),
408			provider.WithOpenAIModel(model),
409			provider.WithOpenAIKey(providerConfig.APIKey),
410		)
411		if err != nil {
412			return nil, nil, err
413		}
414	case models.ProviderAnthropic:
415		var err error
416		agentProvider, err = provider.NewAnthropicProvider(
417			provider.WithAnthropicSystemMessage(
418				prompt.CoderAnthropicSystemPrompt(),
419			),
420			provider.WithAnthropicMaxTokens(maxTokens),
421			provider.WithAnthropicKey(providerConfig.APIKey),
422			provider.WithAnthropicModel(model),
423		)
424		if err != nil {
425			return nil, nil, err
426		}
427		titleGenerator, err = provider.NewAnthropicProvider(
428			provider.WithAnthropicSystemMessage(
429				prompt.TitlePrompt(),
430			),
431			provider.WithAnthropicMaxTokens(80),
432			provider.WithAnthropicKey(providerConfig.APIKey),
433			provider.WithAnthropicModel(model),
434		)
435		if err != nil {
436			return nil, nil, err
437		}
438
439	case models.ProviderGemini:
440		var err error
441		agentProvider, err = provider.NewGeminiProvider(
442			ctx,
443			provider.WithGeminiSystemMessage(
444				prompt.CoderOpenAISystemPrompt(),
445			),
446			provider.WithGeminiMaxTokens(int32(maxTokens)),
447			provider.WithGeminiKey(providerConfig.APIKey),
448			provider.WithGeminiModel(model),
449		)
450		if err != nil {
451			return nil, nil, err
452		}
453		titleGenerator, err = provider.NewGeminiProvider(
454			ctx,
455			provider.WithGeminiSystemMessage(
456				prompt.TitlePrompt(),
457			),
458			provider.WithGeminiMaxTokens(80),
459			provider.WithGeminiKey(providerConfig.APIKey),
460			provider.WithGeminiModel(model),
461		)
462		if err != nil {
463			return nil, nil, err
464		}
465	case models.ProviderGROQ:
466		var err error
467		agentProvider, err = provider.NewOpenAIProvider(
468			provider.WithOpenAISystemMessage(
469				prompt.CoderAnthropicSystemPrompt(),
470			),
471			provider.WithOpenAIMaxTokens(maxTokens),
472			provider.WithOpenAIModel(model),
473			provider.WithOpenAIKey(providerConfig.APIKey),
474			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
475		)
476		if err != nil {
477			return nil, nil, err
478		}
479		titleGenerator, err = provider.NewOpenAIProvider(
480			provider.WithOpenAISystemMessage(
481				prompt.TitlePrompt(),
482			),
483			provider.WithOpenAIMaxTokens(80),
484			provider.WithOpenAIModel(model),
485			provider.WithOpenAIKey(providerConfig.APIKey),
486			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
487		)
488		if err != nil {
489			return nil, nil, err
490		}
491
492	case models.ProviderBedrock:
493		var err error
494		agentProvider, err = provider.NewBedrockProvider(
495			provider.WithBedrockSystemMessage(
496				prompt.CoderAnthropicSystemPrompt(),
497			),
498			provider.WithBedrockMaxTokens(maxTokens),
499			provider.WithBedrockModel(model),
500		)
501		if err != nil {
502			return nil, nil, err
503		}
504		titleGenerator, err = provider.NewBedrockProvider(
505			provider.WithBedrockSystemMessage(
506				prompt.TitlePrompt(),
507			),
508			provider.WithBedrockMaxTokens(maxTokens),
509			provider.WithBedrockModel(model),
510		)
511		if err != nil {
512			return nil, nil, err
513		}
514
515	}
516
517	return agentProvider, titleGenerator, nil
518}