1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20}
21
22type Service interface {
23 pubsub.Suscriber[Message]
24 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
25 Update(ctx context.Context, message Message) error
26 Get(ctx context.Context, id string) (Message, error)
27 List(ctx context.Context, sessionID string) ([]Message, error)
28 Delete(ctx context.Context, id string) error
29 DeleteSessionMessages(ctx context.Context, sessionID string) error
30}
31
32type service struct {
33 *pubsub.Broker[Message]
34 q db.Querier
35}
36
37func NewService(q db.Querier) Service {
38 return &service{
39 Broker: pubsub.NewBroker[Message](),
40 q: q,
41 }
42}
43
44func (s *service) Delete(ctx context.Context, id string) error {
45 message, err := s.Get(ctx, id)
46 if err != nil {
47 return err
48 }
49 err = s.q.DeleteMessage(ctx, message.ID)
50 if err != nil {
51 return err
52 }
53 s.Publish(pubsub.DeletedEvent, message)
54 return nil
55}
56
57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
58 if params.Role != Assistant {
59 params.Parts = append(params.Parts, Finish{
60 Reason: "stop",
61 })
62 }
63 partsJSON, err := marshallParts(params.Parts)
64 if err != nil {
65 return Message{}, err
66 }
67 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
68 ID: uuid.New().String(),
69 SessionID: sessionID,
70 Role: string(params.Role),
71 Parts: string(partsJSON),
72 Model: sql.NullString{String: string(params.Model), Valid: true},
73 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
74 })
75 if err != nil {
76 return Message{}, err
77 }
78 message, err := s.fromDBItem(dbMessage)
79 if err != nil {
80 return Message{}, err
81 }
82 s.Publish(pubsub.CreatedEvent, message)
83 return message, nil
84}
85
86func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
87 messages, err := s.List(ctx, sessionID)
88 if err != nil {
89 return err
90 }
91 for _, message := range messages {
92 if message.SessionID == sessionID {
93 err = s.Delete(ctx, message.ID)
94 if err != nil {
95 return err
96 }
97 }
98 }
99 return nil
100}
101
102func (s *service) Update(ctx context.Context, message Message) error {
103 parts, err := marshallParts(message.Parts)
104 if err != nil {
105 return err
106 }
107 finishedAt := sql.NullInt64{}
108 if f := message.FinishPart(); f != nil {
109 finishedAt.Int64 = f.Time
110 finishedAt.Valid = true
111 }
112 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
113 ID: message.ID,
114 Parts: string(parts),
115 FinishedAt: finishedAt,
116 })
117 if err != nil {
118 return err
119 }
120 message.UpdatedAt = time.Now().Unix()
121 s.Publish(pubsub.UpdatedEvent, message)
122 return nil
123}
124
125func (s *service) Get(ctx context.Context, id string) (Message, error) {
126 dbMessage, err := s.q.GetMessage(ctx, id)
127 if err != nil {
128 return Message{}, err
129 }
130 return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
134 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
135 if err != nil {
136 return nil, err
137 }
138 messages := make([]Message, len(dbMessages))
139 for i, dbMessage := range dbMessages {
140 messages[i], err = s.fromDBItem(dbMessage)
141 if err != nil {
142 return nil, err
143 }
144 }
145 return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149 parts, err := unmarshallParts([]byte(item.Parts))
150 if err != nil {
151 return Message{}, err
152 }
153 return Message{
154 ID: item.ID,
155 SessionID: item.SessionID,
156 Role: MessageRole(item.Role),
157 Parts: parts,
158 Model: item.Model.String,
159 Provider: item.Provider.String,
160 CreatedAt: item.CreatedAt,
161 UpdatedAt: item.UpdatedAt,
162 }, nil
163}
164
165type partType string
166
167const (
168 reasoningType partType = "reasoning"
169 textType partType = "text"
170 imageURLType partType = "image_url"
171 binaryType partType = "binary"
172 toolCallType partType = "tool_call"
173 toolResultType partType = "tool_result"
174 finishType partType = "finish"
175 retryType partType = "retry"
176)
177
178type partWrapper struct {
179 Type partType `json:"type"`
180 Data ContentPart `json:"data"`
181}
182
183func marshallParts(parts []ContentPart) ([]byte, error) {
184 wrappedParts := make([]partWrapper, len(parts))
185
186 for i, part := range parts {
187 var typ partType
188
189 switch part.(type) {
190 case ReasoningContent:
191 typ = reasoningType
192 case TextContent:
193 typ = textType
194 case ImageURLContent:
195 typ = imageURLType
196 case BinaryContent:
197 typ = binaryType
198 case ToolCall:
199 typ = toolCallType
200 case ToolResult:
201 typ = toolResultType
202 case Finish:
203 typ = finishType
204 case RetryContent:
205 typ = retryType
206 default:
207 return nil, fmt.Errorf("unknown part type: %T", part)
208 }
209
210 wrappedParts[i] = partWrapper{
211 Type: typ,
212 Data: part,
213 }
214 }
215 return json.Marshal(wrappedParts)
216}
217
218func unmarshallParts(data []byte) ([]ContentPart, error) {
219 temp := []json.RawMessage{}
220
221 if err := json.Unmarshal(data, &temp); err != nil {
222 return nil, err
223 }
224
225 parts := make([]ContentPart, 0)
226
227 for _, rawPart := range temp {
228 var wrapper struct {
229 Type partType `json:"type"`
230 Data json.RawMessage `json:"data"`
231 }
232
233 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
234 return nil, err
235 }
236
237 switch wrapper.Type {
238 case reasoningType:
239 part := ReasoningContent{}
240 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
241 return nil, err
242 }
243 parts = append(parts, part)
244 case textType:
245 part := TextContent{}
246 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
247 return nil, err
248 }
249 parts = append(parts, part)
250 case imageURLType:
251 part := ImageURLContent{}
252 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253 return nil, err
254 }
255 case binaryType:
256 part := BinaryContent{}
257 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
258 return nil, err
259 }
260 parts = append(parts, part)
261 case toolCallType:
262 part := ToolCall{}
263 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
264 return nil, err
265 }
266 parts = append(parts, part)
267 case toolResultType:
268 part := ToolResult{}
269 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
270 return nil, err
271 }
272 parts = append(parts, part)
273 case finishType:
274 part := Finish{}
275 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
276 return nil, err
277 }
278 parts = append(parts, part)
279 case retryType:
280 part := RetryContent{}
281 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
282 return nil, err
283 }
284 parts = append(parts, part)
285 default:
286 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
287 }
288 }
289
290 return parts, nil
291}