1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "strings"
9 "sync"
10
11 "github.com/kujtimiihoxha/termai/internal/app"
12 "github.com/kujtimiihoxha/termai/internal/config"
13 "github.com/kujtimiihoxha/termai/internal/llm/models"
14 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
15 "github.com/kujtimiihoxha/termai/internal/llm/provider"
16 "github.com/kujtimiihoxha/termai/internal/llm/tools"
17 "github.com/kujtimiihoxha/termai/internal/message"
18 "github.com/kujtimiihoxha/termai/internal/pubsub"
19 "github.com/kujtimiihoxha/termai/internal/tui/util"
20)
21
22type Agent interface {
23 Generate(sessionID string, content string) error
24}
25
26type agent struct {
27 *app.App
28 model models.Model
29 tools []tools.BaseTool
30 agent provider.Provider
31 titleGenerator provider.Provider
32}
33
34func (c *agent) handleTitleGeneration(sessionID, content string) {
35 response, err := c.titleGenerator.SendMessages(
36 c.Context,
37 []message.Message{
38 {
39 Role: message.User,
40 Parts: []message.ContentPart{
41 message.TextContent{
42 Text: content,
43 },
44 },
45 },
46 },
47 nil,
48 )
49 if err != nil {
50 return
51 }
52
53 session, err := c.Sessions.Get(sessionID)
54 if err != nil {
55 return
56 }
57 if response.Content != "" {
58 session.Title = response.Content
59 session.Title = strings.TrimSpace(session.Title)
60 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
61 c.Sessions.Save(session)
62 }
63}
64
65func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
66 session, err := c.Sessions.Get(sessionID)
67 if err != nil {
68 return err
69 }
70
71 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
72 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
73 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
74 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
75
76 session.Cost += cost
77 session.CompletionTokens += usage.OutputTokens
78 session.PromptTokens += usage.InputTokens
79
80 _, err = c.Sessions.Save(session)
81 return err
82}
83
84func (c *agent) processEvent(
85 sessionID string,
86 assistantMsg *message.Message,
87 event provider.ProviderEvent,
88) error {
89 switch event.Type {
90 case provider.EventThinkingDelta:
91 assistantMsg.AppendReasoningContent(event.Content)
92 return c.Messages.Update(*assistantMsg)
93 case provider.EventContentDelta:
94 assistantMsg.AppendContent(event.Content)
95 return c.Messages.Update(*assistantMsg)
96 case provider.EventError:
97 // TODO: remove when realease
98 log.Println("error", event.Error)
99 c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
100 Type: util.InfoTypeError,
101 Msg: event.Error.Error(),
102 })
103 return event.Error
104 case provider.EventWarning:
105 c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
106 Type: util.InfoTypeWarn,
107 Msg: event.Info,
108 })
109 return nil
110 case provider.EventInfo:
111 c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
112 Type: util.InfoTypeInfo,
113 Msg: event.Info,
114 })
115 case provider.EventComplete:
116 assistantMsg.SetToolCalls(event.Response.ToolCalls)
117 assistantMsg.AddFinish(event.Response.FinishReason)
118 err := c.Messages.Update(*assistantMsg)
119 if err != nil {
120 return err
121 }
122 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
123 }
124
125 return nil
126}
127
128func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
129 var wg sync.WaitGroup
130 toolResults := make([]message.ToolResult, len(toolCalls))
131 mutex := &sync.Mutex{}
132
133 for i, tc := range toolCalls {
134 wg.Add(1)
135 go func(index int, toolCall message.ToolCall) {
136 defer wg.Done()
137
138 response := ""
139 isError := false
140 found := false
141
142 for _, tool := range tls {
143 if tool.Info().Name == toolCall.Name {
144 found = true
145 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
146 ID: toolCall.ID,
147 Name: toolCall.Name,
148 Input: toolCall.Input,
149 })
150 if toolErr != nil {
151 response = fmt.Sprintf("error running tool: %s", toolErr)
152 isError = true
153 } else {
154 response = toolResult.Content
155 isError = toolResult.IsError
156 }
157 break
158 }
159 }
160
161 if !found {
162 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
163 isError = true
164 }
165
166 mutex.Lock()
167 defer mutex.Unlock()
168
169 toolResults[index] = message.ToolResult{
170 ToolCallID: toolCall.ID,
171 Content: response,
172 IsError: isError,
173 }
174 }(i, tc)
175 }
176
177 wg.Wait()
178 return toolResults, nil
179}
180
181func (c *agent) handleToolExecution(
182 ctx context.Context,
183 assistantMsg message.Message,
184) (*message.Message, error) {
185 if len(assistantMsg.ToolCalls()) == 0 {
186 return nil, nil
187 }
188
189 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
190 if err != nil {
191 return nil, err
192 }
193 parts := make([]message.ContentPart, 0)
194 for _, toolResult := range toolResults {
195 parts = append(parts, toolResult)
196 }
197 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
198 Role: message.Tool,
199 Parts: parts,
200 })
201
202 return &msg, err
203}
204
205func (c *agent) generate(sessionID string, content string) error {
206 messages, err := c.Messages.List(sessionID)
207 if err != nil {
208 return err
209 }
210
211 if len(messages) == 0 {
212 go c.handleTitleGeneration(sessionID, content)
213 }
214
215 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
216 Role: message.User,
217 Parts: []message.ContentPart{
218 message.TextContent{
219 Text: content,
220 },
221 },
222 })
223 if err != nil {
224 return err
225 }
226
227 messages = append(messages, userMsg)
228 for {
229
230 eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
231 if err != nil {
232 return err
233 }
234
235 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
236 Role: message.Assistant,
237 Parts: []message.ContentPart{},
238 })
239 if err != nil {
240 return err
241 }
242 for event := range eventChan {
243 err = c.processEvent(sessionID, &assistantMsg, event)
244 if err != nil {
245 assistantMsg.AddFinish("error:" + err.Error())
246 c.Messages.Update(assistantMsg)
247 return err
248 }
249 }
250
251 msg, err := c.handleToolExecution(c.Context, assistantMsg)
252
253 c.Messages.Update(assistantMsg)
254 if err != nil {
255 return err
256 }
257
258 if len(assistantMsg.ToolCalls()) == 0 {
259 break
260 }
261
262 messages = append(messages, assistantMsg)
263 if msg != nil {
264 messages = append(messages, *msg)
265 }
266 }
267 return nil
268}
269
270func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
271 maxTokens := config.Get().Model.CoderMaxTokens
272
273 providerConfig, ok := config.Get().Providers[model.Provider]
274 if !ok || !providerConfig.Enabled {
275 return nil, nil, errors.New("provider is not enabled")
276 }
277 var agentProvider provider.Provider
278 var titleGenerator provider.Provider
279
280 switch model.Provider {
281 case models.ProviderOpenAI:
282 var err error
283 agentProvider, err = provider.NewOpenAIProvider(
284 provider.WithOpenAISystemMessage(
285 prompt.CoderOpenAISystemPrompt(),
286 ),
287 provider.WithOpenAIMaxTokens(maxTokens),
288 provider.WithOpenAIModel(model),
289 provider.WithOpenAIKey(providerConfig.APIKey),
290 )
291 if err != nil {
292 return nil, nil, err
293 }
294 titleGenerator, err = provider.NewOpenAIProvider(
295 provider.WithOpenAISystemMessage(
296 prompt.TitlePrompt(),
297 ),
298 provider.WithOpenAIMaxTokens(80),
299 provider.WithOpenAIModel(model),
300 provider.WithOpenAIKey(providerConfig.APIKey),
301 )
302 if err != nil {
303 return nil, nil, err
304 }
305 case models.ProviderAnthropic:
306 var err error
307 agentProvider, err = provider.NewAnthropicProvider(
308 provider.WithAnthropicSystemMessage(
309 prompt.CoderAnthropicSystemPrompt(),
310 ),
311 provider.WithAnthropicMaxTokens(maxTokens),
312 provider.WithAnthropicKey(providerConfig.APIKey),
313 provider.WithAnthropicModel(model),
314 )
315 if err != nil {
316 return nil, nil, err
317 }
318 titleGenerator, err = provider.NewAnthropicProvider(
319 provider.WithAnthropicSystemMessage(
320 prompt.TitlePrompt(),
321 ),
322 provider.WithAnthropicMaxTokens(80),
323 provider.WithAnthropicKey(providerConfig.APIKey),
324 provider.WithAnthropicModel(model),
325 )
326 if err != nil {
327 return nil, nil, err
328 }
329
330 case models.ProviderGemini:
331 var err error
332 agentProvider, err = provider.NewGeminiProvider(
333 ctx,
334 provider.WithGeminiSystemMessage(
335 prompt.CoderOpenAISystemPrompt(),
336 ),
337 provider.WithGeminiMaxTokens(int32(maxTokens)),
338 provider.WithGeminiKey(providerConfig.APIKey),
339 provider.WithGeminiModel(model),
340 )
341 if err != nil {
342 return nil, nil, err
343 }
344 titleGenerator, err = provider.NewGeminiProvider(
345 ctx,
346 provider.WithGeminiSystemMessage(
347 prompt.TitlePrompt(),
348 ),
349 provider.WithGeminiMaxTokens(80),
350 provider.WithGeminiKey(providerConfig.APIKey),
351 provider.WithGeminiModel(model),
352 )
353 if err != nil {
354 return nil, nil, err
355 }
356 case models.ProviderGROQ:
357 var err error
358 agentProvider, err = provider.NewOpenAIProvider(
359 provider.WithOpenAISystemMessage(
360 prompt.CoderAnthropicSystemPrompt(),
361 ),
362 provider.WithOpenAIMaxTokens(maxTokens),
363 provider.WithOpenAIModel(model),
364 provider.WithOpenAIKey(providerConfig.APIKey),
365 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
366 )
367 if err != nil {
368 return nil, nil, err
369 }
370 titleGenerator, err = provider.NewOpenAIProvider(
371 provider.WithOpenAISystemMessage(
372 prompt.TitlePrompt(),
373 ),
374 provider.WithOpenAIMaxTokens(80),
375 provider.WithOpenAIModel(model),
376 provider.WithOpenAIKey(providerConfig.APIKey),
377 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
378 )
379 if err != nil {
380 return nil, nil, err
381 }
382
383 case models.ProviderBedrock:
384 var err error
385 agentProvider, err = provider.NewBedrockProvider(
386 provider.WithBedrockSystemMessage(
387 prompt.CoderAnthropicSystemPrompt(),
388 ),
389 provider.WithBedrockMaxTokens(maxTokens),
390 provider.WithBedrockModel(model),
391 )
392 if err != nil {
393 return nil, nil, err
394 }
395 titleGenerator, err = provider.NewBedrockProvider(
396 provider.WithBedrockSystemMessage(
397 prompt.TitlePrompt(),
398 ),
399 provider.WithBedrockMaxTokens(maxTokens),
400 provider.WithBedrockModel(model),
401 )
402 if err != nil {
403 return nil, nil, err
404 }
405
406 }
407
408 return agentProvider, titleGenerator, nil
409}