agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/app"
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/models"
 13	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 14	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17)
 18
 19type Agent interface {
 20	Generate(sessionID string, content string) error
 21}
 22
 23type agent struct {
 24	*app.App
 25	model          models.Model
 26	tools          []tools.BaseTool
 27	agent          provider.Provider
 28	titleGenerator provider.Provider
 29}
 30
 31func (c *agent) handleTitleGeneration(sessionID, content string) {
 32	response, err := c.titleGenerator.SendMessages(
 33		c.Context,
 34		[]message.Message{
 35			{
 36				Role: message.User,
 37				Parts: []message.ContentPart{
 38					message.TextContent{
 39						Text: content,
 40					},
 41				},
 42			},
 43		},
 44		nil,
 45	)
 46	if err != nil {
 47		return
 48	}
 49
 50	session, err := c.Sessions.Get(sessionID)
 51	if err != nil {
 52		return
 53	}
 54	if response.Content != "" {
 55		session.Title = response.Content
 56		session.Title = strings.TrimSpace(session.Title)
 57		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 58		c.Sessions.Save(session)
 59	}
 60}
 61
 62func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 63	session, err := c.Sessions.Get(sessionID)
 64	if err != nil {
 65		return err
 66	}
 67
 68	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 69		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 70		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 71		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 72
 73	session.Cost += cost
 74	session.CompletionTokens += usage.OutputTokens
 75	session.PromptTokens += usage.InputTokens
 76
 77	_, err = c.Sessions.Save(session)
 78	return err
 79}
 80
 81func (c *agent) processEvent(
 82	sessionID string,
 83	assistantMsg *message.Message,
 84	event provider.ProviderEvent,
 85) error {
 86	switch event.Type {
 87	case provider.EventThinkingDelta:
 88		assistantMsg.AppendReasoningContent(event.Content)
 89		return c.Messages.Update(*assistantMsg)
 90	case provider.EventContentDelta:
 91		assistantMsg.AppendContent(event.Content)
 92		return c.Messages.Update(*assistantMsg)
 93	case provider.EventError:
 94		c.App.Logger.PersistError(event.Error.Error())
 95		return event.Error
 96	case provider.EventWarning:
 97		c.App.Logger.PersistWarn(event.Info)
 98		return nil
 99	case provider.EventInfo:
100		c.App.Logger.PersistInfo(event.Info)
101	case provider.EventComplete:
102		assistantMsg.SetToolCalls(event.Response.ToolCalls)
103		assistantMsg.AddFinish(event.Response.FinishReason)
104		err := c.Messages.Update(*assistantMsg)
105		if err != nil {
106			return err
107		}
108		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
109	}
110
111	return nil
112}
113
114func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
115	var wg sync.WaitGroup
116	toolResults := make([]message.ToolResult, len(toolCalls))
117	mutex := &sync.Mutex{}
118
119	for i, tc := range toolCalls {
120		wg.Add(1)
121		go func(index int, toolCall message.ToolCall) {
122			defer wg.Done()
123
124			response := ""
125			isError := false
126			found := false
127
128			for _, tool := range tls {
129				if tool.Info().Name == toolCall.Name {
130					found = true
131					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
132						ID:    toolCall.ID,
133						Name:  toolCall.Name,
134						Input: toolCall.Input,
135					})
136					if toolErr != nil {
137						response = fmt.Sprintf("error running tool: %s", toolErr)
138						isError = true
139					} else {
140						response = toolResult.Content
141						isError = toolResult.IsError
142					}
143					break
144				}
145			}
146
147			if !found {
148				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
149				isError = true
150			}
151
152			mutex.Lock()
153			defer mutex.Unlock()
154
155			toolResults[index] = message.ToolResult{
156				ToolCallID: toolCall.ID,
157				Content:    response,
158				IsError:    isError,
159			}
160		}(i, tc)
161	}
162
163	wg.Wait()
164	return toolResults, nil
165}
166
167func (c *agent) handleToolExecution(
168	ctx context.Context,
169	assistantMsg message.Message,
170) (*message.Message, error) {
171	if len(assistantMsg.ToolCalls()) == 0 {
172		return nil, nil
173	}
174
175	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
176	if err != nil {
177		return nil, err
178	}
179	parts := make([]message.ContentPart, 0)
180	for _, toolResult := range toolResults {
181		parts = append(parts, toolResult)
182	}
183	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
184		Role:  message.Tool,
185		Parts: parts,
186	})
187
188	return &msg, err
189}
190
191func (c *agent) generate(sessionID string, content string) error {
192	messages, err := c.Messages.List(sessionID)
193	if err != nil {
194		return err
195	}
196
197	if len(messages) == 0 {
198		go c.handleTitleGeneration(sessionID, content)
199	}
200
201	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
202		Role: message.User,
203		Parts: []message.ContentPart{
204			message.TextContent{
205				Text: content,
206			},
207		},
208	})
209	if err != nil {
210		return err
211	}
212
213	messages = append(messages, userMsg)
214	for {
215
216		eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
217		if err != nil {
218			return err
219		}
220
221		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
222			Role:  message.Assistant,
223			Parts: []message.ContentPart{},
224		})
225		if err != nil {
226			return err
227		}
228		for event := range eventChan {
229			err = c.processEvent(sessionID, &assistantMsg, event)
230			if err != nil {
231				assistantMsg.AddFinish("error:" + err.Error())
232				c.Messages.Update(assistantMsg)
233				return err
234			}
235		}
236
237		msg, err := c.handleToolExecution(c.Context, assistantMsg)
238
239		c.Messages.Update(assistantMsg)
240		if err != nil {
241			return err
242		}
243
244		if len(assistantMsg.ToolCalls()) == 0 {
245			break
246		}
247
248		messages = append(messages, assistantMsg)
249		if msg != nil {
250			messages = append(messages, *msg)
251		}
252	}
253	return nil
254}
255
256func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
257	maxTokens := config.Get().Model.CoderMaxTokens
258
259	providerConfig, ok := config.Get().Providers[model.Provider]
260	if !ok || !providerConfig.Enabled {
261		return nil, nil, errors.New("provider is not enabled")
262	}
263	var agentProvider provider.Provider
264	var titleGenerator provider.Provider
265
266	switch model.Provider {
267	case models.ProviderOpenAI:
268		var err error
269		agentProvider, err = provider.NewOpenAIProvider(
270			provider.WithOpenAISystemMessage(
271				prompt.CoderOpenAISystemPrompt(),
272			),
273			provider.WithOpenAIMaxTokens(maxTokens),
274			provider.WithOpenAIModel(model),
275			provider.WithOpenAIKey(providerConfig.APIKey),
276		)
277		if err != nil {
278			return nil, nil, err
279		}
280		titleGenerator, err = provider.NewOpenAIProvider(
281			provider.WithOpenAISystemMessage(
282				prompt.TitlePrompt(),
283			),
284			provider.WithOpenAIMaxTokens(80),
285			provider.WithOpenAIModel(model),
286			provider.WithOpenAIKey(providerConfig.APIKey),
287		)
288		if err != nil {
289			return nil, nil, err
290		}
291	case models.ProviderAnthropic:
292		var err error
293		agentProvider, err = provider.NewAnthropicProvider(
294			provider.WithAnthropicSystemMessage(
295				prompt.CoderAnthropicSystemPrompt(),
296			),
297			provider.WithAnthropicMaxTokens(maxTokens),
298			provider.WithAnthropicKey(providerConfig.APIKey),
299			provider.WithAnthropicModel(model),
300		)
301		if err != nil {
302			return nil, nil, err
303		}
304		titleGenerator, err = provider.NewAnthropicProvider(
305			provider.WithAnthropicSystemMessage(
306				prompt.TitlePrompt(),
307			),
308			provider.WithAnthropicMaxTokens(80),
309			provider.WithAnthropicKey(providerConfig.APIKey),
310			provider.WithAnthropicModel(model),
311		)
312		if err != nil {
313			return nil, nil, err
314		}
315
316	case models.ProviderGemini:
317		var err error
318		agentProvider, err = provider.NewGeminiProvider(
319			ctx,
320			provider.WithGeminiSystemMessage(
321				prompt.CoderOpenAISystemPrompt(),
322			),
323			provider.WithGeminiMaxTokens(int32(maxTokens)),
324			provider.WithGeminiKey(providerConfig.APIKey),
325			provider.WithGeminiModel(model),
326		)
327		if err != nil {
328			return nil, nil, err
329		}
330		titleGenerator, err = provider.NewGeminiProvider(
331			ctx,
332			provider.WithGeminiSystemMessage(
333				prompt.TitlePrompt(),
334			),
335			provider.WithGeminiMaxTokens(80),
336			provider.WithGeminiKey(providerConfig.APIKey),
337			provider.WithGeminiModel(model),
338		)
339		if err != nil {
340			return nil, nil, err
341		}
342	case models.ProviderGROQ:
343		var err error
344		agentProvider, err = provider.NewOpenAIProvider(
345			provider.WithOpenAISystemMessage(
346				prompt.CoderAnthropicSystemPrompt(),
347			),
348			provider.WithOpenAIMaxTokens(maxTokens),
349			provider.WithOpenAIModel(model),
350			provider.WithOpenAIKey(providerConfig.APIKey),
351			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
352		)
353		if err != nil {
354			return nil, nil, err
355		}
356		titleGenerator, err = provider.NewOpenAIProvider(
357			provider.WithOpenAISystemMessage(
358				prompt.TitlePrompt(),
359			),
360			provider.WithOpenAIMaxTokens(80),
361			provider.WithOpenAIModel(model),
362			provider.WithOpenAIKey(providerConfig.APIKey),
363			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
364		)
365		if err != nil {
366			return nil, nil, err
367		}
368
369	case models.ProviderBedrock:
370		var err error
371		agentProvider, err = provider.NewBedrockProvider(
372			provider.WithBedrockSystemMessage(
373				prompt.CoderAnthropicSystemPrompt(),
374			),
375			provider.WithBedrockMaxTokens(maxTokens),
376			provider.WithBedrockModel(model),
377		)
378		if err != nil {
379			return nil, nil, err
380		}
381		titleGenerator, err = provider.NewBedrockProvider(
382			provider.WithBedrockSystemMessage(
383				prompt.TitlePrompt(),
384			),
385			provider.WithBedrockMaxTokens(maxTokens),
386			provider.WithBedrockModel(model),
387		)
388		if err != nil {
389			return nil, nil, err
390		}
391
392	}
393
394	return agentProvider, titleGenerator, nil
395}