1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole `json:"role"`
17 Parts []ContentPart `json:"parts"`
18 Model string `json:"model"`
19 Provider string `json:"provider,omitempty"`
20}
21
22type Service interface {
23 pubsub.Suscriber[Message]
24 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
25 Update(ctx context.Context, message Message) error
26 Get(ctx context.Context, id string) (Message, error)
27 List(ctx context.Context, sessionID string) ([]Message, error)
28 Delete(ctx context.Context, id string) error
29 DeleteSessionMessages(ctx context.Context, sessionID string) error
30}
31
32type service struct {
33 *pubsub.Broker[Message]
34 q db.Querier
35}
36
37func NewService(q db.Querier) Service {
38 return &service{
39 Broker: pubsub.NewBroker[Message](),
40 q: q,
41 }
42}
43
44func (s *service) Delete(ctx context.Context, id string) error {
45 message, err := s.Get(ctx, id)
46 if err != nil {
47 return err
48 }
49 err = s.q.DeleteMessage(ctx, message.ID)
50 if err != nil {
51 return err
52 }
53 s.Publish(pubsub.DeletedEvent, message)
54 return nil
55}
56
57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
58 if params.Role != Assistant {
59 params.Parts = append(params.Parts, Finish{
60 Reason: "stop",
61 })
62 }
63 partsJSON, err := marshallParts(params.Parts)
64 if err != nil {
65 return Message{}, err
66 }
67 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
68 ID: uuid.New().String(),
69 SessionID: sessionID,
70 Role: string(params.Role),
71 Parts: string(partsJSON),
72 Model: sql.NullString{String: string(params.Model), Valid: true},
73 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
74 })
75 if err != nil {
76 return Message{}, err
77 }
78 message, err := s.fromDBItem(dbMessage)
79 if err != nil {
80 return Message{}, err
81 }
82 s.Publish(pubsub.CreatedEvent, message)
83 return message, nil
84}
85
86func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
87 messages, err := s.List(ctx, sessionID)
88 if err != nil {
89 return err
90 }
91 for _, message := range messages {
92 if message.SessionID == sessionID {
93 err = s.Delete(ctx, message.ID)
94 if err != nil {
95 return err
96 }
97 }
98 }
99 return nil
100}
101
102func (s *service) Update(ctx context.Context, message Message) error {
103 parts, err := marshallParts(message.Parts)
104 if err != nil {
105 return err
106 }
107 finishedAt := sql.NullInt64{}
108 if f := message.FinishPart(); f != nil {
109 finishedAt.Int64 = f.Time
110 finishedAt.Valid = true
111 }
112 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
113 ID: message.ID,
114 Parts: string(parts),
115 FinishedAt: finishedAt,
116 })
117 if err != nil {
118 return err
119 }
120 message.UpdatedAt = time.Now().Unix()
121 s.Publish(pubsub.UpdatedEvent, message)
122 return nil
123}
124
125func (s *service) Get(ctx context.Context, id string) (Message, error) {
126 dbMessage, err := s.q.GetMessage(ctx, id)
127 if err != nil {
128 return Message{}, err
129 }
130 return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
134 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
135 if err != nil {
136 return nil, err
137 }
138 messages := make([]Message, len(dbMessages))
139 for i, dbMessage := range dbMessages {
140 messages[i], err = s.fromDBItem(dbMessage)
141 if err != nil {
142 return nil, err
143 }
144 }
145 return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149 parts, err := unmarshallParts([]byte(item.Parts))
150 if err != nil {
151 return Message{}, err
152 }
153 return Message{
154 ID: item.ID,
155 SessionID: item.SessionID,
156 Role: MessageRole(item.Role),
157 Parts: parts,
158 Model: item.Model.String,
159 Provider: item.Provider.String,
160 CreatedAt: item.CreatedAt,
161 UpdatedAt: item.UpdatedAt,
162 }, nil
163}
164
165type partType string
166
167const (
168 reasoningType partType = "reasoning"
169 textType partType = "text"
170 imageURLType partType = "image_url"
171 binaryType partType = "binary"
172 toolCallType partType = "tool_call"
173 toolResultType partType = "tool_result"
174 finishType partType = "finish"
175)
176
177type partWrapper struct {
178 Type partType `json:"type"`
179 Data ContentPart `json:"data"`
180}
181
182func marshallParts(parts []ContentPart) ([]byte, error) {
183 wrappedParts := make([]partWrapper, len(parts))
184
185 for i, part := range parts {
186 var typ partType
187
188 switch part.(type) {
189 case ReasoningContent:
190 typ = reasoningType
191 case TextContent:
192 typ = textType
193 case ImageURLContent:
194 typ = imageURLType
195 case BinaryContent:
196 typ = binaryType
197 case ToolCall:
198 typ = toolCallType
199 case ToolResult:
200 typ = toolResultType
201 case Finish:
202 typ = finishType
203 default:
204 return nil, fmt.Errorf("unknown part type: %T", part)
205 }
206
207 wrappedParts[i] = partWrapper{
208 Type: typ,
209 Data: part,
210 }
211 }
212 return json.Marshal(wrappedParts)
213}
214
215func unmarshallParts(data []byte) ([]ContentPart, error) {
216 temp := []json.RawMessage{}
217
218 if err := json.Unmarshal(data, &temp); err != nil {
219 return nil, err
220 }
221
222 parts := make([]ContentPart, 0)
223
224 for _, rawPart := range temp {
225 var wrapper struct {
226 Type partType `json:"type"`
227 Data json.RawMessage `json:"data"`
228 }
229
230 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
231 return nil, err
232 }
233
234 switch wrapper.Type {
235 case reasoningType:
236 part := ReasoningContent{}
237 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
238 return nil, err
239 }
240 parts = append(parts, part)
241 case textType:
242 part := TextContent{}
243 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
244 return nil, err
245 }
246 parts = append(parts, part)
247 case imageURLType:
248 part := ImageURLContent{}
249 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
250 return nil, err
251 }
252 case binaryType:
253 part := BinaryContent{}
254 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
255 return nil, err
256 }
257 parts = append(parts, part)
258 case toolCallType:
259 part := ToolCall{}
260 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
261 return nil, err
262 }
263 parts = append(parts, part)
264 case toolResultType:
265 part := ToolResult{}
266 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
267 return nil, err
268 }
269 parts = append(parts, part)
270 case finishType:
271 part := Finish{}
272 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
273 return nil, err
274 }
275 parts = append(parts, part)
276 default:
277 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
278 }
279 }
280
281 return parts, nil
282}