1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21 HookOutputs []HookOutput
22}
23
24type Service interface {
25 pubsub.Suscriber[Message]
26 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
27 Update(ctx context.Context, message Message) error
28 Get(ctx context.Context, id string) (Message, error)
29 List(ctx context.Context, sessionID string) ([]Message, error)
30 Delete(ctx context.Context, id string) error
31 DeleteSessionMessages(ctx context.Context, sessionID string) error
32}
33
34type service struct {
35 *pubsub.Broker[Message]
36 q db.Querier
37}
38
39func NewService(q db.Querier) Service {
40 return &service{
41 Broker: pubsub.NewBroker[Message](),
42 q: q,
43 }
44}
45
46func (s *service) Delete(ctx context.Context, id string) error {
47 message, err := s.Get(ctx, id)
48 if err != nil {
49 return err
50 }
51 err = s.q.DeleteMessage(ctx, message.ID)
52 if err != nil {
53 return err
54 }
55 s.Publish(pubsub.DeletedEvent, message)
56 return nil
57}
58
59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
60 if params.Role != Assistant {
61 params.Parts = append(params.Parts, Finish{
62 Reason: "stop",
63 })
64 }
65 partsJSON, err := marshallParts(params.Parts)
66 if err != nil {
67 return Message{}, err
68 }
69 isSummary := int64(0)
70 if params.IsSummaryMessage {
71 isSummary = 1
72 }
73 hookOutputsJSON, err := json.Marshal(params.HookOutputs)
74 if err != nil {
75 return Message{}, err
76 }
77 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
78 ID: uuid.New().String(),
79 SessionID: sessionID,
80 Role: string(params.Role),
81 Parts: string(partsJSON),
82 Model: sql.NullString{String: string(params.Model), Valid: true},
83 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
84 IsSummaryMessage: isSummary,
85 HookOutputs: string(hookOutputsJSON),
86 })
87 if err != nil {
88 return Message{}, err
89 }
90 message, err := s.fromDBItem(dbMessage)
91 if err != nil {
92 return Message{}, err
93 }
94 s.Publish(pubsub.CreatedEvent, message)
95 return message, nil
96}
97
98func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
99 messages, err := s.List(ctx, sessionID)
100 if err != nil {
101 return err
102 }
103 for _, message := range messages {
104 if message.SessionID == sessionID {
105 err = s.Delete(ctx, message.ID)
106 if err != nil {
107 return err
108 }
109 }
110 }
111 return nil
112}
113
114func (s *service) Update(ctx context.Context, message Message) error {
115 parts, err := marshallParts(message.Parts)
116 if err != nil {
117 return err
118 }
119 finishedAt := sql.NullInt64{}
120 if f := message.FinishPart(); f != nil {
121 finishedAt.Int64 = f.Time
122 finishedAt.Valid = true
123 }
124 hookOutputsJSON, err := json.Marshal(message.HookOutputs)
125 if err != nil {
126 return err
127 }
128 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
129 ID: message.ID,
130 Parts: string(parts),
131 FinishedAt: finishedAt,
132 HookOutputs: string(hookOutputsJSON),
133 })
134 if err != nil {
135 return err
136 }
137 message.UpdatedAt = time.Now().Unix()
138 s.Publish(pubsub.UpdatedEvent, message)
139 return nil
140}
141
142func (s *service) Get(ctx context.Context, id string) (Message, error) {
143 dbMessage, err := s.q.GetMessage(ctx, id)
144 if err != nil {
145 return Message{}, err
146 }
147 return s.fromDBItem(dbMessage)
148}
149
150func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
151 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
152 if err != nil {
153 return nil, err
154 }
155 messages := make([]Message, len(dbMessages))
156 for i, dbMessage := range dbMessages {
157 messages[i], err = s.fromDBItem(dbMessage)
158 if err != nil {
159 return nil, err
160 }
161 }
162 return messages, nil
163}
164
165func (s *service) fromDBItem(item db.Message) (Message, error) {
166 parts, err := unmarshallParts([]byte(item.Parts))
167 if err != nil {
168 return Message{}, err
169 }
170 var hookOutputs []HookOutput
171 if item.HookOutputs != "" {
172 if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil {
173 return Message{}, err
174 }
175 }
176 return Message{
177 ID: item.ID,
178 SessionID: item.SessionID,
179 Role: MessageRole(item.Role),
180 Parts: parts,
181 Model: item.Model.String,
182 Provider: item.Provider.String,
183 CreatedAt: item.CreatedAt,
184 UpdatedAt: item.UpdatedAt,
185 IsSummaryMessage: item.IsSummaryMessage != 0,
186 HookOutputs: hookOutputs,
187 }, nil
188}
189
190type partType string
191
192const (
193 reasoningType partType = "reasoning"
194 textType partType = "text"
195 imageURLType partType = "image_url"
196 binaryType partType = "binary"
197 toolCallType partType = "tool_call"
198 toolResultType partType = "tool_result"
199 finishType partType = "finish"
200)
201
202type partWrapper struct {
203 Type partType `json:"type"`
204 Data ContentPart `json:"data"`
205}
206
207func marshallParts(parts []ContentPart) ([]byte, error) {
208 wrappedParts := make([]partWrapper, len(parts))
209
210 for i, part := range parts {
211 var typ partType
212
213 switch part.(type) {
214 case ReasoningContent:
215 typ = reasoningType
216 case TextContent:
217 typ = textType
218 case ImageURLContent:
219 typ = imageURLType
220 case BinaryContent:
221 typ = binaryType
222 case ToolCall:
223 typ = toolCallType
224 case ToolResult:
225 typ = toolResultType
226 case Finish:
227 typ = finishType
228 default:
229 return nil, fmt.Errorf("unknown part type: %T", part)
230 }
231
232 wrappedParts[i] = partWrapper{
233 Type: typ,
234 Data: part,
235 }
236 }
237 return json.Marshal(wrappedParts)
238}
239
240func unmarshallParts(data []byte) ([]ContentPart, error) {
241 temp := []json.RawMessage{}
242
243 if err := json.Unmarshal(data, &temp); err != nil {
244 return nil, err
245 }
246
247 parts := make([]ContentPart, 0)
248
249 for _, rawPart := range temp {
250 var wrapper struct {
251 Type partType `json:"type"`
252 Data json.RawMessage `json:"data"`
253 }
254
255 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
256 return nil, err
257 }
258
259 switch wrapper.Type {
260 case reasoningType:
261 part := ReasoningContent{}
262 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
263 return nil, err
264 }
265 parts = append(parts, part)
266 case textType:
267 part := TextContent{}
268 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
269 return nil, err
270 }
271 parts = append(parts, part)
272 case imageURLType:
273 part := ImageURLContent{}
274 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
275 return nil, err
276 }
277 parts = append(parts, part)
278 case binaryType:
279 part := BinaryContent{}
280 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
281 return nil, err
282 }
283 parts = append(parts, part)
284 case toolCallType:
285 part := ToolCall{}
286 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
287 return nil, err
288 }
289 parts = append(parts, part)
290 case toolResultType:
291 part := ToolResult{}
292 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
293 return nil, err
294 }
295 parts = append(parts, part)
296 case finishType:
297 part := Finish{}
298 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
299 return nil, err
300 }
301 parts = append(parts, part)
302 default:
303 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
304 }
305 }
306
307 return parts, nil
308}