llm.go

  1package llm
  2
  3import (
  4	"context"
  5	"log"
  6	"sync"
  7	"time"
  8
  9	"github.com/cloudwego/eino/callbacks"
 10	"github.com/cloudwego/eino/compose"
 11	"github.com/cloudwego/eino/schema"
 12	"github.com/google/uuid"
 13	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 14	"github.com/kujtimiihoxha/termai/internal/logging"
 15	"github.com/kujtimiihoxha/termai/internal/message"
 16	"github.com/kujtimiihoxha/termai/internal/pubsub"
 17	"github.com/kujtimiihoxha/termai/internal/session"
 18
 19	eModel "github.com/cloudwego/eino/components/model"
 20	enioAgent "github.com/cloudwego/eino/flow/agent"
 21	"github.com/spf13/viper"
 22)
 23
 24const (
 25	AgentRequestoEvent pubsub.EventType = "agent_request"
 26	AgentErrorEvent    pubsub.EventType = "agent_error"
 27	AgentResponseEvent pubsub.EventType = "agent_response"
 28)
 29
 30type AgentMessageType int
 31
 32const (
 33	AgentMessageTypeNewUserMessage AgentMessageType = iota
 34	AgentMessageTypeAgentResponse
 35	AgentMessageTypeError
 36)
 37
 38type agentID string
 39
 40const (
 41	RootAgent agentID = "root"
 42	TaskAgent agentID = "task"
 43)
 44
 45type AgentEvent struct {
 46	ID        string           `json:"id"`
 47	Type      AgentMessageType `json:"type"`
 48	AgentID   agentID          `json:"agent_id"`
 49	MessageID string           `json:"message_id"`
 50	SessionID string           `json:"session_id"`
 51	Content   string           `json:"content"`
 52}
 53
 54type Service interface {
 55	pubsub.Suscriber[AgentEvent]
 56
 57	SendRequest(sessionID string, content string)
 58}
 59type service struct {
 60	*pubsub.Broker[AgentEvent]
 61	Requests       sync.Map
 62	ctx            context.Context
 63	activeRequests sync.Map
 64	messages       message.Service
 65	sessions       session.Service
 66	logger         logging.Interface
 67}
 68
 69func (s *service) handleRequest(id string, sessionID string, content string) {
 70	cancel, ok := s.activeRequests.Load(id)
 71	if !ok {
 72		return
 73	}
 74	defer cancel.(context.CancelFunc)()
 75	defer s.activeRequests.Delete(id)
 76
 77	history, err := s.messages.List(sessionID)
 78	if err != nil {
 79		s.Publish(AgentErrorEvent, AgentEvent{
 80			ID:        id,
 81			Type:      AgentMessageTypeError,
 82			AgentID:   RootAgent,
 83			MessageID: "",
 84			SessionID: sessionID,
 85			Content:   err.Error(),
 86		})
 87		return
 88	}
 89
 90	log.Printf("Request: %s", content)
 91	agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
 92	if err != nil {
 93		s.Publish(AgentErrorEvent, AgentEvent{
 94			ID:        id,
 95			Type:      AgentMessageTypeError,
 96			AgentID:   RootAgent,
 97			MessageID: "",
 98			SessionID: sessionID,
 99			Content:   err.Error(),
100		})
101		return
102	}
103
104	messages := []*schema.Message{
105		{
106			Role:    schema.System,
107			Content: systemMessage,
108		},
109	}
110	for _, m := range history {
111		messages = append(messages, &m.MessageData)
112	}
113	builder := callbacks.NewHandlerBuilder()
114	builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
115		i, ok := input.(*eModel.CallbackInput)
116		if info.Component == "ChatModel" && ok {
117			if len(messages) < len(i.Messages) {
118				// find new messages
119				newMessages := i.Messages[len(messages):]
120				for _, m := range newMessages {
121					_, err = s.messages.Create(sessionID, *m)
122					if err != nil {
123						s.Publish(AgentErrorEvent, AgentEvent{
124							ID:        id,
125							Type:      AgentMessageTypeError,
126							AgentID:   RootAgent,
127							MessageID: "",
128							SessionID: sessionID,
129							Content:   err.Error(),
130						})
131					}
132					messages = append(messages, m)
133				}
134			}
135		}
136
137		return ctx
138	})
139	builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
140		return ctx
141	})
142
143	out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
144	if err != nil {
145		s.Publish(AgentErrorEvent, AgentEvent{
146			ID:        id,
147			Type:      AgentMessageTypeError,
148			AgentID:   RootAgent,
149			MessageID: "",
150			SessionID: sessionID,
151			Content:   err.Error(),
152		})
153		return
154	}
155	usage := out.ResponseMeta.Usage
156	if usage != nil {
157		log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
158		session, err := s.sessions.Get(sessionID)
159		if err != nil {
160			s.Publish(AgentErrorEvent, AgentEvent{
161				ID:        id,
162				Type:      AgentMessageTypeError,
163				AgentID:   RootAgent,
164				MessageID: "",
165				SessionID: sessionID,
166				Content:   err.Error(),
167			})
168			return
169		}
170		session.PromptTokens += int64(usage.PromptTokens)
171		session.CompletionTokens += int64(usage.CompletionTokens)
172		// TODO: calculate cost
173		_, err = s.sessions.Save(session)
174		if err != nil {
175			s.Publish(AgentErrorEvent, AgentEvent{
176				ID:        id,
177				Type:      AgentMessageTypeError,
178				AgentID:   RootAgent,
179				MessageID: "",
180				SessionID: sessionID,
181				Content:   err.Error(),
182			})
183			return
184		}
185	}
186	s.messages.Create(sessionID, *out)
187}
188
189func (s *service) SendRequest(sessionID string, content string) {
190	id := uuid.New().String()
191
192	_, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
193	s.activeRequests.Store(id, cancel)
194	log.Printf("Request: %s", content)
195	go s.handleRequest(id, sessionID, content)
196}
197
198func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
199	return &service{
200		Broker:   pubsub.NewBroker[AgentEvent](),
201		ctx:      ctx,
202		sessions: sessions,
203		messages: messages,
204		logger:   logger,
205	}
206}