1package llm
2
3import (
4 "context"
5 "log"
6 "sync"
7 "time"
8
9 "github.com/cloudwego/eino/callbacks"
10 "github.com/cloudwego/eino/compose"
11 "github.com/cloudwego/eino/schema"
12 "github.com/google/uuid"
13 "github.com/kujtimiihoxha/termai/internal/llm/agent"
14 "github.com/kujtimiihoxha/termai/internal/llm/models"
15 "github.com/kujtimiihoxha/termai/internal/logging"
16 "github.com/kujtimiihoxha/termai/internal/message"
17 "github.com/kujtimiihoxha/termai/internal/pubsub"
18 "github.com/kujtimiihoxha/termai/internal/session"
19
20 eModel "github.com/cloudwego/eino/components/model"
21 enioAgent "github.com/cloudwego/eino/flow/agent"
22 "github.com/spf13/viper"
23)
24
25const (
26 AgentRequestoEvent pubsub.EventType = "agent_request"
27 AgentErrorEvent pubsub.EventType = "agent_error"
28 AgentResponseEvent pubsub.EventType = "agent_response"
29)
30
31type AgentMessageType int
32
33const (
34 AgentMessageTypeNewUserMessage AgentMessageType = iota
35 AgentMessageTypeAgentResponse
36 AgentMessageTypeError
37)
38
39type agentID string
40
41const (
42 RootAgent agentID = "root"
43 TaskAgent agentID = "task"
44)
45
46type AgentEvent struct {
47 ID string `json:"id"`
48 Type AgentMessageType `json:"type"`
49 AgentID agentID `json:"agent_id"`
50 MessageID string `json:"message_id"`
51 SessionID string `json:"session_id"`
52 Content string `json:"content"`
53}
54
55type Service interface {
56 pubsub.Suscriber[AgentEvent]
57
58 SendRequest(sessionID string, content string)
59}
60type service struct {
61 *pubsub.Broker[AgentEvent]
62 Requests sync.Map
63 ctx context.Context
64 activeRequests sync.Map
65 messages message.Service
66 sessions session.Service
67 logger logging.Interface
68}
69
70func (s *service) handleRequest(id string, sessionID string, content string) {
71 cancel, ok := s.activeRequests.Load(id)
72 if !ok {
73 return
74 }
75 defer cancel.(context.CancelFunc)()
76 defer s.activeRequests.Delete(id)
77
78 history, err := s.messages.List(sessionID)
79 if err != nil {
80 s.Publish(AgentErrorEvent, AgentEvent{
81 ID: id,
82 Type: AgentMessageTypeError,
83 AgentID: RootAgent,
84 MessageID: "",
85 SessionID: sessionID,
86 Content: err.Error(),
87 })
88 return
89 }
90
91 log.Printf("Request: %s", content)
92 currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
93 if err != nil {
94 s.Publish(AgentErrorEvent, AgentEvent{
95 ID: id,
96 Type: AgentMessageTypeError,
97 AgentID: RootAgent,
98 MessageID: "",
99 SessionID: sessionID,
100 Content: err.Error(),
101 })
102 return
103 }
104
105 messages := []*schema.Message{
106 {
107 Role: schema.System,
108 Content: systemMessage,
109 },
110 }
111 for _, m := range history {
112 messages = append(messages, &m.MessageData)
113 }
114
115 builder := callbacks.NewHandlerBuilder()
116 builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
117 i, ok := input.(*eModel.CallbackInput)
118 if info.Component == "ChatModel" && ok {
119 if len(messages) < len(i.Messages) {
120 // find new messages
121 newMessages := i.Messages[len(messages):]
122 for _, m := range newMessages {
123 _, err = s.messages.Create(sessionID, *m)
124 if err != nil {
125 s.Publish(AgentErrorEvent, AgentEvent{
126 ID: id,
127 Type: AgentMessageTypeError,
128 AgentID: RootAgent,
129 MessageID: "",
130 SessionID: sessionID,
131 Content: err.Error(),
132 })
133 }
134 messages = append(messages, m)
135 }
136 }
137 }
138
139 return ctx
140 })
141 builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
142 return ctx
143 })
144
145 out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
146 if err != nil {
147 s.Publish(AgentErrorEvent, AgentEvent{
148 ID: id,
149 Type: AgentMessageTypeError,
150 AgentID: RootAgent,
151 MessageID: "",
152 SessionID: sessionID,
153 Content: err.Error(),
154 })
155 return
156 }
157 usage := out.ResponseMeta.Usage
158 s.messages.Create(sessionID, *out)
159 if usage != nil {
160 log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
161 session, err := s.sessions.Get(sessionID)
162 if err != nil {
163 s.Publish(AgentErrorEvent, AgentEvent{
164 ID: id,
165 Type: AgentMessageTypeError,
166 AgentID: RootAgent,
167 MessageID: "",
168 SessionID: sessionID,
169 Content: err.Error(),
170 })
171 return
172 }
173 session.PromptTokens += int64(usage.PromptTokens)
174 session.CompletionTokens += int64(usage.CompletionTokens)
175 // TODO: calculate cost
176 model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
177 session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
178 float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
179 var newTitle string
180 if len(history) == 1 {
181 // first message generate the title
182 newTitle, err = agent.GenerateTitle(s.ctx, content)
183 if err != nil {
184 s.Publish(AgentErrorEvent, AgentEvent{
185 ID: id,
186 Type: AgentMessageTypeError,
187 AgentID: RootAgent,
188 MessageID: "",
189 SessionID: sessionID,
190 Content: err.Error(),
191 })
192 return
193 }
194 }
195 if newTitle != "" {
196 session.Title = newTitle
197 }
198
199 _, err = s.sessions.Save(session)
200 if err != nil {
201 s.Publish(AgentErrorEvent, AgentEvent{
202 ID: id,
203 Type: AgentMessageTypeError,
204 AgentID: RootAgent,
205 MessageID: "",
206 SessionID: sessionID,
207 Content: err.Error(),
208 })
209 return
210 }
211 }
212}
213
214func (s *service) SendRequest(sessionID string, content string) {
215 id := uuid.New().String()
216
217 _, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
218 s.activeRequests.Store(id, cancel)
219 log.Printf("Request: %s", content)
220 go s.handleRequest(id, sessionID, content)
221}
222
223func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
224 return &service{
225 Broker: pubsub.NewBroker[AgentEvent](),
226 ctx: ctx,
227 sessions: sessions,
228 messages: messages,
229 logger: logger,
230 }
231}