1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21 Metadata MessageMetadata
22}
23
24type Service interface {
25 pubsub.Suscriber[Message]
26 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
27 Update(ctx context.Context, message Message) error
28 Get(ctx context.Context, id string) (Message, error)
29 List(ctx context.Context, sessionID string) ([]Message, error)
30 Delete(ctx context.Context, id string) error
31 DeleteSessionMessages(ctx context.Context, sessionID string) error
32}
33
34type service struct {
35 *pubsub.Broker[Message]
36 q db.Querier
37}
38
39func NewService(q db.Querier) Service {
40 return &service{
41 Broker: pubsub.NewBroker[Message](),
42 q: q,
43 }
44}
45
46func (s *service) Delete(ctx context.Context, id string) error {
47 message, err := s.Get(ctx, id)
48 if err != nil {
49 return err
50 }
51 err = s.q.DeleteMessage(ctx, message.ID)
52 if err != nil {
53 return err
54 }
55 s.Publish(pubsub.DeletedEvent, message)
56 return nil
57}
58
59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
60 if params.Role != Assistant {
61 params.Parts = append(params.Parts, Finish{
62 Reason: "stop",
63 })
64 }
65 partsJSON, err := marshallParts(params.Parts)
66 if err != nil {
67 return Message{}, err
68 }
69 metadataJSON, err := json.Marshal(params.Metadata)
70 if err != nil {
71 return Message{}, err
72 }
73 isSummary := int64(0)
74 if params.IsSummaryMessage {
75 isSummary = 1
76 }
77 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
78 ID: uuid.New().String(),
79 SessionID: sessionID,
80 Role: string(params.Role),
81 Parts: string(partsJSON),
82 Model: sql.NullString{String: string(params.Model), Valid: true},
83 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
84 IsSummaryMessage: isSummary,
85 Metadata: string(metadataJSON),
86 })
87 if err != nil {
88 return Message{}, err
89 }
90 message, err := s.fromDBItem(dbMessage)
91 if err != nil {
92 return Message{}, err
93 }
94 s.Publish(pubsub.CreatedEvent, message)
95 return message, nil
96}
97
98func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
99 messages, err := s.List(ctx, sessionID)
100 if err != nil {
101 return err
102 }
103 for _, message := range messages {
104 if message.SessionID == sessionID {
105 err = s.Delete(ctx, message.ID)
106 if err != nil {
107 return err
108 }
109 }
110 }
111 return nil
112}
113
114func (s *service) Update(ctx context.Context, message Message) error {
115 parts, err := marshallParts(message.Parts)
116 if err != nil {
117 return err
118 }
119 metadata, err := json.Marshal(message.Metadata)
120 if err != nil {
121 return err
122 }
123 finishedAt := sql.NullInt64{}
124 if f := message.FinishPart(); f != nil {
125 finishedAt.Int64 = f.Time
126 finishedAt.Valid = true
127 }
128 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
129 ID: message.ID,
130 Parts: string(parts),
131 Metadata: string(metadata),
132 FinishedAt: finishedAt,
133 })
134 if err != nil {
135 return err
136 }
137 message.UpdatedAt = time.Now().Unix()
138 s.Publish(pubsub.UpdatedEvent, message)
139 return nil
140}
141
142func (s *service) Get(ctx context.Context, id string) (Message, error) {
143 dbMessage, err := s.q.GetMessage(ctx, id)
144 if err != nil {
145 return Message{}, err
146 }
147 return s.fromDBItem(dbMessage)
148}
149
150func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
151 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
152 if err != nil {
153 return nil, err
154 }
155 messages := make([]Message, len(dbMessages))
156 for i, dbMessage := range dbMessages {
157 messages[i], err = s.fromDBItem(dbMessage)
158 if err != nil {
159 return nil, err
160 }
161 }
162 return messages, nil
163}
164
165func (s *service) fromDBItem(item db.Message) (Message, error) {
166 parts, err := unmarshallParts([]byte(item.Parts))
167 if err != nil {
168 return Message{}, err
169 }
170 var metadata MessageMetadata
171 if err := json.Unmarshal([]byte(item.Metadata), &metadata); err != nil {
172 return Message{}, err
173 }
174 return Message{
175 ID: item.ID,
176 SessionID: item.SessionID,
177 Role: MessageRole(item.Role),
178 Parts: parts,
179 Model: item.Model.String,
180 Provider: item.Provider.String,
181 CreatedAt: item.CreatedAt,
182 UpdatedAt: item.UpdatedAt,
183 IsSummaryMessage: item.IsSummaryMessage != 0,
184 Metadata: metadata,
185 }, nil
186}
187
188type partType string
189
190const (
191 reasoningType partType = "reasoning"
192 textType partType = "text"
193 imageURLType partType = "image_url"
194 binaryType partType = "binary"
195 toolCallType partType = "tool_call"
196 toolResultType partType = "tool_result"
197 finishType partType = "finish"
198)
199
200type partWrapper struct {
201 Type partType `json:"type"`
202 Data ContentPart `json:"data"`
203}
204
205func marshallParts(parts []ContentPart) ([]byte, error) {
206 wrappedParts := make([]partWrapper, len(parts))
207
208 for i, part := range parts {
209 var typ partType
210
211 switch part.(type) {
212 case ReasoningContent:
213 typ = reasoningType
214 case TextContent:
215 typ = textType
216 case ImageURLContent:
217 typ = imageURLType
218 case BinaryContent:
219 typ = binaryType
220 case ToolCall:
221 typ = toolCallType
222 case ToolResult:
223 typ = toolResultType
224 case Finish:
225 typ = finishType
226 default:
227 return nil, fmt.Errorf("unknown part type: %T", part)
228 }
229
230 wrappedParts[i] = partWrapper{
231 Type: typ,
232 Data: part,
233 }
234 }
235 return json.Marshal(wrappedParts)
236}
237
238func unmarshallParts(data []byte) ([]ContentPart, error) {
239 temp := []json.RawMessage{}
240
241 if err := json.Unmarshal(data, &temp); err != nil {
242 return nil, err
243 }
244
245 parts := make([]ContentPart, 0)
246
247 for _, rawPart := range temp {
248 var wrapper struct {
249 Type partType `json:"type"`
250 Data json.RawMessage `json:"data"`
251 }
252
253 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
254 return nil, err
255 }
256
257 switch wrapper.Type {
258 case reasoningType:
259 part := ReasoningContent{}
260 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
261 return nil, err
262 }
263 parts = append(parts, part)
264 case textType:
265 part := TextContent{}
266 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
267 return nil, err
268 }
269 parts = append(parts, part)
270 case imageURLType:
271 part := ImageURLContent{}
272 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
273 return nil, err
274 }
275 parts = append(parts, part)
276 case binaryType:
277 part := BinaryContent{}
278 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
279 return nil, err
280 }
281 parts = append(parts, part)
282 case toolCallType:
283 part := ToolCall{}
284 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
285 return nil, err
286 }
287 parts = append(parts, part)
288 case toolResultType:
289 part := ToolResult{}
290 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
291 return nil, err
292 }
293 parts = append(parts, part)
294 case finishType:
295 part := Finish{}
296 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
297 return nil, err
298 }
299 parts = append(parts, part)
300 default:
301 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
302 }
303 }
304
305 return parts, nil
306}