1package message
2
3import (
4 "context"
5 "database/sql"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/db"
9 "github.com/charmbracelet/crush/internal/proto"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/google/uuid"
12)
13
14type (
15 CreateMessageParams = proto.CreateMessageParams
16 Message = proto.Message
17 Attachment = proto.Attachment
18 ToolCall = proto.ToolCall
19 ToolResult = proto.ToolResult
20 ContentPart = proto.ContentPart
21 TextContent = proto.TextContent
22 BinaryContent = proto.BinaryContent
23 FinishReason = proto.FinishReason
24 Finish = proto.Finish
25)
26
27const (
28 Assistant = proto.Assistant
29 User = proto.User
30 System = proto.System
31 Tool = proto.Tool
32
33 FinishReasonEndTurn = proto.FinishReasonEndTurn
34 FinishReasonMaxTokens = proto.FinishReasonMaxTokens
35 FinishReasonToolUse = proto.FinishReasonToolUse
36 FinishReasonCanceled = proto.FinishReasonCanceled
37 FinishReasonError = proto.FinishReasonError
38 FinishReasonPermissionDenied = proto.FinishReasonPermissionDenied
39
40 FinishReasonUnknown = proto.FinishReasonUnknown
41)
42
43type Service interface {
44 pubsub.Suscriber[Message]
45 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
46 Update(ctx context.Context, message Message) error
47 Get(ctx context.Context, id string) (Message, error)
48 List(ctx context.Context, sessionID string) ([]Message, error)
49 Delete(ctx context.Context, id string) error
50 DeleteSessionMessages(ctx context.Context, sessionID string) error
51}
52
53type service struct {
54 *pubsub.Broker[Message]
55 q db.Querier
56}
57
58func NewService(q db.Querier) Service {
59 return &service{
60 Broker: pubsub.NewBroker[Message](),
61 q: q,
62 }
63}
64
65func (s *service) Delete(ctx context.Context, id string) error {
66 message, err := s.Get(ctx, id)
67 if err != nil {
68 return err
69 }
70 err = s.q.DeleteMessage(ctx, message.ID)
71 if err != nil {
72 return err
73 }
74 s.Publish(pubsub.DeletedEvent, message)
75 return nil
76}
77
78func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
79 if params.Role != proto.Assistant {
80 params.Parts = append(params.Parts, proto.Finish{
81 Reason: "stop",
82 })
83 }
84 partsJSON, err := proto.MarshallParts(params.Parts)
85 if err != nil {
86 return Message{}, err
87 }
88 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
89 ID: uuid.New().String(),
90 SessionID: sessionID,
91 Role: string(params.Role),
92 Parts: string(partsJSON),
93 Model: sql.NullString{String: string(params.Model), Valid: true},
94 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
95 })
96 if err != nil {
97 return Message{}, err
98 }
99 message, err := s.fromDBItem(dbMessage)
100 if err != nil {
101 return Message{}, err
102 }
103 s.Publish(pubsub.CreatedEvent, message)
104 return message, nil
105}
106
107func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
108 messages, err := s.List(ctx, sessionID)
109 if err != nil {
110 return err
111 }
112 for _, message := range messages {
113 if message.SessionID == sessionID {
114 err = s.Delete(ctx, message.ID)
115 if err != nil {
116 return err
117 }
118 }
119 }
120 return nil
121}
122
123func (s *service) Update(ctx context.Context, message Message) error {
124 parts, err := proto.MarshallParts(message.Parts)
125 if err != nil {
126 return err
127 }
128 finishedAt := sql.NullInt64{}
129 if f := message.FinishPart(); f != nil {
130 finishedAt.Int64 = f.Time
131 finishedAt.Valid = true
132 }
133 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
134 ID: message.ID,
135 Parts: string(parts),
136 FinishedAt: finishedAt,
137 })
138 if err != nil {
139 return err
140 }
141 message.UpdatedAt = time.Now().Unix()
142 s.Publish(pubsub.UpdatedEvent, message)
143 return nil
144}
145
146func (s *service) Get(ctx context.Context, id string) (Message, error) {
147 dbMessage, err := s.q.GetMessage(ctx, id)
148 if err != nil {
149 return Message{}, err
150 }
151 return s.fromDBItem(dbMessage)
152}
153
154func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
155 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
156 if err != nil {
157 return nil, err
158 }
159 messages := make([]Message, len(dbMessages))
160 for i, dbMessage := range dbMessages {
161 messages[i], err = s.fromDBItem(dbMessage)
162 if err != nil {
163 return nil, err
164 }
165 }
166 return messages, nil
167}
168
169func (s *service) fromDBItem(item db.Message) (Message, error) {
170 parts, err := proto.UnmarshallParts([]byte(item.Parts))
171 if err != nil {
172 return Message{}, err
173 }
174 return Message{
175 ID: item.ID,
176 SessionID: item.SessionID,
177 Role: proto.MessageRole(item.Role),
178 Parts: parts,
179 Model: item.Model.String,
180 Provider: item.Provider.String,
181 CreatedAt: item.CreatedAt,
182 UpdatedAt: item.UpdatedAt,
183 }, nil
184}