1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "sync"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/pubsub"
13 "github.com/google/uuid"
14)
15
16// defaultUpdateDebounce is the default debounce window for [Service.Update].
17// Streaming deltas that arrive within the window are coalesced into a
18// single SQL write and a single pubsub event. Terminal updates
19// (finish/error/cancel/tool-call structural changes) bypass the
20// debounce and flush synchronously.
21const defaultUpdateDebounce = 33 * time.Millisecond
22
23type CreateMessageParams struct {
24 Role MessageRole
25 Parts []ContentPart
26 Model string
27 Provider string
28 IsSummaryMessage bool
29}
30
31// Service is the public interface to the message store.
32//
33// [Service.Update] is eventually consistent: it accepts new state into
34// an in-memory buffer and writes it to SQLite plus publishes a
35// [pubsub.UpdatedEvent] on the next debounce tick (default
36// [defaultUpdateDebounce]) or on the next terminal-state update,
37// whichever comes first. Terminal-state updates — those that finish
38// the message, add or finish a tool call, or end a reasoning section —
39// flush synchronously before [Service.Update] returns.
40//
41// Callers that need stronger ordering (e.g. tests, shutdown,
42// session-switch reads) must use [Service.Flush] or [Service.FlushAll]
43// before reading via [Service.Get] / [Service.List]. Without an
44// explicit flush, a read can race the debounce timer and miss the
45// most recent in-memory state.
46type Service interface {
47 pubsub.Subscriber[Message]
48 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
49 Update(ctx context.Context, message Message) error
50 Get(ctx context.Context, id string) (Message, error)
51 List(ctx context.Context, sessionID string) ([]Message, error)
52 ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
53 ListAllUserMessages(ctx context.Context) ([]Message, error)
54 Delete(ctx context.Context, id string) error
55 DeleteSessionMessages(ctx context.Context, sessionID string) error
56
57 // Flush synchronously drains any pending debounced state for the
58 // given message ID, performs the SQL write, and publishes the
59 // resulting [pubsub.UpdatedEvent]. Idempotent; cheap no-op if no
60 // updates are pending. Use this before any read that must observe
61 // the latest [Service.Update].
62 Flush(ctx context.Context, id string) error
63
64 // FlushAll synchronously drains pending debounced state for every
65 // message known to the service. Intended for shutdown and
66 // session-switch paths.
67 FlushAll(ctx context.Context) error
68}
69
70// pendingState holds the in-memory coalescing buffer for a single
71// message ID. All fields except where noted are guarded by
72// service.mu. The flushing flag serializes concurrent flushers for
73// the same ID so SQL writes never reorder.
74type pendingState struct {
75 // latest is the most recent [Message] passed to [Service.Update]
76 // that has not yet been flushed.
77 latest Message
78
79 // dirty is true when latest contains state that has not been
80 // written to SQL since the last successful flush.
81 dirty bool
82
83 // flushing is true while a goroutine is performing the SQL write
84 // for this ID. New updates are still accepted (and re-mark dirty)
85 // but other flushers must back off.
86 flushing bool
87
88 // timer is the active debounce timer, or nil if no flush is
89 // scheduled. Stopped and reset when a terminal update preempts
90 // the debounce window.
91 timer *time.Timer
92
93 // lastFlushed is the snapshot most recently written to SQL. Used
94 // as the baseline for terminal-state detection.
95 lastFlushed Message
96
97 // hasFlushed is false until the first successful write for this
98 // ID; until then lastFlushed is the zero value and must not be
99 // treated as a real prior state.
100 hasFlushed bool
101}
102
103type service struct {
104 *pubsub.Broker[Message]
105 q db.Querier
106 debounce time.Duration
107
108 mu sync.Mutex
109 pending map[string]*pendingState
110}
111
112// ServiceOption configures a [Service] at construction.
113type ServiceOption func(*service)
114
115// WithDebounce overrides the debounce window for [Service.Update]. A
116// zero or negative value disables debouncing entirely (every update
117// flushes synchronously). Intended primarily for tests.
118func WithDebounce(d time.Duration) ServiceOption {
119 return func(s *service) {
120 s.debounce = d
121 }
122}
123
124func NewService(q db.Querier, opts ...ServiceOption) Service {
125 s := &service{
126 Broker: pubsub.NewBroker[Message](),
127 q: q,
128 debounce: defaultUpdateDebounce,
129 pending: make(map[string]*pendingState),
130 }
131 for _, opt := range opts {
132 opt(s)
133 }
134 return s
135}
136
137func (s *service) Delete(ctx context.Context, id string) error {
138 message, err := s.Get(ctx, id)
139 if err != nil {
140 return err
141 }
142 err = s.q.DeleteMessage(ctx, message.ID)
143 if err != nil {
144 return err
145 }
146 // Drop any pending coalesced state for this ID. We never want to
147 // flush back over a deleted row.
148 s.mu.Lock()
149 if p, ok := s.pending[id]; ok {
150 if p.timer != nil {
151 p.timer.Stop()
152 }
153 delete(s.pending, id)
154 }
155 s.mu.Unlock()
156 // Clone the message before publishing to avoid race conditions with
157 // concurrent modifications to the Parts slice.
158 s.Publish(pubsub.DeletedEvent, message.Clone())
159 return nil
160}
161
162func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
163 if params.Role != Assistant {
164 params.Parts = append(params.Parts, Finish{
165 Reason: "stop",
166 })
167 }
168 partsJSON, err := marshalParts(params.Parts)
169 if err != nil {
170 return Message{}, err
171 }
172 isSummary := int64(0)
173 if params.IsSummaryMessage {
174 isSummary = 1
175 }
176 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
177 ID: uuid.New().String(),
178 SessionID: sessionID,
179 Role: string(params.Role),
180 Parts: string(partsJSON),
181 Model: sql.NullString{String: string(params.Model), Valid: true},
182 Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
183 IsSummaryMessage: isSummary,
184 })
185 if err != nil {
186 return Message{}, err
187 }
188 message, err := s.fromDBItem(dbMessage)
189 if err != nil {
190 return Message{}, err
191 }
192 // Clone the message before publishing to avoid race conditions with
193 // concurrent modifications to the Parts slice.
194 s.Publish(pubsub.CreatedEvent, message.Clone())
195 return message, nil
196}
197
198func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
199 messages, err := s.List(ctx, sessionID)
200 if err != nil {
201 return err
202 }
203 for _, message := range messages {
204 if message.SessionID == sessionID {
205 err = s.Delete(ctx, message.ID)
206 if err != nil {
207 return err
208 }
209 }
210 }
211 return nil
212}
213
214// Update accepts a new state for a message and either flushes
215// synchronously (terminal updates, debounce <= 0) or buffers it until
216// the next debounce tick. See [Service] for the contract.
217func (s *service) Update(ctx context.Context, msg Message) error {
218 cloned := msg.Clone()
219
220 // Zero or negative debounce: flush every update synchronously. This
221 // preserves the pre-coalescing behaviour for tests and any caller
222 // that explicitly opted out via [WithDebounce].
223 if s.debounce <= 0 {
224 s.mu.Lock()
225 p, ok := s.pending[msg.ID]
226 if !ok {
227 p = &pendingState{}
228 s.pending[msg.ID] = p
229 }
230 p.latest = cloned
231 p.dirty = true
232 s.mu.Unlock()
233 return s.flushOne(ctx, msg.ID, true)
234 }
235
236 s.mu.Lock()
237 p, ok := s.pending[msg.ID]
238 if !ok {
239 p = &pendingState{}
240 s.pending[msg.ID] = p
241 }
242 p.latest = cloned
243 p.dirty = true
244
245 var prev *Message
246 if p.hasFlushed {
247 prev = &p.lastFlushed
248 }
249 terminal := shouldFlushNow(prev, &cloned)
250
251 if terminal {
252 if p.timer != nil {
253 p.timer.Stop()
254 p.timer = nil
255 }
256 s.mu.Unlock()
257 return s.flushOne(ctx, msg.ID, true)
258 }
259
260 // Debounce: schedule a single flush per pending state. If a flush
261 // is already running we let it finish; the trailing dirty bit will
262 // be picked up by the next Update or by Flush.
263 if p.timer == nil && !p.flushing {
264 id := msg.ID
265 p.timer = time.AfterFunc(s.debounce, func() {
266 // Detached from caller ctx so a cancelled stream context
267 // does not strand the buffered write.
268 _ = s.flushOne(context.Background(), id, false)
269 })
270 }
271 s.mu.Unlock()
272 return nil
273}
274
275// Flush implements [Service.Flush].
276func (s *service) Flush(ctx context.Context, id string) error {
277 return s.flushOne(ctx, id, true)
278}
279
280// FlushAll implements [Service.FlushAll]. It snapshots every ID with
281// outstanding work — either dirty buffered state or a flush already in
282// flight — then drains each one. Picking up in-flight IDs ensures
283// FlushAll cannot return while a timer-fired write is still mid-SQL,
284// which is what shutdown and session-switch callers rely on.
285func (s *service) FlushAll(ctx context.Context) error {
286 s.mu.Lock()
287 ids := make([]string, 0, len(s.pending))
288 for id, p := range s.pending {
289 if p.dirty || p.flushing {
290 ids = append(ids, id)
291 }
292 }
293 s.mu.Unlock()
294 var firstErr error
295 for _, id := range ids {
296 if err := s.flushOne(ctx, id, true); err != nil && firstErr == nil {
297 firstErr = err
298 }
299 }
300 return firstErr
301}
302
303// flushOne drains a single message ID. When syncCaller is true the
304// caller is willing to wait through a concurrent in-flight flush so
305// that, on return, lastFlushed equals latest at the moment of return.
306// When false (timer-fired path) we bail if another flusher is already
307// running; that flusher will pick up the trailing dirty bit.
308//
309// Order matters: a sync caller must wait for any in-flight flush to
310// drain even when the buffer is currently clean — that in-flight
311// write has not yet updated the SQL row, so returning early would
312// violate the contract that on success lastFlushed reflects the most
313// recent state.
314func (s *service) flushOne(ctx context.Context, id string, syncCaller bool) error {
315 for {
316 s.mu.Lock()
317 p, ok := s.pending[id]
318 if !ok {
319 s.mu.Unlock()
320 return nil
321 }
322 if p.flushing {
323 if !syncCaller {
324 s.mu.Unlock()
325 return nil
326 }
327 s.mu.Unlock()
328 // Brief yield; in-flight write should land in <1ms typical.
329 time.Sleep(time.Millisecond)
330 continue
331 }
332 if !p.dirty {
333 s.mu.Unlock()
334 return nil
335 }
336
337 if p.timer != nil {
338 p.timer.Stop()
339 p.timer = nil
340 }
341 snap := p.latest
342 // Decide whether this snapshot represents a terminal event
343 // against the prior baseline. We must do this before resetting
344 // dirty/flushing because shouldFlushNow looks at p.lastFlushed
345 // (which is what was on disk before this write).
346 var prev *Message
347 if p.hasFlushed {
348 prev = &p.lastFlushed
349 }
350 isTerminal := shouldFlushNow(prev, &snap)
351 p.flushing = true
352 p.dirty = false
353 s.mu.Unlock()
354
355 err := s.write(ctx, snap)
356
357 s.mu.Lock()
358 p.flushing = false
359 if err == nil {
360 p.lastFlushed = snap
361 p.hasFlushed = true
362 } else {
363 // Restore dirty so the next caller retries.
364 p.dirty = true
365 }
366 // If a delta arrived during the SQL write and we are a sync
367 // caller, the user expects that delta to land too.
368 wasDirty := p.dirty
369 s.mu.Unlock()
370
371 if err != nil {
372 return err
373 }
374
375 // Terminal events — message finished, tool call added or
376 // finished, reasoning ended — use the bounded must-deliver
377 // path so they never get dropped under channel contention.
378 if isTerminal {
379 s.PublishMustDeliver(ctx, pubsub.UpdatedEvent, snap)
380 } else {
381 s.Publish(pubsub.UpdatedEvent, snap)
382 }
383
384 if wasDirty && syncCaller {
385 continue
386 }
387 return nil
388 }
389}
390
391// write performs the unguarded SQL write + UpdatedAt stamp. Caller
392// owns publishing.
393func (s *service) write(ctx context.Context, msg Message) error {
394 parts, err := marshalParts(msg.Parts)
395 if err != nil {
396 return err
397 }
398 finishedAt := sql.NullInt64{}
399 if f := msg.FinishPart(); f != nil {
400 finishedAt.Int64 = f.Time
401 finishedAt.Valid = true
402 }
403 if err := s.q.UpdateMessage(ctx, db.UpdateMessageParams{
404 ID: msg.ID,
405 Parts: string(parts),
406 FinishedAt: finishedAt,
407 }); err != nil {
408 return err
409 }
410 return nil
411}
412
413// shouldFlushNow returns true when next represents a structural
414// change that must not be silently coalesced: the message just
415// finished, the tool-call set grew, a tool call transitioned to
416// finished, or reasoning just finished. prev is the last-flushed
417// snapshot (or nil if no write has landed yet).
418func shouldFlushNow(prev, next *Message) bool {
419 if next.IsFinished() {
420 return true
421 }
422
423 var prevCalls []ToolCall
424 var prevReasoningFinishedAt int64
425 if prev != nil {
426 prevCalls = prev.ToolCalls()
427 prevReasoningFinishedAt = prev.ReasoningContent().FinishedAt
428 }
429 nextCalls := next.ToolCalls()
430 if len(nextCalls) != len(prevCalls) {
431 return true
432 }
433 for i := range nextCalls {
434 // Bounds-safe: lengths are equal here.
435 if nextCalls[i].Finished != prevCalls[i].Finished {
436 return true
437 }
438 // A tool call's input only matters once it has landed (Finished
439 // flips true). Earlier deltas to Input are debounced with the
440 // rest of the streaming state.
441 }
442 if next.ReasoningContent().FinishedAt > 0 && prevReasoningFinishedAt == 0 {
443 return true
444 }
445 return false
446}
447
448func (s *service) Get(ctx context.Context, id string) (Message, error) {
449 dbMessage, err := s.q.GetMessage(ctx, id)
450 if err != nil {
451 return Message{}, err
452 }
453 return s.fromDBItem(dbMessage)
454}
455
456func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
457 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
458 if err != nil {
459 return nil, err
460 }
461 messages := make([]Message, len(dbMessages))
462 for i, dbMessage := range dbMessages {
463 messages[i], err = s.fromDBItem(dbMessage)
464 if err != nil {
465 return nil, err
466 }
467 }
468 return messages, nil
469}
470
471func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
472 dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
473 if err != nil {
474 return nil, err
475 }
476 messages := make([]Message, len(dbMessages))
477 for i, dbMessage := range dbMessages {
478 messages[i], err = s.fromDBItem(dbMessage)
479 if err != nil {
480 return nil, err
481 }
482 }
483 return messages, nil
484}
485
486func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
487 dbMessages, err := s.q.ListAllUserMessages(ctx)
488 if err != nil {
489 return nil, err
490 }
491 messages := make([]Message, len(dbMessages))
492 for i, dbMessage := range dbMessages {
493 messages[i], err = s.fromDBItem(dbMessage)
494 if err != nil {
495 return nil, err
496 }
497 }
498 return messages, nil
499}
500
501func (s *service) fromDBItem(item db.Message) (Message, error) {
502 parts, err := unmarshalParts([]byte(item.Parts))
503 if err != nil {
504 return Message{}, err
505 }
506 return Message{
507 ID: item.ID,
508 SessionID: item.SessionID,
509 Role: MessageRole(item.Role),
510 Parts: parts,
511 Model: item.Model.String,
512 Provider: item.Provider.String,
513 CreatedAt: item.CreatedAt,
514 UpdatedAt: item.UpdatedAt,
515 IsSummaryMessage: item.IsSummaryMessage != 0,
516 }, nil
517}
518
519type partType string
520
521const (
522 reasoningType partType = "reasoning"
523 textType partType = "text"
524 imageURLType partType = "image_url"
525 binaryType partType = "binary"
526 toolCallType partType = "tool_call"
527 toolResultType partType = "tool_result"
528 finishType partType = "finish"
529)
530
531type partWrapper struct {
532 Type partType `json:"type"`
533 Data ContentPart `json:"data"`
534}
535
536func marshalParts(parts []ContentPart) ([]byte, error) {
537 wrappedParts := make([]partWrapper, len(parts))
538
539 for i, part := range parts {
540 var typ partType
541
542 switch part.(type) {
543 case ReasoningContent:
544 typ = reasoningType
545 case TextContent:
546 typ = textType
547 case ImageURLContent:
548 typ = imageURLType
549 case BinaryContent:
550 typ = binaryType
551 case ToolCall:
552 typ = toolCallType
553 case ToolResult:
554 typ = toolResultType
555 case Finish:
556 typ = finishType
557 default:
558 return nil, fmt.Errorf("unknown part type: %T", part)
559 }
560
561 wrappedParts[i] = partWrapper{
562 Type: typ,
563 Data: part,
564 }
565 }
566 return json.Marshal(wrappedParts)
567}
568
569func unmarshalParts(data []byte) ([]ContentPart, error) {
570 temp := []json.RawMessage{}
571
572 if err := json.Unmarshal(data, &temp); err != nil {
573 return nil, err
574 }
575
576 parts := make([]ContentPart, 0)
577
578 for _, rawPart := range temp {
579 var wrapper struct {
580 Type partType `json:"type"`
581 Data json.RawMessage `json:"data"`
582 }
583
584 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
585 return nil, err
586 }
587
588 switch wrapper.Type {
589 case reasoningType:
590 part := ReasoningContent{}
591 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
592 return nil, err
593 }
594 parts = append(parts, part)
595 case textType:
596 part := TextContent{}
597 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
598 return nil, err
599 }
600 parts = append(parts, part)
601 case imageURLType:
602 part := ImageURLContent{}
603 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
604 return nil, err
605 }
606 parts = append(parts, part)
607 case binaryType:
608 part := BinaryContent{}
609 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
610 return nil, err
611 }
612 parts = append(parts, part)
613 case toolCallType:
614 part := ToolCall{}
615 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
616 return nil, err
617 }
618 parts = append(parts, part)
619 case toolResultType:
620 part := ToolResult{}
621 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
622 return nil, err
623 }
624 parts = append(parts, part)
625 case finishType:
626 part := Finish{}
627 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
628 return nil, err
629 }
630 parts = append(parts, part)
631 default:
632 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
633 }
634 }
635
636 return parts, nil
637}