llm.go

  1package llm
  2
  3import (
  4	"context"
  5	"log"
  6	"sync"
  7	"time"
  8
  9	"github.com/cloudwego/eino/callbacks"
 10	"github.com/cloudwego/eino/compose"
 11	"github.com/cloudwego/eino/schema"
 12	"github.com/google/uuid"
 13	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 14	"github.com/kujtimiihoxha/termai/internal/llm/models"
 15	"github.com/kujtimiihoxha/termai/internal/logging"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17	"github.com/kujtimiihoxha/termai/internal/pubsub"
 18	"github.com/kujtimiihoxha/termai/internal/session"
 19
 20	eModel "github.com/cloudwego/eino/components/model"
 21	enioAgent "github.com/cloudwego/eino/flow/agent"
 22	"github.com/spf13/viper"
 23)
 24
 25const (
 26	AgentRequestoEvent pubsub.EventType = "agent_request"
 27	AgentErrorEvent    pubsub.EventType = "agent_error"
 28	AgentResponseEvent pubsub.EventType = "agent_response"
 29)
 30
 31type AgentMessageType int
 32
 33const (
 34	AgentMessageTypeNewUserMessage AgentMessageType = iota
 35	AgentMessageTypeAgentResponse
 36	AgentMessageTypeError
 37)
 38
 39type agentID string
 40
 41const (
 42	RootAgent agentID = "root"
 43	TaskAgent agentID = "task"
 44)
 45
 46type AgentEvent struct {
 47	ID        string           `json:"id"`
 48	Type      AgentMessageType `json:"type"`
 49	AgentID   agentID          `json:"agent_id"`
 50	MessageID string           `json:"message_id"`
 51	SessionID string           `json:"session_id"`
 52	Content   string           `json:"content"`
 53}
 54
 55type Service interface {
 56	pubsub.Suscriber[AgentEvent]
 57
 58	SendRequest(sessionID string, content string)
 59}
 60type service struct {
 61	*pubsub.Broker[AgentEvent]
 62	Requests       sync.Map
 63	ctx            context.Context
 64	activeRequests sync.Map
 65	messages       message.Service
 66	sessions       session.Service
 67	logger         logging.Interface
 68}
 69
 70func (s *service) handleRequest(id string, sessionID string, content string) {
 71	cancel, ok := s.activeRequests.Load(id)
 72	if !ok {
 73		return
 74	}
 75	defer cancel.(context.CancelFunc)()
 76	defer s.activeRequests.Delete(id)
 77
 78	history, err := s.messages.List(sessionID)
 79	if err != nil {
 80		s.Publish(AgentErrorEvent, AgentEvent{
 81			ID:        id,
 82			Type:      AgentMessageTypeError,
 83			AgentID:   RootAgent,
 84			MessageID: "",
 85			SessionID: sessionID,
 86			Content:   err.Error(),
 87		})
 88		return
 89	}
 90
 91	currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
 92	if err != nil {
 93		s.Publish(AgentErrorEvent, AgentEvent{
 94			ID:        id,
 95			Type:      AgentMessageTypeError,
 96			AgentID:   RootAgent,
 97			MessageID: "",
 98			SessionID: sessionID,
 99			Content:   err.Error(),
100		})
101		return
102	}
103
104	messages := []*schema.Message{
105		{
106			Role:    schema.System,
107			Content: systemMessage,
108		},
109	}
110	for _, m := range history {
111		messages = append(messages, &m.MessageData)
112	}
113
114	builder := callbacks.NewHandlerBuilder()
115	builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
116		i, ok := input.(*eModel.CallbackInput)
117		if info.Component == "ChatModel" && ok {
118			if len(messages) < len(i.Messages) {
119				// find new messages
120				newMessages := i.Messages[len(messages):]
121				for _, m := range newMessages {
122					_, err = s.messages.Create(sessionID, *m)
123					if err != nil {
124						s.Publish(AgentErrorEvent, AgentEvent{
125							ID:        id,
126							Type:      AgentMessageTypeError,
127							AgentID:   RootAgent,
128							MessageID: "",
129							SessionID: sessionID,
130							Content:   err.Error(),
131						})
132					}
133					messages = append(messages, m)
134				}
135			}
136		}
137
138		return ctx
139	})
140	builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
141		return ctx
142	})
143
144	out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
145	if err != nil {
146		s.Publish(AgentErrorEvent, AgentEvent{
147			ID:        id,
148			Type:      AgentMessageTypeError,
149			AgentID:   RootAgent,
150			MessageID: "",
151			SessionID: sessionID,
152			Content:   err.Error(),
153		})
154		return
155	}
156	usage := out.ResponseMeta.Usage
157	s.messages.Create(sessionID, *out)
158	if usage != nil {
159		log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
160		session, err := s.sessions.Get(sessionID)
161		if err != nil {
162			s.Publish(AgentErrorEvent, AgentEvent{
163				ID:        id,
164				Type:      AgentMessageTypeError,
165				AgentID:   RootAgent,
166				MessageID: "",
167				SessionID: sessionID,
168				Content:   err.Error(),
169			})
170			return
171		}
172		session.PromptTokens += int64(usage.PromptTokens)
173		session.CompletionTokens += int64(usage.CompletionTokens)
174		model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
175		session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
176			float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
177		var newTitle string
178		if len(history) == 1 {
179			// first message generate the title
180			newTitle, err = agent.GenerateTitle(s.ctx, content)
181			if err != nil {
182				s.Publish(AgentErrorEvent, AgentEvent{
183					ID:        id,
184					Type:      AgentMessageTypeError,
185					AgentID:   RootAgent,
186					MessageID: "",
187					SessionID: sessionID,
188					Content:   err.Error(),
189				})
190				return
191			}
192		}
193		if newTitle != "" {
194			session.Title = newTitle
195		}
196
197		_, err = s.sessions.Save(session)
198		if err != nil {
199			s.Publish(AgentErrorEvent, AgentEvent{
200				ID:        id,
201				Type:      AgentMessageTypeError,
202				AgentID:   RootAgent,
203				MessageID: "",
204				SessionID: sessionID,
205				Content:   err.Error(),
206			})
207			return
208		}
209	}
210}
211
212func (s *service) SendRequest(sessionID string, content string) {
213	id := uuid.New().String()
214
215	_, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
216	s.activeRequests.Store(id, cancel)
217	log.Printf("Request: %s", content)
218	go s.handleRequest(id, sessionID, content)
219}
220
221func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
222	return &service{
223		Broker:   pubsub.NewBroker[AgentEvent](),
224		ctx:      ctx,
225		sessions: sessions,
226		messages: messages,
227		logger:   logger,
228	}
229}