agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log"
  8	"strings"
  9	"sync"
 10
 11	"github.com/kujtimiihoxha/termai/internal/app"
 12	"github.com/kujtimiihoxha/termai/internal/config"
 13	"github.com/kujtimiihoxha/termai/internal/llm/models"
 14	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 15	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 16	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18	"github.com/kujtimiihoxha/termai/internal/pubsub"
 19	"github.com/kujtimiihoxha/termai/internal/tui/util"
 20)
 21
 22type Agent interface {
 23	Generate(sessionID string, content string) error
 24}
 25
 26type agent struct {
 27	*app.App
 28	model          models.Model
 29	tools          []tools.BaseTool
 30	agent          provider.Provider
 31	titleGenerator provider.Provider
 32}
 33
 34func (c *agent) handleTitleGeneration(sessionID, content string) {
 35	response, err := c.titleGenerator.SendMessages(
 36		c.Context,
 37		[]message.Message{
 38			{
 39				Role: message.User,
 40				Parts: []message.ContentPart{
 41					message.TextContent{
 42						Text: content,
 43					},
 44				},
 45			},
 46		},
 47		nil,
 48	)
 49	if err != nil {
 50		return
 51	}
 52
 53	session, err := c.Sessions.Get(sessionID)
 54	if err != nil {
 55		return
 56	}
 57	if response.Content != "" {
 58		session.Title = response.Content
 59		session.Title = strings.TrimSpace(session.Title)
 60		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
 61		c.Sessions.Save(session)
 62	}
 63}
 64
 65func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
 66	session, err := c.Sessions.Get(sessionID)
 67	if err != nil {
 68		return err
 69	}
 70
 71	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 72		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 73		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 74		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 75
 76	session.Cost += cost
 77	session.CompletionTokens += usage.OutputTokens
 78	session.PromptTokens += usage.InputTokens
 79
 80	_, err = c.Sessions.Save(session)
 81	return err
 82}
 83
 84func (c *agent) processEvent(
 85	sessionID string,
 86	assistantMsg *message.Message,
 87	event provider.ProviderEvent,
 88) error {
 89	switch event.Type {
 90	case provider.EventThinkingDelta:
 91		assistantMsg.AppendReasoningContent(event.Content)
 92		return c.Messages.Update(*assistantMsg)
 93	case provider.EventContentDelta:
 94		assistantMsg.AppendContent(event.Content)
 95		return c.Messages.Update(*assistantMsg)
 96	case provider.EventError:
 97		// TODO: remove when realease
 98		log.Println("error", event.Error)
 99		c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
100			Type: util.InfoTypeError,
101			Msg:  event.Error.Error(),
102		})
103		return event.Error
104	case provider.EventWarning:
105		c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
106			Type: util.InfoTypeWarn,
107			Msg:  event.Info,
108		})
109		return nil
110	case provider.EventInfo:
111		c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
112			Type: util.InfoTypeInfo,
113			Msg:  event.Info,
114		})
115	case provider.EventComplete:
116		assistantMsg.SetToolCalls(event.Response.ToolCalls)
117		assistantMsg.AddFinish(event.Response.FinishReason)
118		err := c.Messages.Update(*assistantMsg)
119		if err != nil {
120			return err
121		}
122		return c.TrackUsage(sessionID, c.model, event.Response.Usage)
123	}
124
125	return nil
126}
127
128func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
129	var wg sync.WaitGroup
130	toolResults := make([]message.ToolResult, len(toolCalls))
131	mutex := &sync.Mutex{}
132
133	for i, tc := range toolCalls {
134		wg.Add(1)
135		go func(index int, toolCall message.ToolCall) {
136			defer wg.Done()
137
138			response := ""
139			isError := false
140			found := false
141
142			for _, tool := range tls {
143				if tool.Info().Name == toolCall.Name {
144					found = true
145					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
146						ID:    toolCall.ID,
147						Name:  toolCall.Name,
148						Input: toolCall.Input,
149					})
150					if toolErr != nil {
151						response = fmt.Sprintf("error running tool: %s", toolErr)
152						isError = true
153					} else {
154						response = toolResult.Content
155						isError = toolResult.IsError
156					}
157					break
158				}
159			}
160
161			if !found {
162				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
163				isError = true
164			}
165
166			mutex.Lock()
167			defer mutex.Unlock()
168
169			toolResults[index] = message.ToolResult{
170				ToolCallID: toolCall.ID,
171				Content:    response,
172				IsError:    isError,
173			}
174		}(i, tc)
175	}
176
177	wg.Wait()
178	return toolResults, nil
179}
180
181func (c *agent) handleToolExecution(
182	ctx context.Context,
183	assistantMsg message.Message,
184) (*message.Message, error) {
185	if len(assistantMsg.ToolCalls()) == 0 {
186		return nil, nil
187	}
188
189	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
190	if err != nil {
191		return nil, err
192	}
193	parts := make([]message.ContentPart, 0)
194	for _, toolResult := range toolResults {
195		parts = append(parts, toolResult)
196	}
197	msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
198		Role:  message.Tool,
199		Parts: parts,
200	})
201
202	return &msg, err
203}
204
205func (c *agent) generate(sessionID string, content string) error {
206	messages, err := c.Messages.List(sessionID)
207	if err != nil {
208		return err
209	}
210
211	if len(messages) == 0 {
212		go c.handleTitleGeneration(sessionID, content)
213	}
214
215	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
216		Role: message.User,
217		Parts: []message.ContentPart{
218			message.TextContent{
219				Text: content,
220			},
221		},
222	})
223	if err != nil {
224		return err
225	}
226
227	messages = append(messages, userMsg)
228	for {
229
230		eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
231		if err != nil {
232			return err
233		}
234
235		assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
236			Role:  message.Assistant,
237			Parts: []message.ContentPart{},
238		})
239		if err != nil {
240			return err
241		}
242		for event := range eventChan {
243			err = c.processEvent(sessionID, &assistantMsg, event)
244			if err != nil {
245				assistantMsg.AddFinish("error:" + err.Error())
246				c.Messages.Update(assistantMsg)
247				return err
248			}
249		}
250
251		msg, err := c.handleToolExecution(c.Context, assistantMsg)
252
253		c.Messages.Update(assistantMsg)
254		if err != nil {
255			return err
256		}
257
258		if len(assistantMsg.ToolCalls()) == 0 {
259			break
260		}
261
262		messages = append(messages, assistantMsg)
263		if msg != nil {
264			messages = append(messages, *msg)
265		}
266	}
267	return nil
268}
269
270func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
271	maxTokens := config.Get().Model.CoderMaxTokens
272
273	providerConfig, ok := config.Get().Providers[model.Provider]
274	if !ok || !providerConfig.Enabled {
275		return nil, nil, errors.New("provider is not enabled")
276	}
277	var agentProvider provider.Provider
278	var titleGenerator provider.Provider
279
280	switch model.Provider {
281	case models.ProviderOpenAI:
282		var err error
283		agentProvider, err = provider.NewOpenAIProvider(
284			provider.WithOpenAISystemMessage(
285				prompt.CoderOpenAISystemPrompt(),
286			),
287			provider.WithOpenAIMaxTokens(maxTokens),
288			provider.WithOpenAIModel(model),
289			provider.WithOpenAIKey(providerConfig.APIKey),
290		)
291		if err != nil {
292			return nil, nil, err
293		}
294		titleGenerator, err = provider.NewOpenAIProvider(
295			provider.WithOpenAISystemMessage(
296				prompt.TitlePrompt(),
297			),
298			provider.WithOpenAIMaxTokens(80),
299			provider.WithOpenAIModel(model),
300			provider.WithOpenAIKey(providerConfig.APIKey),
301		)
302		if err != nil {
303			return nil, nil, err
304		}
305	case models.ProviderAnthropic:
306		var err error
307		agentProvider, err = provider.NewAnthropicProvider(
308			provider.WithAnthropicSystemMessage(
309				prompt.CoderAnthropicSystemPrompt(),
310			),
311			provider.WithAnthropicMaxTokens(maxTokens),
312			provider.WithAnthropicKey(providerConfig.APIKey),
313			provider.WithAnthropicModel(model),
314		)
315		if err != nil {
316			return nil, nil, err
317		}
318		titleGenerator, err = provider.NewAnthropicProvider(
319			provider.WithAnthropicSystemMessage(
320				prompt.TitlePrompt(),
321			),
322			provider.WithAnthropicMaxTokens(80),
323			provider.WithAnthropicKey(providerConfig.APIKey),
324			provider.WithAnthropicModel(model),
325		)
326		if err != nil {
327			return nil, nil, err
328		}
329
330	case models.ProviderGemini:
331		var err error
332		agentProvider, err = provider.NewGeminiProvider(
333			ctx,
334			provider.WithGeminiSystemMessage(
335				prompt.CoderOpenAISystemPrompt(),
336			),
337			provider.WithGeminiMaxTokens(int32(maxTokens)),
338			provider.WithGeminiKey(providerConfig.APIKey),
339			provider.WithGeminiModel(model),
340		)
341		if err != nil {
342			return nil, nil, err
343		}
344		titleGenerator, err = provider.NewGeminiProvider(
345			ctx,
346			provider.WithGeminiSystemMessage(
347				prompt.TitlePrompt(),
348			),
349			provider.WithGeminiMaxTokens(80),
350			provider.WithGeminiKey(providerConfig.APIKey),
351			provider.WithGeminiModel(model),
352		)
353		if err != nil {
354			return nil, nil, err
355		}
356	case models.ProviderGROQ:
357		var err error
358		agentProvider, err = provider.NewOpenAIProvider(
359			provider.WithOpenAISystemMessage(
360				prompt.CoderAnthropicSystemPrompt(),
361			),
362			provider.WithOpenAIMaxTokens(maxTokens),
363			provider.WithOpenAIModel(model),
364			provider.WithOpenAIKey(providerConfig.APIKey),
365			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
366		)
367		if err != nil {
368			return nil, nil, err
369		}
370		titleGenerator, err = provider.NewOpenAIProvider(
371			provider.WithOpenAISystemMessage(
372				prompt.TitlePrompt(),
373			),
374			provider.WithOpenAIMaxTokens(80),
375			provider.WithOpenAIModel(model),
376			provider.WithOpenAIKey(providerConfig.APIKey),
377			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
378		)
379		if err != nil {
380			return nil, nil, err
381		}
382
383	case models.ProviderBedrock:
384		var err error
385		agentProvider, err = provider.NewBedrockProvider(
386			provider.WithBedrockSystemMessage(
387				prompt.CoderAnthropicSystemPrompt(),
388			),
389			provider.WithBedrockMaxTokens(maxTokens),
390			provider.WithBedrockModel(model),
391		)
392		if err != nil {
393			return nil, nil, err
394		}
395		titleGenerator, err = provider.NewBedrockProvider(
396			provider.WithBedrockSystemMessage(
397				prompt.TitlePrompt(),
398			),
399			provider.WithBedrockMaxTokens(maxTokens),
400			provider.WithBedrockModel(model),
401		)
402		if err != nil {
403			return nil, nil, err
404		}
405
406	}
407
408	return agentProvider, titleGenerator, nil
409}