llm.go

  1package llm
  2
  3import (
  4	"context"
  5	"log"
  6	"sync"
  7	"time"
  8
  9	"github.com/cloudwego/eino/callbacks"
 10	"github.com/cloudwego/eino/compose"
 11	"github.com/cloudwego/eino/schema"
 12	"github.com/google/uuid"
 13	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 14	"github.com/kujtimiihoxha/termai/internal/llm/models"
 15	"github.com/kujtimiihoxha/termai/internal/logging"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17	"github.com/kujtimiihoxha/termai/internal/pubsub"
 18	"github.com/kujtimiihoxha/termai/internal/session"
 19
 20	eModel "github.com/cloudwego/eino/components/model"
 21	enioAgent "github.com/cloudwego/eino/flow/agent"
 22	"github.com/spf13/viper"
 23)
 24
 25const (
 26	AgentRequestoEvent pubsub.EventType = "agent_request"
 27	AgentErrorEvent    pubsub.EventType = "agent_error"
 28	AgentResponseEvent pubsub.EventType = "agent_response"
 29)
 30
 31type AgentMessageType int
 32
 33const (
 34	AgentMessageTypeNewUserMessage AgentMessageType = iota
 35	AgentMessageTypeAgentResponse
 36	AgentMessageTypeError
 37)
 38
 39type agentID string
 40
 41const (
 42	RootAgent agentID = "root"
 43	TaskAgent agentID = "task"
 44)
 45
 46type AgentEvent struct {
 47	ID        string           `json:"id"`
 48	Type      AgentMessageType `json:"type"`
 49	AgentID   agentID          `json:"agent_id"`
 50	MessageID string           `json:"message_id"`
 51	SessionID string           `json:"session_id"`
 52	Content   string           `json:"content"`
 53}
 54
 55type Service interface {
 56	pubsub.Suscriber[AgentEvent]
 57
 58	SendRequest(sessionID string, content string)
 59}
 60type service struct {
 61	*pubsub.Broker[AgentEvent]
 62	Requests       sync.Map
 63	ctx            context.Context
 64	activeRequests sync.Map
 65	messages       message.Service
 66	sessions       session.Service
 67	logger         logging.Interface
 68}
 69
 70func (s *service) handleRequest(id string, sessionID string, content string) {
 71	cancel, ok := s.activeRequests.Load(id)
 72	if !ok {
 73		return
 74	}
 75	defer cancel.(context.CancelFunc)()
 76	defer s.activeRequests.Delete(id)
 77
 78	history, err := s.messages.List(sessionID)
 79	if err != nil {
 80		s.Publish(AgentErrorEvent, AgentEvent{
 81			ID:        id,
 82			Type:      AgentMessageTypeError,
 83			AgentID:   RootAgent,
 84			MessageID: "",
 85			SessionID: sessionID,
 86			Content:   err.Error(),
 87		})
 88		return
 89	}
 90
 91	log.Printf("Request: %s", content)
 92	currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
 93	if err != nil {
 94		s.Publish(AgentErrorEvent, AgentEvent{
 95			ID:        id,
 96			Type:      AgentMessageTypeError,
 97			AgentID:   RootAgent,
 98			MessageID: "",
 99			SessionID: sessionID,
100			Content:   err.Error(),
101		})
102		return
103	}
104
105	messages := []*schema.Message{
106		{
107			Role:    schema.System,
108			Content: systemMessage,
109		},
110	}
111	for _, m := range history {
112		messages = append(messages, &m.MessageData)
113	}
114
115	builder := callbacks.NewHandlerBuilder()
116	builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
117		i, ok := input.(*eModel.CallbackInput)
118		if info.Component == "ChatModel" && ok {
119			if len(messages) < len(i.Messages) {
120				// find new messages
121				newMessages := i.Messages[len(messages):]
122				for _, m := range newMessages {
123					_, err = s.messages.Create(sessionID, *m)
124					if err != nil {
125						s.Publish(AgentErrorEvent, AgentEvent{
126							ID:        id,
127							Type:      AgentMessageTypeError,
128							AgentID:   RootAgent,
129							MessageID: "",
130							SessionID: sessionID,
131							Content:   err.Error(),
132						})
133					}
134					messages = append(messages, m)
135				}
136			}
137		}
138
139		return ctx
140	})
141	builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
142		return ctx
143	})
144
145	out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
146	if err != nil {
147		s.Publish(AgentErrorEvent, AgentEvent{
148			ID:        id,
149			Type:      AgentMessageTypeError,
150			AgentID:   RootAgent,
151			MessageID: "",
152			SessionID: sessionID,
153			Content:   err.Error(),
154		})
155		return
156	}
157	usage := out.ResponseMeta.Usage
158	s.messages.Create(sessionID, *out)
159	if usage != nil {
160		log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
161		session, err := s.sessions.Get(sessionID)
162		if err != nil {
163			s.Publish(AgentErrorEvent, AgentEvent{
164				ID:        id,
165				Type:      AgentMessageTypeError,
166				AgentID:   RootAgent,
167				MessageID: "",
168				SessionID: sessionID,
169				Content:   err.Error(),
170			})
171			return
172		}
173		session.PromptTokens += int64(usage.PromptTokens)
174		session.CompletionTokens += int64(usage.CompletionTokens)
175		// TODO: calculate cost
176		model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
177		session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
178			float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
179		var newTitle string
180		if len(history) == 1 {
181			// first message generate the title
182			newTitle, err = agent.GenerateTitle(s.ctx, content)
183			if err != nil {
184				s.Publish(AgentErrorEvent, AgentEvent{
185					ID:        id,
186					Type:      AgentMessageTypeError,
187					AgentID:   RootAgent,
188					MessageID: "",
189					SessionID: sessionID,
190					Content:   err.Error(),
191				})
192				return
193			}
194		}
195		if newTitle != "" {
196			session.Title = newTitle
197		}
198
199		_, err = s.sessions.Save(session)
200		if err != nil {
201			s.Publish(AgentErrorEvent, AgentEvent{
202				ID:        id,
203				Type:      AgentMessageTypeError,
204				AgentID:   RootAgent,
205				MessageID: "",
206				SessionID: sessionID,
207				Content:   err.Error(),
208			})
209			return
210		}
211	}
212}
213
214func (s *service) SendRequest(sessionID string, content string) {
215	id := uuid.New().String()
216
217	_, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
218	s.activeRequests.Store(id, cancel)
219	log.Printf("Request: %s", content)
220	go s.handleRequest(id, sessionID, content)
221}
222
223func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
224	return &service{
225		Broker:   pubsub.NewBroker[AgentEvent](),
226		ctx:      ctx,
227		sessions: sessions,
228		messages: messages,
229		logger:   logger,
230	}
231}