1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Suscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 FullList(ctx context.Context) ([]Message, error)
30 Delete(ctx context.Context, id string) error
31 DeleteSessionMessages(ctx context.Context, sessionID string) error
32}
33
34type service struct {
35 *pubsub.Broker[Message]
36 q db.Querier
37}
38
39func NewService(q db.Querier) Service {
40 return &service{
41 Broker: pubsub.NewBroker[Message](),
42 q: q,
43 }
44}
45
46func (s *service) Delete(ctx context.Context, id string) error {
47 message, err := s.Get(ctx, id)
48 if err != nil {
49 return err
50 }
51 err = s.q.DeleteMessage(ctx, message.ID)
52 if err != nil {
53 return err
54 }
55 s.Publish(pubsub.DeletedEvent, message)
56 return nil
57}
58
59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
60 if params.Role != Assistant {
61 params.Parts = append(params.Parts, Finish{
62 Reason: "stop",
63 })
64 }
65 partsJSON, err := marshallParts(params.Parts)
66 if err != nil {
67 return Message{}, err
68 }
69 isSummary := int64(0)
70 if params.IsSummaryMessage {
71 isSummary = 1
72 }
73 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
74 ID: uuid.New().String(),
75 SessionID: sessionID,
76 Role: string(params.Role),
77 Parts: string(partsJSON),
78 Model: sql.NullString{String: string(params.Model), Valid: true},
79 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
80 IsSummaryMessage: isSummary,
81 })
82 if err != nil {
83 return Message{}, err
84 }
85 message, err := s.fromDBItem(dbMessage)
86 if err != nil {
87 return Message{}, err
88 }
89 s.Publish(pubsub.CreatedEvent, message)
90 return message, nil
91}
92
93func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
94 messages, err := s.List(ctx, sessionID)
95 if err != nil {
96 return err
97 }
98 for _, message := range messages {
99 if message.SessionID == sessionID {
100 err = s.Delete(ctx, message.ID)
101 if err != nil {
102 return err
103 }
104 }
105 }
106 return nil
107}
108
109func (s *service) Update(ctx context.Context, message Message) error {
110 parts, err := marshallParts(message.Parts)
111 if err != nil {
112 return err
113 }
114 finishedAt := sql.NullInt64{}
115 if f := message.FinishPart(); f != nil {
116 finishedAt.Int64 = f.Time
117 finishedAt.Valid = true
118 }
119 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
120 ID: message.ID,
121 Parts: string(parts),
122 FinishedAt: finishedAt,
123 })
124 if err != nil {
125 return err
126 }
127 message.UpdatedAt = time.Now().Unix()
128 s.Publish(pubsub.UpdatedEvent, message)
129 return nil
130}
131
132func (s *service) Get(ctx context.Context, id string) (Message, error) {
133 dbMessage, err := s.q.GetMessage(ctx, id)
134 if err != nil {
135 return Message{}, err
136 }
137 return s.fromDBItem(dbMessage)
138}
139
140func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
141 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
142 if err != nil {
143 return nil, err
144 }
145 messages := make([]Message, len(dbMessages))
146 for i, dbMessage := range dbMessages {
147 messages[i], err = s.fromDBItem(dbMessage)
148 if err != nil {
149 return nil, err
150 }
151 }
152 return messages, nil
153}
154
155func (s *service) FullList(ctx context.Context) ([]Message, error) {
156 dbMessages, err := s.q.ListAllMessages(ctx)
157 return nil, nil
158}
159
160func (s *service) fromDBItem(item db.Message) (Message, error) {
161 parts, err := unmarshallParts([]byte(item.Parts))
162 if err != nil {
163 return Message{}, err
164 }
165 return Message{
166 ID: item.ID,
167 SessionID: item.SessionID,
168 Role: MessageRole(item.Role),
169 Parts: parts,
170 Model: item.Model.String,
171 Provider: item.Provider.String,
172 CreatedAt: item.CreatedAt,
173 UpdatedAt: item.UpdatedAt,
174 IsSummaryMessage: item.IsSummaryMessage != 0,
175 }, nil
176}
177
178type partType string
179
180const (
181 reasoningType partType = "reasoning"
182 textType partType = "text"
183 imageURLType partType = "image_url"
184 binaryType partType = "binary"
185 toolCallType partType = "tool_call"
186 toolResultType partType = "tool_result"
187 finishType partType = "finish"
188)
189
190type partWrapper struct {
191 Type partType `json:"type"`
192 Data ContentPart `json:"data"`
193}
194
195func marshallParts(parts []ContentPart) ([]byte, error) {
196 wrappedParts := make([]partWrapper, len(parts))
197
198 for i, part := range parts {
199 var typ partType
200
201 switch part.(type) {
202 case ReasoningContent:
203 typ = reasoningType
204 case TextContent:
205 typ = textType
206 case ImageURLContent:
207 typ = imageURLType
208 case BinaryContent:
209 typ = binaryType
210 case ToolCall:
211 typ = toolCallType
212 case ToolResult:
213 typ = toolResultType
214 case Finish:
215 typ = finishType
216 default:
217 return nil, fmt.Errorf("unknown part type: %T", part)
218 }
219
220 wrappedParts[i] = partWrapper{
221 Type: typ,
222 Data: part,
223 }
224 }
225 return json.Marshal(wrappedParts)
226}
227
228func unmarshallParts(data []byte) ([]ContentPart, error) {
229 temp := []json.RawMessage{}
230
231 if err := json.Unmarshal(data, &temp); err != nil {
232 return nil, err
233 }
234
235 parts := make([]ContentPart, 0)
236
237 for _, rawPart := range temp {
238 var wrapper struct {
239 Type partType `json:"type"`
240 Data json.RawMessage `json:"data"`
241 }
242
243 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
244 return nil, err
245 }
246
247 switch wrapper.Type {
248 case reasoningType:
249 part := ReasoningContent{}
250 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
251 return nil, err
252 }
253 parts = append(parts, part)
254 case textType:
255 part := TextContent{}
256 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
257 return nil, err
258 }
259 parts = append(parts, part)
260 case imageURLType:
261 part := ImageURLContent{}
262 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
263 return nil, err
264 }
265 parts = append(parts, part)
266 case binaryType:
267 part := BinaryContent{}
268 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
269 return nil, err
270 }
271 parts = append(parts, part)
272 case toolCallType:
273 part := ToolCall{}
274 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
275 return nil, err
276 }
277 parts = append(parts, part)
278 case toolResultType:
279 part := ToolResult{}
280 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
281 return nil, err
282 }
283 parts = append(parts, part)
284 case finishType:
285 part := Finish{}
286 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
287 return nil, err
288 }
289 parts = append(parts, part)
290 default:
291 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
292 }
293 }
294
295 return parts, nil
296}