agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/app"
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/models"
 13	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 14	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17)
 18
 19type Agent interface {
 20	Generate(sessionID string, content string) error
 21}
 22
 23type agent struct {
 24	*app.App
 25	model          models.Model
 26	tools          []tools.BaseTool
 27	agent          provider.Provider
 28	titleGenerator provider.Provider
 29}
 30
 31func (c *agent) handleTitleGeneration(sessionID, content string) {
 32	response, err := c.titleGenerator.SendMessages(
 33		c.Context,
 34		[]message.Message{
 35			{
 36				Role:    message.User,
 37				Content: content,
 38			},
 39		},
 40		nil,
 41	)
 42	if err != nil {
 43		return
 44	}
 45
 46	session, err := c.Sessions.Get(sessionID)
 47	if err != nil {
 48		return
 49	}
 50	if response.Content != "" {
 51		session.Title = response.Content
 52		c.Sessions.Save(session)
 53	}
 54}
 55
 56func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 57	session, err := c.Sessions.Get(sessionID)
 58	if err != nil {
 59		return err
 60	}
 61
 62	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 63		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 64		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 65		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 66
 67	session.Cost += cost
 68	session.CompletionTokens += usage.OutputTokens
 69	session.PromptTokens += usage.InputTokens
 70
 71	_, err = c.Sessions.Save(session)
 72	return err
 73}
 74
 75func (c *agent) processEvent(
 76	sessionID string,
 77	assistantMsg *message.Message,
 78	event provider.ProviderEvent,
 79) error {
 80	switch event.Type {
 81	case provider.EventThinkingDelta:
 82		assistantMsg.Thinking += event.Thinking
 83		return c.Messages.Update(*assistantMsg)
 84	case provider.EventContentDelta:
 85		assistantMsg.Content += event.Content
 86		return c.Messages.Update(*assistantMsg)
 87	case provider.EventError:
 88		log.Println("error", event.Error)
 89		return event.Error
 90
 91	case provider.EventComplete:
 92		assistantMsg.ToolCalls = event.Response.ToolCalls
 93		err := c.Messages.Update(*assistantMsg)
 94		if err != nil {
 95			return err
 96		}
 97		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
 98	}
 99
100	return nil
101}
102
103func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
104	var wg sync.WaitGroup
105	toolResults := make([]message.ToolResult, len(toolCalls))
106	mutex := &sync.Mutex{}
107
108	for i, tc := range toolCalls {
109		wg.Add(1)
110		go func(index int, toolCall message.ToolCall) {
111			defer wg.Done()
112
113			response := ""
114			isError := false
115			found := false
116
117			for _, tool := range tls {
118				if tool.Info().Name == toolCall.Name {
119					found = true
120					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
121						ID:    toolCall.ID,
122						Name:  toolCall.Name,
123						Input: toolCall.Input,
124					})
125					if toolErr != nil {
126						response = fmt.Sprintf("error running tool: %s", toolErr)
127						isError = true
128					} else {
129						response = toolResult.Content
130						isError = toolResult.IsError
131					}
132					break
133				}
134			}
135
136			if !found {
137				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
138				isError = true
139			}
140
141			mutex.Lock()
142			defer mutex.Unlock()
143
144			toolResults[index] = message.ToolResult{
145				ToolCallID: toolCall.ID,
146				Content:    response,
147				IsError:    isError,
148			}
149		}(i, tc)
150	}
151
152	wg.Wait()
153	return toolResults, nil
154}
155
156func (c *agent) handleToolExecution(
157	ctx context.Context,
158	assistantMsg message.Message,
159) (*message.Message, error) {
160	if len(assistantMsg.ToolCalls) == 0 {
161		return nil, nil
162	}
163
164	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
165	if err != nil {
166		return nil, err
167	}
168
169	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
170		Role:        message.Tool,
171		ToolResults: toolResults,
172	})
173
174	return &msg, err
175}
176
177func (c *agent) generate(sessionID string, content string) error {
178	messages, err := c.Messages.List(sessionID)
179	if err != nil {
180		return err
181	}
182
183	if len(messages) == 0 {
184		go c.handleTitleGeneration(sessionID, content)
185	}
186
187	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
188		Role:    message.User,
189		Content: content,
190	})
191	if err != nil {
192		return err
193	}
194
195	messages = append(messages, userMsg)
196	for {
197
198		eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
199		if err != nil {
200			return err
201		}
202
203		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
204			Role:    message.Assistant,
205			Content: "",
206		})
207		if err != nil {
208			return err
209		}
210		for event := range eventChan {
211			err = c.processEvent(sessionID, &assistantMsg, event)
212			if err != nil {
213				assistantMsg.Finished = true
214				c.Messages.Update(assistantMsg)
215				return err
216			}
217		}
218
219		msg, err := c.handleToolExecution(c.Context, assistantMsg)
220		assistantMsg.Finished = true
221		c.Messages.Update(assistantMsg)
222		if err != nil {
223			return err
224		}
225
226		if len(assistantMsg.ToolCalls) == 0 {
227			break
228		}
229
230		messages = append(messages, assistantMsg)
231		if msg != nil {
232			messages = append(messages, *msg)
233		}
234	}
235	return nil
236}
237
238func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
239	maxTokens := config.Get().Model.CoderMaxTokens
240
241	providerConfig, ok := config.Get().Providers[model.Provider]
242	if !ok || !providerConfig.Enabled {
243		return nil, nil, errors.New("provider is not enabled")
244	}
245	var agentProvider provider.Provider
246	var titleGenerator provider.Provider
247
248	switch model.Provider {
249	case models.ProviderOpenAI:
250		var err error
251		agentProvider, err = provider.NewOpenAIProvider(
252			provider.WithOpenAISystemMessage(
253				prompt.CoderOpenAISystemPrompt(),
254			),
255			provider.WithOpenAIMaxTokens(maxTokens),
256			provider.WithOpenAIModel(model),
257			provider.WithOpenAIKey(providerConfig.APIKey),
258		)
259		if err != nil {
260			return nil, nil, err
261		}
262		titleGenerator, err = provider.NewOpenAIProvider(
263			provider.WithOpenAISystemMessage(
264				prompt.TitlePrompt(),
265			),
266			provider.WithOpenAIMaxTokens(80),
267			provider.WithOpenAIModel(model),
268			provider.WithOpenAIKey(providerConfig.APIKey),
269		)
270		if err != nil {
271			return nil, nil, err
272		}
273	case models.ProviderAnthropic:
274		var err error
275		agentProvider, err = provider.NewAnthropicProvider(
276			provider.WithAnthropicSystemMessage(
277				prompt.CoderAnthropicSystemPrompt(),
278			),
279			provider.WithAnthropicMaxTokens(maxTokens),
280			provider.WithAnthropicKey(providerConfig.APIKey),
281			provider.WithAnthropicModel(model),
282		)
283		if err != nil {
284			return nil, nil, err
285		}
286		titleGenerator, err = provider.NewAnthropicProvider(
287			provider.WithAnthropicSystemMessage(
288				prompt.TitlePrompt(),
289			),
290			provider.WithAnthropicMaxTokens(80),
291			provider.WithAnthropicKey(providerConfig.APIKey),
292			provider.WithAnthropicModel(model),
293		)
294		if err != nil {
295			return nil, nil, err
296		}
297
298	case models.ProviderGemini:
299		var err error
300		agentProvider, err = provider.NewGeminiProvider(
301			ctx,
302			provider.WithGeminiSystemMessage(
303				prompt.CoderOpenAISystemPrompt(),
304			),
305			provider.WithGeminiMaxTokens(int32(maxTokens)),
306			provider.WithGeminiKey(providerConfig.APIKey),
307			provider.WithGeminiModel(model),
308		)
309		if err != nil {
310			return nil, nil, err
311		}
312		titleGenerator, err = provider.NewGeminiProvider(
313			ctx,
314			provider.WithGeminiSystemMessage(
315				prompt.TitlePrompt(),
316			),
317			provider.WithGeminiMaxTokens(80),
318			provider.WithGeminiKey(providerConfig.APIKey),
319			provider.WithGeminiModel(model),
320		)
321		if err != nil {
322			return nil, nil, err
323		}
324	case models.ProviderGROQ:
325		var err error
326		agentProvider, err = provider.NewOpenAIProvider(
327			provider.WithOpenAISystemMessage(
328				prompt.CoderAnthropicSystemPrompt(),
329			),
330			provider.WithOpenAIMaxTokens(maxTokens),
331			provider.WithOpenAIModel(model),
332			provider.WithOpenAIKey(providerConfig.APIKey),
333			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
334		)
335		if err != nil {
336			return nil, nil, err
337		}
338		titleGenerator, err = provider.NewOpenAIProvider(
339			provider.WithOpenAISystemMessage(
340				prompt.TitlePrompt(),
341			),
342			provider.WithOpenAIMaxTokens(80),
343			provider.WithOpenAIModel(model),
344			provider.WithOpenAIKey(providerConfig.APIKey),
345			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
346		)
347		if err != nil {
348			return nil, nil, err
349		}
350
351	}
352
353	return agentProvider, titleGenerator, nil
354}