1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Subscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
30 ListAllUserMessages(ctx context.Context) ([]Message, error)
31 Copy(ctx context.Context, sessionID string, message Message) (Message, error)
32 Delete(ctx context.Context, id string) error
33 DeleteSessionMessages(ctx context.Context, sessionID string) error
34 DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error
35}
36
37type service struct {
38 *pubsub.Broker[Message]
39 q db.Querier
40}
41
42func NewService(q db.Querier) Service {
43 return &service{
44 Broker: pubsub.NewBroker[Message](),
45 q: q,
46 }
47}
48
49func (s *service) Delete(ctx context.Context, id string) error {
50 message, err := s.Get(ctx, id)
51 if err != nil {
52 return err
53 }
54 err = s.q.DeleteMessage(ctx, message.ID)
55 if err != nil {
56 return err
57 }
58 // Clone the message before publishing to avoid race conditions with
59 // concurrent modifications to the Parts slice.
60 s.Publish(pubsub.DeletedEvent, message.Clone())
61 return nil
62}
63
64func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
65 if params.Role != Assistant {
66 params.Parts = append(params.Parts, Finish{
67 Reason: "stop",
68 })
69 }
70 partsJSON, err := marshalParts(params.Parts)
71 if err != nil {
72 return Message{}, err
73 }
74 isSummary := int64(0)
75 if params.IsSummaryMessage {
76 isSummary = 1
77 }
78 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
79 ID: uuid.New().String(),
80 SessionID: sessionID,
81 Role: string(params.Role),
82 Parts: string(partsJSON),
83 Model: sql.NullString{String: string(params.Model), Valid: true},
84 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
85 IsSummaryMessage: isSummary,
86 })
87 if err != nil {
88 return Message{}, err
89 }
90 message, err := s.fromDBItem(dbMessage)
91 if err != nil {
92 return Message{}, err
93 }
94 // Clone the message before publishing to avoid race conditions with
95 // concurrent modifications to the Parts slice.
96 s.Publish(pubsub.CreatedEvent, message.Clone())
97 return message, nil
98}
99
100func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) {
101 params := CreateMessageParams{
102 Role: message.Role,
103 Parts: message.Parts,
104 Model: message.Model,
105 Provider: message.Provider,
106 IsSummaryMessage: message.IsSummaryMessage,
107 }
108 return s.Create(ctx, sessionID, params)
109}
110
111func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
112 messages, err := s.List(ctx, sessionID)
113 if err != nil {
114 return err
115 }
116 for _, message := range messages {
117 if message.SessionID == sessionID {
118 err = s.Delete(ctx, message.ID)
119 if err != nil {
120 return err
121 }
122 }
123 }
124 return nil
125}
126
127func (s *service) DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error {
128 allMessages, err := s.List(ctx, sessionID)
129 if err != nil {
130 return err
131 }
132 targetIndex := -1
133 for i, msg := range allMessages {
134 if msg.ID == messageID {
135 targetIndex = i
136 break
137 }
138 }
139 if targetIndex == -1 {
140 return fmt.Errorf("message not found: %s", messageID)
141 }
142 for i := len(allMessages) - 1; i >= targetIndex; i-- {
143 if err := s.Delete(ctx, allMessages[i].ID); err != nil {
144 return err
145 }
146 }
147 return nil
148}
149
150func (s *service) Update(ctx context.Context, message Message) error {
151 parts, err := marshalParts(message.Parts)
152 if err != nil {
153 return err
154 }
155 finishedAt := sql.NullInt64{}
156 if f := message.FinishPart(); f != nil {
157 finishedAt.Int64 = f.Time
158 finishedAt.Valid = true
159 }
160 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
161 ID: message.ID,
162 Parts: string(parts),
163 FinishedAt: finishedAt,
164 })
165 if err != nil {
166 return err
167 }
168 message.UpdatedAt = time.Now().Unix()
169 // Clone the message before publishing to avoid race conditions with
170 // concurrent modifications to the Parts slice.
171 s.Publish(pubsub.UpdatedEvent, message.Clone())
172 return nil
173}
174
175func (s *service) Get(ctx context.Context, id string) (Message, error) {
176 dbMessage, err := s.q.GetMessage(ctx, id)
177 if err != nil {
178 return Message{}, err
179 }
180 return s.fromDBItem(dbMessage)
181}
182
183func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
184 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
185 if err != nil {
186 return nil, err
187 }
188 messages := make([]Message, len(dbMessages))
189 for i, dbMessage := range dbMessages {
190 messages[i], err = s.fromDBItem(dbMessage)
191 if err != nil {
192 return nil, err
193 }
194 }
195 return messages, nil
196}
197
198func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
199 dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
200 if err != nil {
201 return nil, err
202 }
203 messages := make([]Message, len(dbMessages))
204 for i, dbMessage := range dbMessages {
205 messages[i], err = s.fromDBItem(dbMessage)
206 if err != nil {
207 return nil, err
208 }
209 }
210 return messages, nil
211}
212
213func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
214 dbMessages, err := s.q.ListAllUserMessages(ctx)
215 if err != nil {
216 return nil, err
217 }
218 messages := make([]Message, len(dbMessages))
219 for i, dbMessage := range dbMessages {
220 messages[i], err = s.fromDBItem(dbMessage)
221 if err != nil {
222 return nil, err
223 }
224 }
225 return messages, nil
226}
227
228func (s *service) fromDBItem(item db.Message) (Message, error) {
229 parts, err := unmarshalParts([]byte(item.Parts))
230 if err != nil {
231 return Message{}, err
232 }
233 return Message{
234 ID: item.ID,
235 SessionID: item.SessionID,
236 Role: MessageRole(item.Role),
237 Parts: parts,
238 Model: item.Model.String,
239 Provider: item.Provider.String,
240 CreatedAt: item.CreatedAt,
241 UpdatedAt: item.UpdatedAt,
242 IsSummaryMessage: item.IsSummaryMessage != 0,
243 }, nil
244}
245
246type partType string
247
248const (
249 reasoningType partType = "reasoning"
250 textType partType = "text"
251 imageURLType partType = "image_url"
252 binaryType partType = "binary"
253 toolCallType partType = "tool_call"
254 toolResultType partType = "tool_result"
255 finishType partType = "finish"
256)
257
258type partWrapper struct {
259 Type partType `json:"type"`
260 Data ContentPart `json:"data"`
261}
262
263func marshalParts(parts []ContentPart) ([]byte, error) {
264 wrappedParts := make([]partWrapper, len(parts))
265
266 for i, part := range parts {
267 var typ partType
268
269 switch part.(type) {
270 case ReasoningContent:
271 typ = reasoningType
272 case TextContent:
273 typ = textType
274 case ImageURLContent:
275 typ = imageURLType
276 case BinaryContent:
277 typ = binaryType
278 case ToolCall:
279 typ = toolCallType
280 case ToolResult:
281 typ = toolResultType
282 case Finish:
283 typ = finishType
284 default:
285 return nil, fmt.Errorf("unknown part type: %T", part)
286 }
287
288 wrappedParts[i] = partWrapper{
289 Type: typ,
290 Data: part,
291 }
292 }
293 return json.Marshal(wrappedParts)
294}
295
296func unmarshalParts(data []byte) ([]ContentPart, error) {
297 temp := []json.RawMessage{}
298
299 if err := json.Unmarshal(data, &temp); err != nil {
300 return nil, err
301 }
302
303 parts := make([]ContentPart, 0)
304
305 for _, rawPart := range temp {
306 var wrapper struct {
307 Type partType `json:"type"`
308 Data json.RawMessage `json:"data"`
309 }
310
311 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
312 return nil, err
313 }
314
315 switch wrapper.Type {
316 case reasoningType:
317 part := ReasoningContent{}
318 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
319 return nil, err
320 }
321 parts = append(parts, part)
322 case textType:
323 part := TextContent{}
324 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
325 return nil, err
326 }
327 parts = append(parts, part)
328 case imageURLType:
329 part := ImageURLContent{}
330 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
331 return nil, err
332 }
333 parts = append(parts, part)
334 case binaryType:
335 part := BinaryContent{}
336 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
337 return nil, err
338 }
339 parts = append(parts, part)
340 case toolCallType:
341 part := ToolCall{}
342 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
343 return nil, err
344 }
345 parts = append(parts, part)
346 case toolResultType:
347 part := ToolResult{}
348 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
349 return nil, err
350 }
351 parts = append(parts, part)
352 case finishType:
353 part := Finish{}
354 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
355 return nil, err
356 }
357 parts = append(parts, part)
358 default:
359 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
360 }
361 }
362
363 return parts, nil
364}