1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Subscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 Delete(ctx context.Context, id string) error
30 DeleteSessionMessages(ctx context.Context, sessionID string) error
31}
32
33type service struct {
34 *pubsub.Broker[Message]
35 q db.Querier
36}
37
38func NewService(q db.Querier) Service {
39 return &service{
40 Broker: pubsub.NewBroker[Message](),
41 q: q,
42 }
43}
44
45func (s *service) Delete(ctx context.Context, id string) error {
46 message, err := s.Get(ctx, id)
47 if err != nil {
48 return err
49 }
50 err = s.q.DeleteMessage(ctx, message.ID)
51 if err != nil {
52 return err
53 }
54 // Clone the message before publishing to avoid race conditions with
55 // concurrent modifications to the Parts slice.
56 s.Publish(pubsub.DeletedEvent, message.Clone())
57 return nil
58}
59
60func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
61 if params.Role != Assistant {
62 params.Parts = append(params.Parts, Finish{
63 Reason: "stop",
64 })
65 }
66 partsJSON, err := marshallParts(params.Parts)
67 if err != nil {
68 return Message{}, err
69 }
70 isSummary := int64(0)
71 if params.IsSummaryMessage {
72 isSummary = 1
73 }
74 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
75 ID: uuid.New().String(),
76 SessionID: sessionID,
77 Role: string(params.Role),
78 Parts: string(partsJSON),
79 Model: sql.NullString{String: string(params.Model), Valid: true},
80 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
81 IsSummaryMessage: isSummary,
82 })
83 if err != nil {
84 return Message{}, err
85 }
86 message, err := s.fromDBItem(dbMessage)
87 if err != nil {
88 return Message{}, err
89 }
90 // Clone the message before publishing to avoid race conditions with
91 // concurrent modifications to the Parts slice.
92 s.Publish(pubsub.CreatedEvent, message.Clone())
93 return message, nil
94}
95
96func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
97 messages, err := s.List(ctx, sessionID)
98 if err != nil {
99 return err
100 }
101 for _, message := range messages {
102 if message.SessionID == sessionID {
103 err = s.Delete(ctx, message.ID)
104 if err != nil {
105 return err
106 }
107 }
108 }
109 return nil
110}
111
112func (s *service) Update(ctx context.Context, message Message) error {
113 parts, err := marshallParts(message.Parts)
114 if err != nil {
115 return err
116 }
117 finishedAt := sql.NullInt64{}
118 if f := message.FinishPart(); f != nil {
119 finishedAt.Int64 = f.Time
120 finishedAt.Valid = true
121 }
122 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
123 ID: message.ID,
124 Parts: string(parts),
125 FinishedAt: finishedAt,
126 })
127 if err != nil {
128 return err
129 }
130 message.UpdatedAt = time.Now().Unix()
131 // Clone the message before publishing to avoid race conditions with
132 // concurrent modifications to the Parts slice.
133 s.Publish(pubsub.UpdatedEvent, message.Clone())
134 return nil
135}
136
137func (s *service) Get(ctx context.Context, id string) (Message, error) {
138 dbMessage, err := s.q.GetMessage(ctx, id)
139 if err != nil {
140 return Message{}, err
141 }
142 return s.fromDBItem(dbMessage)
143}
144
145func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
146 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
147 if err != nil {
148 return nil, err
149 }
150 messages := make([]Message, len(dbMessages))
151 for i, dbMessage := range dbMessages {
152 messages[i], err = s.fromDBItem(dbMessage)
153 if err != nil {
154 return nil, err
155 }
156 }
157 return messages, nil
158}
159
160func (s *service) fromDBItem(item db.Message) (Message, error) {
161 parts, err := unmarshallParts([]byte(item.Parts))
162 if err != nil {
163 return Message{}, err
164 }
165 return Message{
166 ID: item.ID,
167 SessionID: item.SessionID,
168 Role: MessageRole(item.Role),
169 Parts: parts,
170 Model: item.Model.String,
171 Provider: item.Provider.String,
172 CreatedAt: item.CreatedAt,
173 UpdatedAt: item.UpdatedAt,
174 IsSummaryMessage: item.IsSummaryMessage != 0,
175 }, nil
176}
177
178type partType string
179
180const (
181 reasoningType partType = "reasoning"
182 textType partType = "text"
183 imageURLType partType = "image_url"
184 binaryType partType = "binary"
185 toolCallType partType = "tool_call"
186 toolResultType partType = "tool_result"
187 finishType partType = "finish"
188)
189
190type partWrapper struct {
191 Type partType `json:"type"`
192 Data ContentPart `json:"data"`
193}
194
195func marshallParts(parts []ContentPart) ([]byte, error) {
196 wrappedParts := make([]partWrapper, len(parts))
197
198 for i, part := range parts {
199 var typ partType
200
201 switch part.(type) {
202 case ReasoningContent:
203 typ = reasoningType
204 case TextContent:
205 typ = textType
206 case ImageURLContent:
207 typ = imageURLType
208 case BinaryContent:
209 typ = binaryType
210 case ToolCall:
211 typ = toolCallType
212 case ToolResult:
213 typ = toolResultType
214 case Finish:
215 typ = finishType
216 default:
217 return nil, fmt.Errorf("unknown part type: %T", part)
218 }
219
220 wrappedParts[i] = partWrapper{
221 Type: typ,
222 Data: part,
223 }
224 }
225 return json.Marshal(wrappedParts)
226}
227
228func unmarshallParts(data []byte) ([]ContentPart, error) {
229 temp := []json.RawMessage{}
230
231 if err := json.Unmarshal(data, &temp); err != nil {
232 return nil, err
233 }
234
235 parts := make([]ContentPart, 0)
236
237 for _, rawPart := range temp {
238 var wrapper struct {
239 Type partType `json:"type"`
240 Data json.RawMessage `json:"data"`
241 }
242
243 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
244 return nil, err
245 }
246
247 switch wrapper.Type {
248 case reasoningType:
249 part := ReasoningContent{}
250 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
251 return nil, err
252 }
253 parts = append(parts, part)
254 case textType:
255 part := TextContent{}
256 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
257 return nil, err
258 }
259 parts = append(parts, part)
260 case imageURLType:
261 part := ImageURLContent{}
262 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
263 return nil, err
264 }
265 parts = append(parts, part)
266 case binaryType:
267 part := BinaryContent{}
268 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
269 return nil, err
270 }
271 parts = append(parts, part)
272 case toolCallType:
273 part := ToolCall{}
274 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
275 return nil, err
276 }
277 parts = append(parts, part)
278 case toolResultType:
279 part := ToolResult{}
280 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
281 return nil, err
282 }
283 parts = append(parts, part)
284 case finishType:
285 part := Finish{}
286 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
287 return nil, err
288 }
289 parts = append(parts, part)
290 default:
291 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
292 }
293 }
294
295 return parts, nil
296}