session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/google/uuid"
  8	"github.com/kujtimiihoxha/termai/internal/db"
  9	"github.com/kujtimiihoxha/termai/internal/pubsub"
 10)
 11
 12type Session struct {
 13	ID               string
 14	ParentSessionID  string
 15	Title            string
 16	MessageCount     int64
 17	PromptTokens     int64
 18	CompletionTokens int64
 19	Cost             float64
 20	CreatedAt        int64
 21	UpdatedAt        int64
 22}
 23
 24type Service interface {
 25	pubsub.Suscriber[Session]
 26	Create(ctx context.Context, title string) (Session, error)
 27	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 28	Get(ctx context.Context, id string) (Session, error)
 29	List(ctx context.Context) ([]Session, error)
 30	Save(ctx context.Context, session Session) (Session, error)
 31	Delete(ctx context.Context, id string) error
 32}
 33
 34type service struct {
 35	*pubsub.Broker[Session]
 36	q db.Querier
 37}
 38
 39func (s *service) Create(ctx context.Context, title string) (Session, error) {
 40	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 41		ID:    uuid.New().String(),
 42		Title: title,
 43	})
 44	if err != nil {
 45		return Session{}, err
 46	}
 47	session := s.fromDBItem(dbSession)
 48	s.Publish(pubsub.CreatedEvent, session)
 49	return session, nil
 50}
 51
 52func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 53	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 54		ID:              toolCallID,
 55		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 56		Title:           title,
 57	})
 58	if err != nil {
 59		return Session{}, err
 60	}
 61	session := s.fromDBItem(dbSession)
 62	s.Publish(pubsub.CreatedEvent, session)
 63	return session, nil
 64}
 65
 66func (s *service) Delete(ctx context.Context, id string) error {
 67	session, err := s.Get(ctx, id)
 68	if err != nil {
 69		return err
 70	}
 71	err = s.q.DeleteSession(ctx, session.ID)
 72	if err != nil {
 73		return err
 74	}
 75	s.Publish(pubsub.DeletedEvent, session)
 76	return nil
 77}
 78
 79func (s *service) Get(ctx context.Context, id string) (Session, error) {
 80	dbSession, err := s.q.GetSessionByID(ctx, id)
 81	if err != nil {
 82		return Session{}, err
 83	}
 84	return s.fromDBItem(dbSession), nil
 85}
 86
 87func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 88	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
 89		ID:               session.ID,
 90		Title:            session.Title,
 91		PromptTokens:     session.PromptTokens,
 92		CompletionTokens: session.CompletionTokens,
 93		Cost:             session.Cost,
 94	})
 95	if err != nil {
 96		return Session{}, err
 97	}
 98	session = s.fromDBItem(dbSession)
 99	s.Publish(pubsub.UpdatedEvent, session)
100	return session, nil
101}
102
103func (s *service) List(ctx context.Context) ([]Session, error) {
104	dbSessions, err := s.q.ListSessions(ctx)
105	if err != nil {
106		return nil, err
107	}
108	sessions := make([]Session, len(dbSessions))
109	for i, dbSession := range dbSessions {
110		sessions[i] = s.fromDBItem(dbSession)
111	}
112	return sessions, nil
113}
114
115func (s service) fromDBItem(item db.Session) Session {
116	return Session{
117		ID:               item.ID,
118		ParentSessionID:  item.ParentSessionID.String,
119		Title:            item.Title,
120		MessageCount:     item.MessageCount,
121		PromptTokens:     item.PromptTokens,
122		CompletionTokens: item.CompletionTokens,
123		Cost:             item.Cost,
124		CreatedAt:        item.CreatedAt,
125		UpdatedAt:        item.UpdatedAt,
126	}
127}
128
129func NewService(q db.Querier) Service {
130	broker := pubsub.NewBroker[Session]()
131	return &service{
132		broker,
133		q,
134	}
135}