1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Subscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
30 ListAllUserMessages(ctx context.Context) ([]Message, error)
31 Delete(ctx context.Context, id string) error
32 DeleteSessionMessages(ctx context.Context, sessionID string) error
33}
34
35type service struct {
36 *pubsub.Broker[Message]
37 q db.Querier
38}
39
40func NewService(q db.Querier) Service {
41 return &service{
42 Broker: pubsub.NewBroker[Message](),
43 q: q,
44 }
45}
46
47func (s *service) Delete(ctx context.Context, id string) error {
48 message, err := s.Get(ctx, id)
49 if err != nil {
50 return err
51 }
52 err = s.q.DeleteMessage(ctx, message.ID)
53 if err != nil {
54 return err
55 }
56 // Clone the message before publishing to avoid race conditions with
57 // concurrent modifications to the Parts slice.
58 s.Publish(pubsub.DeletedEvent, message.Clone())
59 return nil
60}
61
62func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
63 if params.Role != Assistant {
64 params.Parts = append(params.Parts, Finish{
65 Reason: "stop",
66 })
67 }
68 partsJSON, err := marshalParts(params.Parts)
69 if err != nil {
70 return Message{}, err
71 }
72 isSummary := int64(0)
73 if params.IsSummaryMessage {
74 isSummary = 1
75 }
76 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
77 ID: uuid.New().String(),
78 SessionID: sessionID,
79 Role: string(params.Role),
80 Parts: string(partsJSON),
81 Model: sql.NullString{String: string(params.Model), Valid: true},
82 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
83 IsSummaryMessage: isSummary,
84 })
85 if err != nil {
86 return Message{}, err
87 }
88 message, err := s.fromDBItem(dbMessage)
89 if err != nil {
90 return Message{}, err
91 }
92 // Clone the message before publishing to avoid race conditions with
93 // concurrent modifications to the Parts slice.
94 s.Publish(pubsub.CreatedEvent, message.Clone())
95 return message, nil
96}
97
98func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
99 messages, err := s.List(ctx, sessionID)
100 if err != nil {
101 return err
102 }
103 for _, message := range messages {
104 if message.SessionID == sessionID {
105 err = s.Delete(ctx, message.ID)
106 if err != nil {
107 return err
108 }
109 }
110 }
111 return nil
112}
113
114func (s *service) Update(ctx context.Context, message Message) error {
115 parts, err := marshalParts(message.Parts)
116 if err != nil {
117 return err
118 }
119 finishedAt := sql.NullInt64{}
120 if f := message.FinishPart(); f != nil {
121 finishedAt.Int64 = f.Time
122 finishedAt.Valid = true
123 }
124 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
125 ID: message.ID,
126 Parts: string(parts),
127 FinishedAt: finishedAt,
128 })
129 if err != nil {
130 return err
131 }
132 message.UpdatedAt = time.Now().Unix()
133 // Clone the message before publishing to avoid race conditions with
134 // concurrent modifications to the Parts slice.
135 s.Publish(pubsub.UpdatedEvent, message.Clone())
136 return nil
137}
138
139func (s *service) Get(ctx context.Context, id string) (Message, error) {
140 dbMessage, err := s.q.GetMessage(ctx, id)
141 if err != nil {
142 return Message{}, err
143 }
144 return s.fromDBItem(dbMessage)
145}
146
147func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
148 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
149 if err != nil {
150 return nil, err
151 }
152 messages := make([]Message, len(dbMessages))
153 for i, dbMessage := range dbMessages {
154 messages[i], err = s.fromDBItem(dbMessage)
155 if err != nil {
156 return nil, err
157 }
158 }
159 return messages, nil
160}
161
162func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
163 dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
164 if err != nil {
165 return nil, err
166 }
167 messages := make([]Message, len(dbMessages))
168 for i, dbMessage := range dbMessages {
169 messages[i], err = s.fromDBItem(dbMessage)
170 if err != nil {
171 return nil, err
172 }
173 }
174 return messages, nil
175}
176
177func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
178 dbMessages, err := s.q.ListAllUserMessages(ctx)
179 if err != nil {
180 return nil, err
181 }
182 messages := make([]Message, len(dbMessages))
183 for i, dbMessage := range dbMessages {
184 messages[i], err = s.fromDBItem(dbMessage)
185 if err != nil {
186 return nil, err
187 }
188 }
189 return messages, nil
190}
191
192func (s *service) fromDBItem(item db.Message) (Message, error) {
193 parts, err := unmarshalParts([]byte(item.Parts))
194 if err != nil {
195 return Message{}, err
196 }
197 return Message{
198 ID: item.ID,
199 SessionID: item.SessionID,
200 Role: MessageRole(item.Role),
201 Parts: parts,
202 Model: item.Model.String,
203 Provider: item.Provider.String,
204 CreatedAt: item.CreatedAt,
205 UpdatedAt: item.UpdatedAt,
206 IsSummaryMessage: item.IsSummaryMessage != 0,
207 }, nil
208}
209
210type partType string
211
212const (
213 reasoningType partType = "reasoning"
214 textType partType = "text"
215 imageURLType partType = "image_url"
216 binaryType partType = "binary"
217 toolCallType partType = "tool_call"
218 toolResultType partType = "tool_result"
219 finishType partType = "finish"
220)
221
222type partWrapper struct {
223 Type partType `json:"type"`
224 Data ContentPart `json:"data"`
225}
226
227func marshalParts(parts []ContentPart) ([]byte, error) {
228 wrappedParts := make([]partWrapper, len(parts))
229
230 for i, part := range parts {
231 var typ partType
232
233 switch part.(type) {
234 case ReasoningContent:
235 typ = reasoningType
236 case TextContent:
237 typ = textType
238 case ImageURLContent:
239 typ = imageURLType
240 case BinaryContent:
241 typ = binaryType
242 case ToolCall:
243 typ = toolCallType
244 case ToolResult:
245 typ = toolResultType
246 case Finish:
247 typ = finishType
248 default:
249 return nil, fmt.Errorf("unknown part type: %T", part)
250 }
251
252 wrappedParts[i] = partWrapper{
253 Type: typ,
254 Data: part,
255 }
256 }
257 return json.Marshal(wrappedParts)
258}
259
260func unmarshalParts(data []byte) ([]ContentPart, error) {
261 temp := []json.RawMessage{}
262
263 if err := json.Unmarshal(data, &temp); err != nil {
264 return nil, err
265 }
266
267 parts := make([]ContentPart, 0)
268
269 for _, rawPart := range temp {
270 var wrapper struct {
271 Type partType `json:"type"`
272 Data json.RawMessage `json:"data"`
273 }
274
275 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
276 return nil, err
277 }
278
279 switch wrapper.Type {
280 case reasoningType:
281 part := ReasoningContent{}
282 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
283 return nil, err
284 }
285 parts = append(parts, part)
286 case textType:
287 part := TextContent{}
288 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
289 return nil, err
290 }
291 parts = append(parts, part)
292 case imageURLType:
293 part := ImageURLContent{}
294 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
295 return nil, err
296 }
297 parts = append(parts, part)
298 case binaryType:
299 part := BinaryContent{}
300 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
301 return nil, err
302 }
303 parts = append(parts, part)
304 case toolCallType:
305 part := ToolCall{}
306 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
307 return nil, err
308 }
309 parts = append(parts, part)
310 case toolResultType:
311 part := ToolResult{}
312 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
313 return nil, err
314 }
315 parts = append(parts, part)
316 case finishType:
317 part := Finish{}
318 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
319 return nil, err
320 }
321 parts = append(parts, part)
322 default:
323 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
324 }
325 }
326
327 return parts, nil
328}