1package llm
2
3import (
4 "context"
5 "log"
6 "sync"
7 "time"
8
9 "github.com/cloudwego/eino/callbacks"
10 "github.com/cloudwego/eino/compose"
11 "github.com/cloudwego/eino/schema"
12 "github.com/google/uuid"
13 "github.com/kujtimiihoxha/termai/internal/llm/agent"
14 "github.com/kujtimiihoxha/termai/internal/llm/models"
15 "github.com/kujtimiihoxha/termai/internal/logging"
16 "github.com/kujtimiihoxha/termai/internal/message"
17 "github.com/kujtimiihoxha/termai/internal/pubsub"
18 "github.com/kujtimiihoxha/termai/internal/session"
19
20 eModel "github.com/cloudwego/eino/components/model"
21 enioAgent "github.com/cloudwego/eino/flow/agent"
22 "github.com/spf13/viper"
23)
24
25const (
26 AgentRequestoEvent pubsub.EventType = "agent_request"
27 AgentErrorEvent pubsub.EventType = "agent_error"
28 AgentResponseEvent pubsub.EventType = "agent_response"
29)
30
31type AgentMessageType int
32
33const (
34 AgentMessageTypeNewUserMessage AgentMessageType = iota
35 AgentMessageTypeAgentResponse
36 AgentMessageTypeError
37)
38
39type agentID string
40
41const (
42 RootAgent agentID = "root"
43 TaskAgent agentID = "task"
44)
45
46type AgentEvent struct {
47 ID string `json:"id"`
48 Type AgentMessageType `json:"type"`
49 AgentID agentID `json:"agent_id"`
50 MessageID string `json:"message_id"`
51 SessionID string `json:"session_id"`
52 Content string `json:"content"`
53}
54
55type Service interface {
56 pubsub.Suscriber[AgentEvent]
57
58 SendRequest(sessionID string, content string)
59}
60type service struct {
61 *pubsub.Broker[AgentEvent]
62 Requests sync.Map
63 ctx context.Context
64 activeRequests sync.Map
65 messages message.Service
66 sessions session.Service
67 logger logging.Interface
68}
69
70func (s *service) handleRequest(id string, sessionID string, content string) {
71 cancel, ok := s.activeRequests.Load(id)
72 if !ok {
73 return
74 }
75 defer cancel.(context.CancelFunc)()
76 defer s.activeRequests.Delete(id)
77
78 history, err := s.messages.List(sessionID)
79 if err != nil {
80 s.Publish(AgentErrorEvent, AgentEvent{
81 ID: id,
82 Type: AgentMessageTypeError,
83 AgentID: RootAgent,
84 MessageID: "",
85 SessionID: sessionID,
86 Content: err.Error(),
87 })
88 return
89 }
90
91 currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
92 if err != nil {
93 s.Publish(AgentErrorEvent, AgentEvent{
94 ID: id,
95 Type: AgentMessageTypeError,
96 AgentID: RootAgent,
97 MessageID: "",
98 SessionID: sessionID,
99 Content: err.Error(),
100 })
101 return
102 }
103
104 messages := []*schema.Message{
105 {
106 Role: schema.System,
107 Content: systemMessage,
108 },
109 }
110 for _, m := range history {
111 messages = append(messages, &m.MessageData)
112 }
113
114 builder := callbacks.NewHandlerBuilder()
115 builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
116 i, ok := input.(*eModel.CallbackInput)
117 if info.Component == "ChatModel" && ok {
118 if len(messages) < len(i.Messages) {
119 // find new messages
120 newMessages := i.Messages[len(messages):]
121 for _, m := range newMessages {
122 _, err = s.messages.Create(sessionID, *m)
123 if err != nil {
124 s.Publish(AgentErrorEvent, AgentEvent{
125 ID: id,
126 Type: AgentMessageTypeError,
127 AgentID: RootAgent,
128 MessageID: "",
129 SessionID: sessionID,
130 Content: err.Error(),
131 })
132 }
133 messages = append(messages, m)
134 }
135 }
136 }
137
138 return ctx
139 })
140 builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
141 return ctx
142 })
143
144 out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
145 if err != nil {
146 s.Publish(AgentErrorEvent, AgentEvent{
147 ID: id,
148 Type: AgentMessageTypeError,
149 AgentID: RootAgent,
150 MessageID: "",
151 SessionID: sessionID,
152 Content: err.Error(),
153 })
154 return
155 }
156 usage := out.ResponseMeta.Usage
157 s.messages.Create(sessionID, *out)
158 if usage != nil {
159 log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
160 session, err := s.sessions.Get(sessionID)
161 if err != nil {
162 s.Publish(AgentErrorEvent, AgentEvent{
163 ID: id,
164 Type: AgentMessageTypeError,
165 AgentID: RootAgent,
166 MessageID: "",
167 SessionID: sessionID,
168 Content: err.Error(),
169 })
170 return
171 }
172 session.PromptTokens += int64(usage.PromptTokens)
173 session.CompletionTokens += int64(usage.CompletionTokens)
174 model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
175 session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
176 float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
177 var newTitle string
178 if len(history) == 1 {
179 // first message generate the title
180 newTitle, err = agent.GenerateTitle(s.ctx, content)
181 if err != nil {
182 s.Publish(AgentErrorEvent, AgentEvent{
183 ID: id,
184 Type: AgentMessageTypeError,
185 AgentID: RootAgent,
186 MessageID: "",
187 SessionID: sessionID,
188 Content: err.Error(),
189 })
190 return
191 }
192 }
193 if newTitle != "" {
194 session.Title = newTitle
195 }
196
197 _, err = s.sessions.Save(session)
198 if err != nil {
199 s.Publish(AgentErrorEvent, AgentEvent{
200 ID: id,
201 Type: AgentMessageTypeError,
202 AgentID: RootAgent,
203 MessageID: "",
204 SessionID: sessionID,
205 Content: err.Error(),
206 })
207 return
208 }
209 }
210}
211
212func (s *service) SendRequest(sessionID string, content string) {
213 id := uuid.New().String()
214
215 _, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
216 s.activeRequests.Store(id, cancel)
217 log.Printf("Request: %s", content)
218 go s.handleRequest(id, sessionID, content)
219}
220
221func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
222 return &service{
223 Broker: pubsub.NewBroker[AgentEvent](),
224 ctx: ctx,
225 sessions: sessions,
226 messages: messages,
227 logger: logger,
228 }
229}