1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model string
19 Provider string
20 IsSummaryMessage bool
21}
22
23type Service interface {
24 pubsub.Suscriber[Message]
25 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
26 Update(ctx context.Context, message Message) error
27 Get(ctx context.Context, id string) (Message, error)
28 List(ctx context.Context, sessionID string) ([]Message, error)
29 Delete(ctx context.Context, id string) error
30 DeleteSessionMessages(ctx context.Context, sessionID string) error
31}
32
33type service struct {
34 *pubsub.Broker[Message]
35 q db.Querier
36}
37
38func NewService(q db.Querier) Service {
39 return &service{
40 Broker: pubsub.NewBroker[Message](),
41 q: q,
42 }
43}
44
45func (s *service) Delete(ctx context.Context, id string) error {
46 message, err := s.Get(ctx, id)
47 if err != nil {
48 return err
49 }
50 err = s.q.DeleteMessage(ctx, message.ID)
51 if err != nil {
52 return err
53 }
54 s.Publish(pubsub.DeletedEvent, message)
55 return nil
56}
57
58func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
59 if params.Role != Assistant {
60 params.Parts = append(params.Parts, Finish{
61 Reason: "stop",
62 })
63 }
64 partsJSON, err := marshallParts(params.Parts)
65 if err != nil {
66 return Message{}, err
67 }
68 isSummary := int64(0)
69 if params.IsSummaryMessage {
70 isSummary = 1
71 }
72 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
73 ID: uuid.New().String(),
74 SessionID: sessionID,
75 Role: string(params.Role),
76 Parts: string(partsJSON),
77 Model: sql.NullString{String: string(params.Model), Valid: true},
78 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
79 IsSummaryMessage: isSummary,
80 })
81 if err != nil {
82 return Message{}, err
83 }
84 message, err := s.fromDBItem(dbMessage)
85 if err != nil {
86 return Message{}, err
87 }
88 s.Publish(pubsub.CreatedEvent, message)
89 return message, nil
90}
91
92func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
93 messages, err := s.List(ctx, sessionID)
94 if err != nil {
95 return err
96 }
97 for _, message := range messages {
98 if message.SessionID == sessionID {
99 err = s.Delete(ctx, message.ID)
100 if err != nil {
101 return err
102 }
103 }
104 }
105 return nil
106}
107
108func (s *service) Update(ctx context.Context, message Message) error {
109 parts, err := marshallParts(message.Parts)
110 if err != nil {
111 return err
112 }
113 finishedAt := sql.NullInt64{}
114 if f := message.FinishPart(); f != nil {
115 finishedAt.Int64 = f.Time
116 finishedAt.Valid = true
117 }
118 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
119 ID: message.ID,
120 Parts: string(parts),
121 FinishedAt: finishedAt,
122 })
123 if err != nil {
124 return err
125 }
126 message.UpdatedAt = time.Now().Unix()
127 s.Publish(pubsub.UpdatedEvent, message)
128 return nil
129}
130
131func (s *service) Get(ctx context.Context, id string) (Message, error) {
132 dbMessage, err := s.q.GetMessage(ctx, id)
133 if err != nil {
134 return Message{}, err
135 }
136 return s.fromDBItem(dbMessage)
137}
138
139func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
140 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
141 if err != nil {
142 return nil, err
143 }
144 messages := make([]Message, len(dbMessages))
145 for i, dbMessage := range dbMessages {
146 messages[i], err = s.fromDBItem(dbMessage)
147 if err != nil {
148 return nil, err
149 }
150 }
151 return messages, nil
152}
153
154func (s *service) fromDBItem(item db.Message) (Message, error) {
155 parts, err := unmarshallParts([]byte(item.Parts))
156 if err != nil {
157 return Message{}, err
158 }
159 return Message{
160 ID: item.ID,
161 SessionID: item.SessionID,
162 Role: MessageRole(item.Role),
163 Parts: parts,
164 Model: item.Model.String,
165 Provider: item.Provider.String,
166 CreatedAt: item.CreatedAt,
167 UpdatedAt: item.UpdatedAt,
168 IsSummaryMessage: item.IsSummaryMessage != 0,
169 }, nil
170}
171
172type partType string
173
174const (
175 reasoningType partType = "reasoning"
176 textType partType = "text"
177 imageURLType partType = "image_url"
178 binaryType partType = "binary"
179 toolCallType partType = "tool_call"
180 toolResultType partType = "tool_result"
181 finishType partType = "finish"
182)
183
184type partWrapper struct {
185 Type partType `json:"type"`
186 Data ContentPart `json:"data"`
187}
188
189func marshallParts(parts []ContentPart) ([]byte, error) {
190 wrappedParts := make([]partWrapper, len(parts))
191
192 for i, part := range parts {
193 var typ partType
194
195 switch part.(type) {
196 case ReasoningContent:
197 typ = reasoningType
198 case TextContent:
199 typ = textType
200 case ImageURLContent:
201 typ = imageURLType
202 case BinaryContent:
203 typ = binaryType
204 case ToolCall:
205 typ = toolCallType
206 case ToolResult:
207 typ = toolResultType
208 case Finish:
209 typ = finishType
210 default:
211 return nil, fmt.Errorf("unknown part type: %T", part)
212 }
213
214 wrappedParts[i] = partWrapper{
215 Type: typ,
216 Data: part,
217 }
218 }
219 return json.Marshal(wrappedParts)
220}
221
222func unmarshallParts(data []byte) ([]ContentPart, error) {
223 temp := []json.RawMessage{}
224
225 if err := json.Unmarshal(data, &temp); err != nil {
226 return nil, err
227 }
228
229 parts := make([]ContentPart, 0)
230
231 for _, rawPart := range temp {
232 var wrapper struct {
233 Type partType `json:"type"`
234 Data json.RawMessage `json:"data"`
235 }
236
237 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
238 return nil, err
239 }
240
241 switch wrapper.Type {
242 case reasoningType:
243 part := ReasoningContent{}
244 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
245 return nil, err
246 }
247 parts = append(parts, part)
248 case textType:
249 part := TextContent{}
250 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
251 return nil, err
252 }
253 parts = append(parts, part)
254 case imageURLType:
255 part := ImageURLContent{}
256 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
257 return nil, err
258 }
259 case binaryType:
260 part := BinaryContent{}
261 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
262 return nil, err
263 }
264 parts = append(parts, part)
265 case toolCallType:
266 part := ToolCall{}
267 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
268 return nil, err
269 }
270 parts = append(parts, part)
271 case toolResultType:
272 part := ToolResult{}
273 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
274 return nil, err
275 }
276 parts = append(parts, part)
277 case finishType:
278 part := Finish{}
279 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
280 return nil, err
281 }
282 parts = append(parts, part)
283 default:
284 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
285 }
286 }
287
288 return parts, nil
289}