agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/app"
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/models"
 13	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 14	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/logging"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type Agent interface {
 21	Generate(ctx context.Context, sessionID string, content string) error
 22}
 23
 24type agent struct {
 25	*app.App
 26	model          models.Model
 27	tools          []tools.BaseTool
 28	agent          provider.Provider
 29	titleGenerator provider.Provider
 30}
 31
 32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
 33	response, err := c.titleGenerator.SendMessages(
 34		ctx,
 35		[]message.Message{
 36			{
 37				Role: message.User,
 38				Parts: []message.ContentPart{
 39					message.TextContent{
 40						Text: content,
 41					},
 42				},
 43			},
 44		},
 45		nil,
 46	)
 47	if err != nil {
 48		return
 49	}
 50
 51	session, err := c.Sessions.Get(ctx, sessionID)
 52	if err != nil {
 53		return
 54	}
 55	if response.Content != "" {
 56		session.Title = response.Content
 57		session.Title = strings.TrimSpace(session.Title)
 58		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 59		c.Sessions.Save(ctx, session)
 60	}
 61}
 62
 63func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
 64	session, err := c.Sessions.Get(ctx, sessionID)
 65	if err != nil {
 66		return err
 67	}
 68
 69	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 70		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 71		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 72		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 73
 74	session.Cost += cost
 75	session.CompletionTokens += usage.OutputTokens
 76	session.PromptTokens += usage.InputTokens
 77
 78	_, err = c.Sessions.Save(ctx, session)
 79	return err
 80}
 81
 82func (c *agent) processEvent(
 83	ctx context.Context,
 84	sessionID string,
 85	assistantMsg *message.Message,
 86	event provider.ProviderEvent,
 87) error {
 88	switch event.Type {
 89	case provider.EventThinkingDelta:
 90		assistantMsg.AppendReasoningContent(event.Content)
 91		return c.Messages.Update(ctx, *assistantMsg)
 92	case provider.EventContentDelta:
 93		assistantMsg.AppendContent(event.Content)
 94		return c.Messages.Update(ctx, *assistantMsg)
 95	case provider.EventError:
 96		if errors.Is(event.Error, context.Canceled) {
 97			return nil
 98		}
 99		logging.ErrorPersist(event.Error.Error())
100		return event.Error
101	case provider.EventWarning:
102		logging.WarnPersist(event.Info)
103		return nil
104	case provider.EventInfo:
105		logging.InfoPersist(event.Info)
106	case provider.EventComplete:
107		assistantMsg.SetToolCalls(event.Response.ToolCalls)
108		assistantMsg.AddFinish(event.Response.FinishReason)
109		err := c.Messages.Update(ctx, *assistantMsg)
110		if err != nil {
111			return err
112		}
113		return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
114	}
115
116	return nil
117}
118
119func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
120	var wg sync.WaitGroup
121	toolResults := make([]message.ToolResult, len(toolCalls))
122	mutex := &sync.Mutex{}
123	errChan := make(chan error, 1)
124
125	// Create a child context that can be canceled
126	ctx, cancel := context.WithCancel(ctx)
127	defer cancel()
128
129	for i, tc := range toolCalls {
130		wg.Add(1)
131		go func(index int, toolCall message.ToolCall) {
132			defer wg.Done()
133
134			// Check if context is already canceled
135			select {
136			case <-ctx.Done():
137				mutex.Lock()
138				toolResults[index] = message.ToolResult{
139					ToolCallID: toolCall.ID,
140					Content:    "Tool execution canceled",
141					IsError:    true,
142				}
143				mutex.Unlock()
144
145				// Send cancellation error to error channel if it's empty
146				select {
147				case errChan <- ctx.Err():
148				default:
149				}
150				return
151			default:
152			}
153
154			response := ""
155			isError := false
156			found := false
157
158			for _, tool := range tls {
159				if tool.Info().Name == toolCall.Name {
160					found = true
161					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
162						ID:    toolCall.ID,
163						Name:  toolCall.Name,
164						Input: toolCall.Input,
165					})
166
167					if toolErr != nil {
168						if errors.Is(toolErr, context.Canceled) {
169							response = "Tool execution canceled"
170
171							// Send cancellation error to error channel if it's empty
172							select {
173							case errChan <- ctx.Err():
174							default:
175							}
176						} else {
177							response = fmt.Sprintf("error running tool: %s", toolErr)
178						}
179						isError = true
180					} else {
181						response = toolResult.Content
182						isError = toolResult.IsError
183					}
184					break
185				}
186			}
187
188			if !found {
189				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
190				isError = true
191			}
192
193			mutex.Lock()
194			defer mutex.Unlock()
195
196			toolResults[index] = message.ToolResult{
197				ToolCallID: toolCall.ID,
198				Content:    response,
199				IsError:    isError,
200			}
201		}(i, tc)
202	}
203
204	// Wait for all goroutines to finish or context to be canceled
205	done := make(chan struct{})
206	go func() {
207		wg.Wait()
208		close(done)
209	}()
210
211	select {
212	case <-done:
213		// All tools completed successfully
214	case err := <-errChan:
215		// One of the tools encountered a cancellation
216		return toolResults, err
217	case <-ctx.Done():
218		// Context was canceled externally
219		return toolResults, ctx.Err()
220	}
221
222	return toolResults, nil
223}
224
225func (c *agent) handleToolExecution(
226	ctx context.Context,
227	assistantMsg message.Message,
228) (*message.Message, error) {
229	if len(assistantMsg.ToolCalls()) == 0 {
230		return nil, nil
231	}
232
233	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
234	if err != nil {
235		return nil, err
236	}
237	parts := make([]message.ContentPart, 0)
238	for _, toolResult := range toolResults {
239		parts = append(parts, toolResult)
240	}
241	msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
242		Role:  message.Tool,
243		Parts: parts,
244	})
245
246	return &msg, err
247}
248
249func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
250	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
251	messages, err := c.Messages.List(ctx, sessionID)
252	if err != nil {
253		return err
254	}
255
256	if len(messages) == 0 {
257		go c.handleTitleGeneration(ctx, sessionID, content)
258	}
259
260	userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
261		Role: message.User,
262		Parts: []message.ContentPart{
263			message.TextContent{
264				Text: content,
265			},
266		},
267	})
268	if err != nil {
269		return err
270	}
271
272	messages = append(messages, userMsg)
273	for {
274		select {
275		case <-ctx.Done():
276			assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
277				Role:  message.Assistant,
278				Parts: []message.ContentPart{},
279			})
280			if err != nil {
281				return err
282			}
283			assistantMsg.AddFinish("canceled")
284			c.Messages.Update(ctx, assistantMsg)
285			return context.Canceled
286		default:
287			// Continue processing
288		}
289
290		eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
291		if err != nil {
292			if errors.Is(err, context.Canceled) {
293				assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
294					Role:  message.Assistant,
295					Parts: []message.ContentPart{},
296				})
297				if err != nil {
298					return err
299				}
300				assistantMsg.AddFinish("canceled")
301				c.Messages.Update(ctx, assistantMsg)
302				return context.Canceled
303			}
304			return err
305		}
306
307		assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
308			Role:  message.Assistant,
309			Parts: []message.ContentPart{},
310			Model: c.model.ID,
311		})
312		if err != nil {
313			return err
314		}
315
316		ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
317		for event := range eventChan {
318			err = c.processEvent(ctx, sessionID, &assistantMsg, event)
319			if err != nil {
320				if errors.Is(err, context.Canceled) {
321					assistantMsg.AddFinish("canceled")
322					c.Messages.Update(ctx, assistantMsg)
323					return context.Canceled
324				}
325				assistantMsg.AddFinish("error:" + err.Error())
326				c.Messages.Update(ctx, assistantMsg)
327				return err
328			}
329
330			select {
331			case <-ctx.Done():
332				assistantMsg.AddFinish("canceled")
333				c.Messages.Update(ctx, assistantMsg)
334				return context.Canceled
335			default:
336			}
337		}
338
339		// Check for context cancellation before tool execution
340		select {
341		case <-ctx.Done():
342			assistantMsg.AddFinish("canceled")
343			c.Messages.Update(ctx, assistantMsg)
344			return context.Canceled
345		default:
346			// Continue processing
347		}
348
349		msg, err := c.handleToolExecution(ctx, assistantMsg)
350		if err != nil {
351			if errors.Is(err, context.Canceled) {
352				assistantMsg.AddFinish("canceled")
353				c.Messages.Update(ctx, assistantMsg)
354				return context.Canceled
355			}
356			return err
357		}
358
359		c.Messages.Update(ctx, assistantMsg)
360
361		if len(assistantMsg.ToolCalls()) == 0 {
362			break
363		}
364
365		messages = append(messages, assistantMsg)
366		if msg != nil {
367			messages = append(messages, *msg)
368		}
369
370		// Check for context cancellation after tool execution
371		select {
372		case <-ctx.Done():
373			assistantMsg.AddFinish("canceled")
374			c.Messages.Update(ctx, assistantMsg)
375			return context.Canceled
376		default:
377			// Continue processing
378		}
379	}
380	return nil
381}
382
383func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
384	maxTokens := config.Get().Model.CoderMaxTokens
385
386	providerConfig, ok := config.Get().Providers[model.Provider]
387	if !ok || providerConfig.Disabled {
388		return nil, nil, errors.New("provider is not enabled")
389	}
390	var agentProvider provider.Provider
391	var titleGenerator provider.Provider
392
393	switch model.Provider {
394	case models.ProviderOpenAI:
395		var err error
396		agentProvider, err = provider.NewOpenAIProvider(
397			provider.WithOpenAISystemMessage(
398				prompt.CoderOpenAISystemPrompt(),
399			),
400			provider.WithOpenAIMaxTokens(maxTokens),
401			provider.WithOpenAIModel(model),
402			provider.WithOpenAIKey(providerConfig.APIKey),
403		)
404		if err != nil {
405			return nil, nil, err
406		}
407		titleGenerator, err = provider.NewOpenAIProvider(
408			provider.WithOpenAISystemMessage(
409				prompt.TitlePrompt(),
410			),
411			provider.WithOpenAIMaxTokens(80),
412			provider.WithOpenAIModel(model),
413			provider.WithOpenAIKey(providerConfig.APIKey),
414		)
415		if err != nil {
416			return nil, nil, err
417		}
418	case models.ProviderAnthropic:
419		var err error
420		agentProvider, err = provider.NewAnthropicProvider(
421			provider.WithAnthropicSystemMessage(
422				prompt.CoderAnthropicSystemPrompt(),
423			),
424			provider.WithAnthropicMaxTokens(maxTokens),
425			provider.WithAnthropicKey(providerConfig.APIKey),
426			provider.WithAnthropicModel(model),
427		)
428		if err != nil {
429			return nil, nil, err
430		}
431		titleGenerator, err = provider.NewAnthropicProvider(
432			provider.WithAnthropicSystemMessage(
433				prompt.TitlePrompt(),
434			),
435			provider.WithAnthropicMaxTokens(80),
436			provider.WithAnthropicKey(providerConfig.APIKey),
437			provider.WithAnthropicModel(model),
438		)
439		if err != nil {
440			return nil, nil, err
441		}
442
443	case models.ProviderGemini:
444		var err error
445		agentProvider, err = provider.NewGeminiProvider(
446			ctx,
447			provider.WithGeminiSystemMessage(
448				prompt.CoderOpenAISystemPrompt(),
449			),
450			provider.WithGeminiMaxTokens(int32(maxTokens)),
451			provider.WithGeminiKey(providerConfig.APIKey),
452			provider.WithGeminiModel(model),
453		)
454		if err != nil {
455			return nil, nil, err
456		}
457		titleGenerator, err = provider.NewGeminiProvider(
458			ctx,
459			provider.WithGeminiSystemMessage(
460				prompt.TitlePrompt(),
461			),
462			provider.WithGeminiMaxTokens(80),
463			provider.WithGeminiKey(providerConfig.APIKey),
464			provider.WithGeminiModel(model),
465		)
466		if err != nil {
467			return nil, nil, err
468		}
469	case models.ProviderGROQ:
470		var err error
471		agentProvider, err = provider.NewOpenAIProvider(
472			provider.WithOpenAISystemMessage(
473				prompt.CoderAnthropicSystemPrompt(),
474			),
475			provider.WithOpenAIMaxTokens(maxTokens),
476			provider.WithOpenAIModel(model),
477			provider.WithOpenAIKey(providerConfig.APIKey),
478			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
479		)
480		if err != nil {
481			return nil, nil, err
482		}
483		titleGenerator, err = provider.NewOpenAIProvider(
484			provider.WithOpenAISystemMessage(
485				prompt.TitlePrompt(),
486			),
487			provider.WithOpenAIMaxTokens(80),
488			provider.WithOpenAIModel(model),
489			provider.WithOpenAIKey(providerConfig.APIKey),
490			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
491		)
492		if err != nil {
493			return nil, nil, err
494		}
495
496	case models.ProviderBedrock:
497		var err error
498		agentProvider, err = provider.NewBedrockProvider(
499			provider.WithBedrockSystemMessage(
500				prompt.CoderAnthropicSystemPrompt(),
501			),
502			provider.WithBedrockMaxTokens(maxTokens),
503			provider.WithBedrockModel(model),
504		)
505		if err != nil {
506			return nil, nil, err
507		}
508		titleGenerator, err = provider.NewBedrockProvider(
509			provider.WithBedrockSystemMessage(
510				prompt.TitlePrompt(),
511			),
512			provider.WithBedrockMaxTokens(maxTokens),
513			provider.WithBedrockModel(model),
514		)
515		if err != nil {
516			return nil, nil, err
517		}
518
519	}
520
521	return agentProvider, titleGenerator, nil
522}