1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/llm/models"
12 "github.com/charmbracelet/crush/internal/pubsub"
13 "github.com/google/uuid"
14)
15
16type CreateMessageParams struct {
17 Role MessageRole
18 Parts []ContentPart
19 Model models.ModelID
20}
21
22type Service interface {
23 pubsub.Suscriber[Message]
24 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
25 Update(ctx context.Context, message Message) error
26 Get(ctx context.Context, id string) (Message, error)
27 List(ctx context.Context, sessionID string) ([]Message, error)
28 Delete(ctx context.Context, id string) error
29 DeleteSessionMessages(ctx context.Context, sessionID string) error
30}
31
32type service struct {
33 *pubsub.Broker[Message]
34 q db.Querier
35}
36
37func NewService(q db.Querier) Service {
38 return &service{
39 Broker: pubsub.NewBroker[Message](),
40 q: q,
41 }
42}
43
44func (s *service) Delete(ctx context.Context, id string) error {
45 message, err := s.Get(ctx, id)
46 if err != nil {
47 return err
48 }
49 err = s.q.DeleteMessage(ctx, message.ID)
50 if err != nil {
51 return err
52 }
53 s.Publish(pubsub.DeletedEvent, message)
54 return nil
55}
56
57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
58 if params.Role != Assistant {
59 params.Parts = append(params.Parts, Finish{
60 Reason: "stop",
61 })
62 }
63 partsJSON, err := marshallParts(params.Parts)
64 if err != nil {
65 return Message{}, err
66 }
67 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
68 ID: uuid.New().String(),
69 SessionID: sessionID,
70 Role: string(params.Role),
71 Parts: string(partsJSON),
72 Model: sql.NullString{String: string(params.Model), Valid: true},
73 })
74 if err != nil {
75 return Message{}, err
76 }
77 message, err := s.fromDBItem(dbMessage)
78 if err != nil {
79 return Message{}, err
80 }
81 s.Publish(pubsub.CreatedEvent, message)
82 return message, nil
83}
84
85func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
86 messages, err := s.List(ctx, sessionID)
87 if err != nil {
88 return err
89 }
90 for _, message := range messages {
91 if message.SessionID == sessionID {
92 err = s.Delete(ctx, message.ID)
93 if err != nil {
94 return err
95 }
96 }
97 }
98 return nil
99}
100
101func (s *service) Update(ctx context.Context, message Message) error {
102 parts, err := marshallParts(message.Parts)
103 if err != nil {
104 return err
105 }
106 finishedAt := sql.NullInt64{}
107 if f := message.FinishPart(); f != nil {
108 finishedAt.Int64 = f.Time
109 finishedAt.Valid = true
110 }
111 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
112 ID: message.ID,
113 Parts: string(parts),
114 FinishedAt: finishedAt,
115 })
116 if err != nil {
117 return err
118 }
119 message.UpdatedAt = time.Now().Unix()
120 s.Publish(pubsub.UpdatedEvent, message)
121 return nil
122}
123
124func (s *service) Get(ctx context.Context, id string) (Message, error) {
125 dbMessage, err := s.q.GetMessage(ctx, id)
126 if err != nil {
127 return Message{}, err
128 }
129 return s.fromDBItem(dbMessage)
130}
131
132func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
133 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
134 if err != nil {
135 return nil, err
136 }
137 messages := make([]Message, len(dbMessages))
138 for i, dbMessage := range dbMessages {
139 messages[i], err = s.fromDBItem(dbMessage)
140 if err != nil {
141 return nil, err
142 }
143 }
144 return messages, nil
145}
146
147func (s *service) fromDBItem(item db.Message) (Message, error) {
148 parts, err := unmarshallParts([]byte(item.Parts))
149 if err != nil {
150 return Message{}, err
151 }
152 return Message{
153 ID: item.ID,
154 SessionID: item.SessionID,
155 Role: MessageRole(item.Role),
156 Parts: parts,
157 Model: models.ModelID(item.Model.String),
158 CreatedAt: item.CreatedAt,
159 UpdatedAt: item.UpdatedAt,
160 }, nil
161}
162
163type partType string
164
165const (
166 reasoningType partType = "reasoning"
167 textType partType = "text"
168 imageURLType partType = "image_url"
169 binaryType partType = "binary"
170 toolCallType partType = "tool_call"
171 toolResultType partType = "tool_result"
172 finishType partType = "finish"
173)
174
175type partWrapper struct {
176 Type partType `json:"type"`
177 Data ContentPart `json:"data"`
178}
179
180func marshallParts(parts []ContentPart) ([]byte, error) {
181 wrappedParts := make([]partWrapper, len(parts))
182
183 for i, part := range parts {
184 var typ partType
185
186 switch part.(type) {
187 case ReasoningContent:
188 typ = reasoningType
189 case TextContent:
190 typ = textType
191 case ImageURLContent:
192 typ = imageURLType
193 case BinaryContent:
194 typ = binaryType
195 case ToolCall:
196 typ = toolCallType
197 case ToolResult:
198 typ = toolResultType
199 case Finish:
200 typ = finishType
201 default:
202 return nil, fmt.Errorf("unknown part type: %T", part)
203 }
204
205 wrappedParts[i] = partWrapper{
206 Type: typ,
207 Data: part,
208 }
209 }
210 return json.Marshal(wrappedParts)
211}
212
213func unmarshallParts(data []byte) ([]ContentPart, error) {
214 temp := []json.RawMessage{}
215
216 if err := json.Unmarshal(data, &temp); err != nil {
217 return nil, err
218 }
219
220 parts := make([]ContentPart, 0)
221
222 for _, rawPart := range temp {
223 var wrapper struct {
224 Type partType `json:"type"`
225 Data json.RawMessage `json:"data"`
226 }
227
228 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
229 return nil, err
230 }
231
232 switch wrapper.Type {
233 case reasoningType:
234 part := ReasoningContent{}
235 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
236 return nil, err
237 }
238 parts = append(parts, part)
239 case textType:
240 part := TextContent{}
241 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
242 return nil, err
243 }
244 parts = append(parts, part)
245 case imageURLType:
246 part := ImageURLContent{}
247 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
248 return nil, err
249 }
250 case binaryType:
251 part := BinaryContent{}
252 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253 return nil, err
254 }
255 parts = append(parts, part)
256 case toolCallType:
257 part := ToolCall{}
258 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
259 return nil, err
260 }
261 parts = append(parts, part)
262 case toolResultType:
263 part := ToolResult{}
264 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
265 return nil, err
266 }
267 parts = append(parts, part)
268 case finishType:
269 part := Finish{}
270 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
271 return nil, err
272 }
273 parts = append(parts, part)
274 default:
275 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
276 }
277 }
278
279 return parts, nil
280}