1package llm
2
3import (
4 "context"
5 "log"
6 "sync"
7 "time"
8
9 "github.com/cloudwego/eino/callbacks"
10 "github.com/cloudwego/eino/compose"
11 "github.com/cloudwego/eino/schema"
12 "github.com/google/uuid"
13 "github.com/kujtimiihoxha/termai/internal/llm/agent"
14 "github.com/kujtimiihoxha/termai/internal/logging"
15 "github.com/kujtimiihoxha/termai/internal/message"
16 "github.com/kujtimiihoxha/termai/internal/pubsub"
17 "github.com/kujtimiihoxha/termai/internal/session"
18
19 eModel "github.com/cloudwego/eino/components/model"
20 enioAgent "github.com/cloudwego/eino/flow/agent"
21 "github.com/spf13/viper"
22)
23
24const (
25 AgentRequestoEvent pubsub.EventType = "agent_request"
26 AgentErrorEvent pubsub.EventType = "agent_error"
27 AgentResponseEvent pubsub.EventType = "agent_response"
28)
29
30type AgentMessageType int
31
32const (
33 AgentMessageTypeNewUserMessage AgentMessageType = iota
34 AgentMessageTypeAgentResponse
35 AgentMessageTypeError
36)
37
38type agentID string
39
40const (
41 RootAgent agentID = "root"
42 TaskAgent agentID = "task"
43)
44
45type AgentEvent struct {
46 ID string `json:"id"`
47 Type AgentMessageType `json:"type"`
48 AgentID agentID `json:"agent_id"`
49 MessageID string `json:"message_id"`
50 SessionID string `json:"session_id"`
51 Content string `json:"content"`
52}
53
54type Service interface {
55 pubsub.Suscriber[AgentEvent]
56
57 SendRequest(sessionID string, content string)
58}
59type service struct {
60 *pubsub.Broker[AgentEvent]
61 Requests sync.Map
62 ctx context.Context
63 activeRequests sync.Map
64 messages message.Service
65 sessions session.Service
66 logger logging.Interface
67}
68
69func (s *service) handleRequest(id string, sessionID string, content string) {
70 cancel, ok := s.activeRequests.Load(id)
71 if !ok {
72 return
73 }
74 defer cancel.(context.CancelFunc)()
75 defer s.activeRequests.Delete(id)
76
77 history, err := s.messages.List(sessionID)
78 if err != nil {
79 s.Publish(AgentErrorEvent, AgentEvent{
80 ID: id,
81 Type: AgentMessageTypeError,
82 AgentID: RootAgent,
83 MessageID: "",
84 SessionID: sessionID,
85 Content: err.Error(),
86 })
87 return
88 }
89
90 log.Printf("Request: %s", content)
91 agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
92 if err != nil {
93 s.Publish(AgentErrorEvent, AgentEvent{
94 ID: id,
95 Type: AgentMessageTypeError,
96 AgentID: RootAgent,
97 MessageID: "",
98 SessionID: sessionID,
99 Content: err.Error(),
100 })
101 return
102 }
103
104 messages := []*schema.Message{
105 {
106 Role: schema.System,
107 Content: systemMessage,
108 },
109 }
110 for _, m := range history {
111 messages = append(messages, &m.MessageData)
112 }
113 builder := callbacks.NewHandlerBuilder()
114 builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
115 i, ok := input.(*eModel.CallbackInput)
116 if info.Component == "ChatModel" && ok {
117 if len(messages) < len(i.Messages) {
118 // find new messages
119 newMessages := i.Messages[len(messages):]
120 for _, m := range newMessages {
121 _, err = s.messages.Create(sessionID, *m)
122 if err != nil {
123 s.Publish(AgentErrorEvent, AgentEvent{
124 ID: id,
125 Type: AgentMessageTypeError,
126 AgentID: RootAgent,
127 MessageID: "",
128 SessionID: sessionID,
129 Content: err.Error(),
130 })
131 }
132 messages = append(messages, m)
133 }
134 }
135 }
136
137 return ctx
138 })
139 builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
140 return ctx
141 })
142
143 out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
144 if err != nil {
145 s.Publish(AgentErrorEvent, AgentEvent{
146 ID: id,
147 Type: AgentMessageTypeError,
148 AgentID: RootAgent,
149 MessageID: "",
150 SessionID: sessionID,
151 Content: err.Error(),
152 })
153 return
154 }
155 usage := out.ResponseMeta.Usage
156 if usage != nil {
157 log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
158 session, err := s.sessions.Get(sessionID)
159 if err != nil {
160 s.Publish(AgentErrorEvent, AgentEvent{
161 ID: id,
162 Type: AgentMessageTypeError,
163 AgentID: RootAgent,
164 MessageID: "",
165 SessionID: sessionID,
166 Content: err.Error(),
167 })
168 return
169 }
170 session.PromptTokens += int64(usage.PromptTokens)
171 session.CompletionTokens += int64(usage.CompletionTokens)
172 // TODO: calculate cost
173 _, err = s.sessions.Save(session)
174 if err != nil {
175 s.Publish(AgentErrorEvent, AgentEvent{
176 ID: id,
177 Type: AgentMessageTypeError,
178 AgentID: RootAgent,
179 MessageID: "",
180 SessionID: sessionID,
181 Content: err.Error(),
182 })
183 return
184 }
185 }
186 s.messages.Create(sessionID, *out)
187}
188
189func (s *service) SendRequest(sessionID string, content string) {
190 id := uuid.New().String()
191
192 _, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
193 s.activeRequests.Store(id, cancel)
194 log.Printf("Request: %s", content)
195 go s.handleRequest(id, sessionID, content)
196}
197
198func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
199 return &service{
200 Broker: pubsub.NewBroker[AgentEvent](),
201 ctx: ctx,
202 sessions: sessions,
203 messages: messages,
204 logger: logger,
205 }
206}