1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"time"
  7
  8	"github.com/charmbracelet/crush/internal/db"
  9	"github.com/charmbracelet/crush/internal/proto"
 10	"github.com/charmbracelet/crush/internal/pubsub"
 11	"github.com/google/uuid"
 12)
 13
 14type (
 15	CreateMessageParams = proto.CreateMessageParams
 16	Message             = proto.Message
 17	Attachment          = proto.Attachment
 18	ToolCall            = proto.ToolCall
 19	ToolResult          = proto.ToolResult
 20	ContentPart         = proto.ContentPart
 21	TextContent         = proto.TextContent
 22	BinaryContent       = proto.BinaryContent
 23	FinishReason        = proto.FinishReason
 24	Finish              = proto.Finish
 25)
 26
 27const (
 28	Assistant = proto.Assistant
 29	User      = proto.User
 30	System    = proto.System
 31	Tool      = proto.Tool
 32
 33	FinishReasonEndTurn          = proto.FinishReasonEndTurn
 34	FinishReasonMaxTokens        = proto.FinishReasonMaxTokens
 35	FinishReasonToolUse          = proto.FinishReasonToolUse
 36	FinishReasonCanceled         = proto.FinishReasonCanceled
 37	FinishReasonError            = proto.FinishReasonError
 38	FinishReasonPermissionDenied = proto.FinishReasonPermissionDenied
 39
 40	FinishReasonUnknown = proto.FinishReasonUnknown
 41)
 42
 43type Service interface {
 44	pubsub.Suscriber[Message]
 45	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 46	Update(ctx context.Context, message Message) error
 47	Get(ctx context.Context, id string) (Message, error)
 48	List(ctx context.Context, sessionID string) ([]Message, error)
 49	Delete(ctx context.Context, id string) error
 50	DeleteSessionMessages(ctx context.Context, sessionID string) error
 51}
 52
 53type service struct {
 54	*pubsub.Broker[Message]
 55	q db.Querier
 56}
 57
 58func NewService(q db.Querier) Service {
 59	return &service{
 60		Broker: pubsub.NewBroker[Message](),
 61		q:      q,
 62	}
 63}
 64
 65func (s *service) Delete(ctx context.Context, id string) error {
 66	message, err := s.Get(ctx, id)
 67	if err != nil {
 68		return err
 69	}
 70	err = s.q.DeleteMessage(ctx, message.ID)
 71	if err != nil {
 72		return err
 73	}
 74	s.Publish(pubsub.DeletedEvent, message)
 75	return nil
 76}
 77
 78func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 79	if params.Role != proto.Assistant {
 80		params.Parts = append(params.Parts, proto.Finish{
 81			Reason: "stop",
 82		})
 83	}
 84	partsJSON, err := proto.MarshallParts(params.Parts)
 85	if err != nil {
 86		return Message{}, err
 87	}
 88	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 89		ID:        uuid.New().String(),
 90		SessionID: sessionID,
 91		Role:      string(params.Role),
 92		Parts:     string(partsJSON),
 93		Model:     sql.NullString{String: string(params.Model), Valid: true},
 94		Provider:  sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 95	})
 96	if err != nil {
 97		return Message{}, err
 98	}
 99	message, err := s.fromDBItem(dbMessage)
100	if err != nil {
101		return Message{}, err
102	}
103	s.Publish(pubsub.CreatedEvent, message)
104	return message, nil
105}
106
107func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
108	messages, err := s.List(ctx, sessionID)
109	if err != nil {
110		return err
111	}
112	for _, message := range messages {
113		if message.SessionID == sessionID {
114			err = s.Delete(ctx, message.ID)
115			if err != nil {
116				return err
117			}
118		}
119	}
120	return nil
121}
122
123func (s *service) Update(ctx context.Context, message Message) error {
124	parts, err := proto.MarshallParts(message.Parts)
125	if err != nil {
126		return err
127	}
128	finishedAt := sql.NullInt64{}
129	if f := message.FinishPart(); f != nil {
130		finishedAt.Int64 = f.Time
131		finishedAt.Valid = true
132	}
133	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
134		ID:         message.ID,
135		Parts:      string(parts),
136		FinishedAt: finishedAt,
137	})
138	if err != nil {
139		return err
140	}
141	message.UpdatedAt = time.Now().Unix()
142	s.Publish(pubsub.UpdatedEvent, message)
143	return nil
144}
145
146func (s *service) Get(ctx context.Context, id string) (Message, error) {
147	dbMessage, err := s.q.GetMessage(ctx, id)
148	if err != nil {
149		return Message{}, err
150	}
151	return s.fromDBItem(dbMessage)
152}
153
154func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
155	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
156	if err != nil {
157		return nil, err
158	}
159	messages := make([]Message, len(dbMessages))
160	for i, dbMessage := range dbMessages {
161		messages[i], err = s.fromDBItem(dbMessage)
162		if err != nil {
163			return nil, err
164		}
165	}
166	return messages, nil
167}
168
169func (s *service) fromDBItem(item db.Message) (Message, error) {
170	parts, err := proto.UnmarshallParts([]byte(item.Parts))
171	if err != nil {
172		return Message{}, err
173	}
174	return Message{
175		ID:        item.ID,
176		SessionID: item.SessionID,
177		Role:      proto.MessageRole(item.Role),
178		Parts:     parts,
179		Model:     item.Model.String,
180		Provider:  item.Provider.String,
181		CreatedAt: item.CreatedAt,
182		UpdatedAt: item.UpdatedAt,
183	}, nil
184}