1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Suscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 FullList(ctx context.Context) ([]Message, error)
30 Delete(ctx context.Context, id string) error
31 DeleteSessionMessages(ctx context.Context, sessionID string) error
32}
33
34type service struct {
35 *pubsub.Broker[Message]
36 q db.Querier
37}
38
39func NewService(q db.Querier) Service {
40 return &service{
41 Broker: pubsub.NewBroker[Message](),
42 q: q,
43 }
44}
45
46func (s *service) Delete(ctx context.Context, id string) error {
47 message, err := s.Get(ctx, id)
48 if err != nil {
49 return err
50 }
51 err = s.q.DeleteMessage(ctx, message.ID)
52 if err != nil {
53 return err
54 }
55 s.Publish(pubsub.DeletedEvent, message)
56 return nil
57}
58
59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
60 if params.Role != Assistant {
61 params.Parts = append(params.Parts, Finish{
62 Reason: "stop",
63 })
64 }
65 partsJSON, err := marshallParts(params.Parts)
66 if err != nil {
67 return Message{}, err
68 }
69 isSummary := int64(0)
70 if params.IsSummaryMessage {
71 isSummary = 1
72 }
73 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
74 ID: uuid.New().String(),
75 SessionID: sessionID,
76 Role: string(params.Role),
77 Parts: string(partsJSON),
78 Model: sql.NullString{String: string(params.Model), Valid: true},
79 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
80 IsSummaryMessage: isSummary,
81 })
82 if err != nil {
83 return Message{}, err
84 }
85 message, err := fromDBItem(dbMessage)
86 if err != nil {
87 return Message{}, err
88 }
89 s.Publish(pubsub.CreatedEvent, message)
90 return message, nil
91}
92
93func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
94 messages, err := s.List(ctx, sessionID)
95 if err != nil {
96 return err
97 }
98 for _, message := range messages {
99 if message.SessionID == sessionID {
100 err = s.Delete(ctx, message.ID)
101 if err != nil {
102 return err
103 }
104 }
105 }
106 return nil
107}
108
109func (s *service) Update(ctx context.Context, message Message) error {
110 parts, err := marshallParts(message.Parts)
111 if err != nil {
112 return err
113 }
114 finishedAt := sql.NullInt64{}
115 if f := message.FinishPart(); f != nil {
116 finishedAt.Int64 = f.Time
117 finishedAt.Valid = true
118 }
119 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
120 ID: message.ID,
121 Parts: string(parts),
122 FinishedAt: finishedAt,
123 })
124 if err != nil {
125 return err
126 }
127 message.UpdatedAt = time.Now().Unix()
128 s.Publish(pubsub.UpdatedEvent, message)
129 return nil
130}
131
132func (s *service) Get(ctx context.Context, id string) (Message, error) {
133 dbMessage, err := s.q.GetMessage(ctx, id)
134 if err != nil {
135 return Message{}, err
136 }
137 return fromDBItem(dbMessage)
138}
139
140func convertDBMessagesToMessages(dbMessages []db.Message, err error) ([]Message, error) {
141 if err != nil {
142 return nil, err
143 }
144 messages := make([]Message, len(dbMessages))
145 for i, dbMessage := range dbMessages {
146 messages[i], err = fromDBItem(dbMessage)
147 if err != nil {
148 return nil, err
149 }
150 }
151 return messages, nil
152}
153
154func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
155 return convertDBMessagesToMessages(s.q.ListMessagesBySession(ctx, sessionID))
156}
157
158func (s *service) FullList(ctx context.Context) ([]Message, error) {
159 return convertDBMessagesToMessages(s.q.ListAllMessages(ctx))
160}
161
162func fromDBItem(item db.Message) (Message, error) {
163 parts, err := unmarshallParts([]byte(item.Parts))
164 if err != nil {
165 return Message{}, err
166 }
167 return Message{
168 ID: item.ID,
169 SessionID: item.SessionID,
170 Role: MessageRole(item.Role),
171 Parts: parts,
172 Model: item.Model.String,
173 Provider: item.Provider.String,
174 CreatedAt: item.CreatedAt,
175 UpdatedAt: item.UpdatedAt,
176 IsSummaryMessage: item.IsSummaryMessage != 0,
177 }, nil
178}
179
180type partType string
181
182const (
183 reasoningType partType = "reasoning"
184 textType partType = "text"
185 imageURLType partType = "image_url"
186 binaryType partType = "binary"
187 toolCallType partType = "tool_call"
188 toolResultType partType = "tool_result"
189 finishType partType = "finish"
190)
191
192type partWrapper struct {
193 Type partType `json:"type"`
194 Data ContentPart `json:"data"`
195}
196
197func marshallParts(parts []ContentPart) ([]byte, error) {
198 wrappedParts := make([]partWrapper, len(parts))
199
200 for i, part := range parts {
201 var typ partType
202
203 switch part.(type) {
204 case ReasoningContent:
205 typ = reasoningType
206 case TextContent:
207 typ = textType
208 case ImageURLContent:
209 typ = imageURLType
210 case BinaryContent:
211 typ = binaryType
212 case ToolCall:
213 typ = toolCallType
214 case ToolResult:
215 typ = toolResultType
216 case Finish:
217 typ = finishType
218 default:
219 return nil, fmt.Errorf("unknown part type: %T", part)
220 }
221
222 wrappedParts[i] = partWrapper{
223 Type: typ,
224 Data: part,
225 }
226 }
227 return json.Marshal(wrappedParts)
228}
229
230func unmarshallParts(data []byte) ([]ContentPart, error) {
231 temp := []json.RawMessage{}
232
233 if err := json.Unmarshal(data, &temp); err != nil {
234 return nil, err
235 }
236
237 parts := make([]ContentPart, 0)
238
239 for _, rawPart := range temp {
240 var wrapper struct {
241 Type partType `json:"type"`
242 Data json.RawMessage `json:"data"`
243 }
244
245 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
246 return nil, err
247 }
248
249 switch wrapper.Type {
250 case reasoningType:
251 part := ReasoningContent{}
252 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253 return nil, err
254 }
255 parts = append(parts, part)
256 case textType:
257 part := TextContent{}
258 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
259 return nil, err
260 }
261 parts = append(parts, part)
262 case imageURLType:
263 part := ImageURLContent{}
264 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
265 return nil, err
266 }
267 parts = append(parts, part)
268 case binaryType:
269 part := BinaryContent{}
270 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
271 return nil, err
272 }
273 parts = append(parts, part)
274 case toolCallType:
275 part := ToolCall{}
276 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
277 return nil, err
278 }
279 parts = append(parts, part)
280 case toolResultType:
281 part := ToolResult{}
282 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
283 return nil, err
284 }
285 parts = append(parts, part)
286 case finishType:
287 part := Finish{}
288 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
289 return nil, err
290 }
291 parts = append(parts, part)
292 default:
293 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
294 }
295 }
296
297 return parts, nil
298}