message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"sync"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/pubsub"
 13	"github.com/google/uuid"
 14)
 15
 16// defaultUpdateDebounce is the default debounce window for [Service.Update].
 17// Streaming deltas that arrive within the window are coalesced into a
 18// single SQL write and a single pubsub event. Terminal updates
 19// (finish/error/cancel/tool-call structural changes) bypass the
 20// debounce and flush synchronously.
 21const defaultUpdateDebounce = 33 * time.Millisecond
 22
 23type CreateMessageParams struct {
 24	Role             MessageRole
 25	Parts            []ContentPart
 26	Model            string
 27	Provider         string
 28	IsSummaryMessage bool
 29}
 30
 31// Service is the public interface to the message store.
 32//
 33// [Service.Update] is eventually consistent: it accepts new state into
 34// an in-memory buffer and writes it to SQLite plus publishes a
 35// [pubsub.UpdatedEvent] on the next debounce tick (default
 36// [defaultUpdateDebounce]) or on the next terminal-state update,
 37// whichever comes first. Terminal-state updates — those that finish
 38// the message, add or finish a tool call, or end a reasoning section —
 39// flush synchronously before [Service.Update] returns.
 40//
 41// Callers that need stronger ordering (e.g. tests, shutdown,
 42// session-switch reads) must use [Service.Flush] or [Service.FlushAll]
 43// before reading via [Service.Get] / [Service.List]. Without an
 44// explicit flush, a read can race the debounce timer and miss the
 45// most recent in-memory state.
 46type Service interface {
 47	pubsub.Subscriber[Message]
 48	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 49	Update(ctx context.Context, message Message) error
 50	Get(ctx context.Context, id string) (Message, error)
 51	List(ctx context.Context, sessionID string) ([]Message, error)
 52	ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
 53	ListAllUserMessages(ctx context.Context) ([]Message, error)
 54	Delete(ctx context.Context, id string) error
 55	DeleteSessionMessages(ctx context.Context, sessionID string) error
 56
 57	// Flush synchronously drains any pending debounced state for the
 58	// given message ID, performs the SQL write, and publishes the
 59	// resulting [pubsub.UpdatedEvent]. Idempotent; cheap no-op if no
 60	// updates are pending. Use this before any read that must observe
 61	// the latest [Service.Update].
 62	Flush(ctx context.Context, id string) error
 63
 64	// FlushAll synchronously drains pending debounced state for every
 65	// message known to the service. Intended for shutdown and
 66	// session-switch paths.
 67	FlushAll(ctx context.Context) error
 68}
 69
 70// pendingState holds the in-memory coalescing buffer for a single
 71// message ID. All fields except where noted are guarded by
 72// service.mu. The flushing flag serializes concurrent flushers for
 73// the same ID so SQL writes never reorder.
 74type pendingState struct {
 75	// latest is the most recent [Message] passed to [Service.Update]
 76	// that has not yet been flushed.
 77	latest Message
 78
 79	// dirty is true when latest contains state that has not been
 80	// written to SQL since the last successful flush.
 81	dirty bool
 82
 83	// flushing is true while a goroutine is performing the SQL write
 84	// for this ID. New updates are still accepted (and re-mark dirty)
 85	// but other flushers must back off.
 86	flushing bool
 87
 88	// timer is the active debounce timer, or nil if no flush is
 89	// scheduled. Stopped and reset when a terminal update preempts
 90	// the debounce window.
 91	timer *time.Timer
 92
 93	// lastFlushed is the snapshot most recently written to SQL. Used
 94	// as the baseline for terminal-state detection.
 95	lastFlushed Message
 96
 97	// hasFlushed is false until the first successful write for this
 98	// ID; until then lastFlushed is the zero value and must not be
 99	// treated as a real prior state.
100	hasFlushed bool
101}
102
103type service struct {
104	*pubsub.Broker[Message]
105	q        db.Querier
106	debounce time.Duration
107
108	mu      sync.Mutex
109	pending map[string]*pendingState
110}
111
112// ServiceOption configures a [Service] at construction.
113type ServiceOption func(*service)
114
115// WithDebounce overrides the debounce window for [Service.Update]. A
116// zero or negative value disables debouncing entirely (every update
117// flushes synchronously). Intended primarily for tests.
118func WithDebounce(d time.Duration) ServiceOption {
119	return func(s *service) {
120		s.debounce = d
121	}
122}
123
124func NewService(q db.Querier, opts ...ServiceOption) Service {
125	s := &service{
126		Broker:   pubsub.NewBroker[Message](),
127		q:        q,
128		debounce: defaultUpdateDebounce,
129		pending:  make(map[string]*pendingState),
130	}
131	for _, opt := range opts {
132		opt(s)
133	}
134	return s
135}
136
137func (s *service) Delete(ctx context.Context, id string) error {
138	message, err := s.Get(ctx, id)
139	if err != nil {
140		return err
141	}
142	err = s.q.DeleteMessage(ctx, message.ID)
143	if err != nil {
144		return err
145	}
146	// Drop any pending coalesced state for this ID. We never want to
147	// flush back over a deleted row.
148	s.mu.Lock()
149	if p, ok := s.pending[id]; ok {
150		if p.timer != nil {
151			p.timer.Stop()
152		}
153		delete(s.pending, id)
154	}
155	s.mu.Unlock()
156	// Clone the message before publishing to avoid race conditions with
157	// concurrent modifications to the Parts slice.
158	s.Publish(pubsub.DeletedEvent, message.Clone())
159	return nil
160}
161
162func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
163	if params.Role != Assistant {
164		params.Parts = append(params.Parts, Finish{
165			Reason: "stop",
166		})
167	}
168	partsJSON, err := marshalParts(params.Parts)
169	if err != nil {
170		return Message{}, err
171	}
172	isSummary := int64(0)
173	if params.IsSummaryMessage {
174		isSummary = 1
175	}
176	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
177		ID:               uuid.New().String(),
178		SessionID:        sessionID,
179		Role:             string(params.Role),
180		Parts:            string(partsJSON),
181		Model:            sql.NullString{String: string(params.Model), Valid: true},
182		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
183		IsSummaryMessage: isSummary,
184	})
185	if err != nil {
186		return Message{}, err
187	}
188	message, err := s.fromDBItem(dbMessage)
189	if err != nil {
190		return Message{}, err
191	}
192	// Clone the message before publishing to avoid race conditions with
193	// concurrent modifications to the Parts slice.
194	s.Publish(pubsub.CreatedEvent, message.Clone())
195	return message, nil
196}
197
198func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
199	messages, err := s.List(ctx, sessionID)
200	if err != nil {
201		return err
202	}
203	for _, message := range messages {
204		if message.SessionID == sessionID {
205			err = s.Delete(ctx, message.ID)
206			if err != nil {
207				return err
208			}
209		}
210	}
211	return nil
212}
213
214// Update accepts a new state for a message and either flushes
215// synchronously (terminal updates, debounce <= 0) or buffers it until
216// the next debounce tick. See [Service] for the contract.
217func (s *service) Update(ctx context.Context, msg Message) error {
218	cloned := msg.Clone()
219
220	// Zero or negative debounce: flush every update synchronously. This
221	// preserves the pre-coalescing behaviour for tests and any caller
222	// that explicitly opted out via [WithDebounce].
223	if s.debounce <= 0 {
224		s.mu.Lock()
225		p, ok := s.pending[msg.ID]
226		if !ok {
227			p = &pendingState{}
228			s.pending[msg.ID] = p
229		}
230		p.latest = cloned
231		p.dirty = true
232		s.mu.Unlock()
233		return s.flushOne(ctx, msg.ID, true)
234	}
235
236	s.mu.Lock()
237	p, ok := s.pending[msg.ID]
238	if !ok {
239		p = &pendingState{}
240		s.pending[msg.ID] = p
241	}
242	p.latest = cloned
243	p.dirty = true
244
245	var prev *Message
246	if p.hasFlushed {
247		prev = &p.lastFlushed
248	}
249	terminal := shouldFlushNow(prev, &cloned)
250
251	if terminal {
252		if p.timer != nil {
253			p.timer.Stop()
254			p.timer = nil
255		}
256		s.mu.Unlock()
257		return s.flushOne(ctx, msg.ID, true)
258	}
259
260	// Debounce: schedule a single flush per pending state. If a flush
261	// is already running we let it finish; the trailing dirty bit will
262	// be picked up by the next Update or by Flush.
263	if p.timer == nil && !p.flushing {
264		id := msg.ID
265		p.timer = time.AfterFunc(s.debounce, func() {
266			// Detached from caller ctx so a cancelled stream context
267			// does not strand the buffered write.
268			_ = s.flushOne(context.Background(), id, false)
269		})
270	}
271	s.mu.Unlock()
272	return nil
273}
274
275// Flush implements [Service.Flush].
276func (s *service) Flush(ctx context.Context, id string) error {
277	return s.flushOne(ctx, id, true)
278}
279
280// FlushAll implements [Service.FlushAll]. It snapshots every ID with
281// outstanding work — either dirty buffered state or a flush already in
282// flight — then drains each one. Picking up in-flight IDs ensures
283// FlushAll cannot return while a timer-fired write is still mid-SQL,
284// which is what shutdown and session-switch callers rely on.
285func (s *service) FlushAll(ctx context.Context) error {
286	s.mu.Lock()
287	ids := make([]string, 0, len(s.pending))
288	for id, p := range s.pending {
289		if p.dirty || p.flushing {
290			ids = append(ids, id)
291		}
292	}
293	s.mu.Unlock()
294	var firstErr error
295	for _, id := range ids {
296		if err := s.flushOne(ctx, id, true); err != nil && firstErr == nil {
297			firstErr = err
298		}
299	}
300	return firstErr
301}
302
303// flushOne drains a single message ID. When syncCaller is true the
304// caller is willing to wait through a concurrent in-flight flush so
305// that, on return, lastFlushed equals latest at the moment of return.
306// When false (timer-fired path) we bail if another flusher is already
307// running; that flusher will pick up the trailing dirty bit.
308//
309// Order matters: a sync caller must wait for any in-flight flush to
310// drain even when the buffer is currently clean — that in-flight
311// write has not yet updated the SQL row, so returning early would
312// violate the contract that on success lastFlushed reflects the most
313// recent state.
314func (s *service) flushOne(ctx context.Context, id string, syncCaller bool) error {
315	for {
316		s.mu.Lock()
317		p, ok := s.pending[id]
318		if !ok {
319			s.mu.Unlock()
320			return nil
321		}
322		if p.flushing {
323			if !syncCaller {
324				s.mu.Unlock()
325				return nil
326			}
327			s.mu.Unlock()
328			// Brief yield; in-flight write should land in <1ms typical.
329			time.Sleep(time.Millisecond)
330			continue
331		}
332		if !p.dirty {
333			s.mu.Unlock()
334			return nil
335		}
336
337		if p.timer != nil {
338			p.timer.Stop()
339			p.timer = nil
340		}
341		snap := p.latest
342		// Decide whether this snapshot represents a terminal event
343		// against the prior baseline. We must do this before resetting
344		// dirty/flushing because shouldFlushNow looks at p.lastFlushed
345		// (which is what was on disk before this write).
346		var prev *Message
347		if p.hasFlushed {
348			prev = &p.lastFlushed
349		}
350		isTerminal := shouldFlushNow(prev, &snap)
351		p.flushing = true
352		p.dirty = false
353		s.mu.Unlock()
354
355		err := s.write(ctx, snap)
356
357		s.mu.Lock()
358		p.flushing = false
359		if err == nil {
360			p.lastFlushed = snap
361			p.hasFlushed = true
362		} else {
363			// Restore dirty so the next caller retries.
364			p.dirty = true
365		}
366		// If a delta arrived during the SQL write and we are a sync
367		// caller, the user expects that delta to land too.
368		wasDirty := p.dirty
369		s.mu.Unlock()
370
371		if err != nil {
372			return err
373		}
374
375		// Terminal events — message finished, tool call added or
376		// finished, reasoning ended — use the bounded must-deliver
377		// path so they never get dropped under channel contention.
378		if isTerminal {
379			s.PublishMustDeliver(ctx, pubsub.UpdatedEvent, snap)
380		} else {
381			s.Publish(pubsub.UpdatedEvent, snap)
382		}
383
384		if wasDirty && syncCaller {
385			continue
386		}
387		return nil
388	}
389}
390
391// write performs the unguarded SQL write + UpdatedAt stamp. Caller
392// owns publishing.
393func (s *service) write(ctx context.Context, msg Message) error {
394	parts, err := marshalParts(msg.Parts)
395	if err != nil {
396		return err
397	}
398	finishedAt := sql.NullInt64{}
399	if f := msg.FinishPart(); f != nil {
400		finishedAt.Int64 = f.Time
401		finishedAt.Valid = true
402	}
403	if err := s.q.UpdateMessage(ctx, db.UpdateMessageParams{
404		ID:         msg.ID,
405		Parts:      string(parts),
406		FinishedAt: finishedAt,
407	}); err != nil {
408		return err
409	}
410	return nil
411}
412
413// shouldFlushNow returns true when next represents a structural
414// change that must not be silently coalesced: the message just
415// finished, the tool-call set grew, a tool call transitioned to
416// finished, or reasoning just finished. prev is the last-flushed
417// snapshot (or nil if no write has landed yet).
418func shouldFlushNow(prev, next *Message) bool {
419	if next.IsFinished() {
420		return true
421	}
422
423	var prevCalls []ToolCall
424	var prevReasoningFinishedAt int64
425	if prev != nil {
426		prevCalls = prev.ToolCalls()
427		prevReasoningFinishedAt = prev.ReasoningContent().FinishedAt
428	}
429	nextCalls := next.ToolCalls()
430	if len(nextCalls) != len(prevCalls) {
431		return true
432	}
433	for i := range nextCalls {
434		// Bounds-safe: lengths are equal here.
435		if nextCalls[i].Finished != prevCalls[i].Finished {
436			return true
437		}
438		// A tool call's input only matters once it has landed (Finished
439		// flips true). Earlier deltas to Input are debounced with the
440		// rest of the streaming state.
441	}
442	if next.ReasoningContent().FinishedAt > 0 && prevReasoningFinishedAt == 0 {
443		return true
444	}
445	return false
446}
447
448func (s *service) Get(ctx context.Context, id string) (Message, error) {
449	dbMessage, err := s.q.GetMessage(ctx, id)
450	if err != nil {
451		return Message{}, err
452	}
453	return s.fromDBItem(dbMessage)
454}
455
456func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
457	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
458	if err != nil {
459		return nil, err
460	}
461	messages := make([]Message, len(dbMessages))
462	for i, dbMessage := range dbMessages {
463		messages[i], err = s.fromDBItem(dbMessage)
464		if err != nil {
465			return nil, err
466		}
467	}
468	return messages, nil
469}
470
471func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
472	dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
473	if err != nil {
474		return nil, err
475	}
476	messages := make([]Message, len(dbMessages))
477	for i, dbMessage := range dbMessages {
478		messages[i], err = s.fromDBItem(dbMessage)
479		if err != nil {
480			return nil, err
481		}
482	}
483	return messages, nil
484}
485
486func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
487	dbMessages, err := s.q.ListAllUserMessages(ctx)
488	if err != nil {
489		return nil, err
490	}
491	messages := make([]Message, len(dbMessages))
492	for i, dbMessage := range dbMessages {
493		messages[i], err = s.fromDBItem(dbMessage)
494		if err != nil {
495			return nil, err
496		}
497	}
498	return messages, nil
499}
500
501func (s *service) fromDBItem(item db.Message) (Message, error) {
502	parts, err := unmarshalParts([]byte(item.Parts))
503	if err != nil {
504		return Message{}, err
505	}
506	return Message{
507		ID:               item.ID,
508		SessionID:        item.SessionID,
509		Role:             MessageRole(item.Role),
510		Parts:            parts,
511		Model:            item.Model.String,
512		Provider:         item.Provider.String,
513		CreatedAt:        item.CreatedAt,
514		UpdatedAt:        item.UpdatedAt,
515		IsSummaryMessage: item.IsSummaryMessage != 0,
516	}, nil
517}
518
519type partType string
520
521const (
522	reasoningType  partType = "reasoning"
523	textType       partType = "text"
524	imageURLType   partType = "image_url"
525	binaryType     partType = "binary"
526	toolCallType   partType = "tool_call"
527	toolResultType partType = "tool_result"
528	finishType     partType = "finish"
529)
530
531type partWrapper struct {
532	Type partType    `json:"type"`
533	Data ContentPart `json:"data"`
534}
535
536func marshalParts(parts []ContentPart) ([]byte, error) {
537	wrappedParts := make([]partWrapper, len(parts))
538
539	for i, part := range parts {
540		var typ partType
541
542		switch part.(type) {
543		case ReasoningContent:
544			typ = reasoningType
545		case TextContent:
546			typ = textType
547		case ImageURLContent:
548			typ = imageURLType
549		case BinaryContent:
550			typ = binaryType
551		case ToolCall:
552			typ = toolCallType
553		case ToolResult:
554			typ = toolResultType
555		case Finish:
556			typ = finishType
557		default:
558			return nil, fmt.Errorf("unknown part type: %T", part)
559		}
560
561		wrappedParts[i] = partWrapper{
562			Type: typ,
563			Data: part,
564		}
565	}
566	return json.Marshal(wrappedParts)
567}
568
569func unmarshalParts(data []byte) ([]ContentPart, error) {
570	temp := []json.RawMessage{}
571
572	if err := json.Unmarshal(data, &temp); err != nil {
573		return nil, err
574	}
575
576	parts := make([]ContentPart, 0)
577
578	for _, rawPart := range temp {
579		var wrapper struct {
580			Type partType        `json:"type"`
581			Data json.RawMessage `json:"data"`
582		}
583
584		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
585			return nil, err
586		}
587
588		switch wrapper.Type {
589		case reasoningType:
590			part := ReasoningContent{}
591			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
592				return nil, err
593			}
594			parts = append(parts, part)
595		case textType:
596			part := TextContent{}
597			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
598				return nil, err
599			}
600			parts = append(parts, part)
601		case imageURLType:
602			part := ImageURLContent{}
603			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
604				return nil, err
605			}
606			parts = append(parts, part)
607		case binaryType:
608			part := BinaryContent{}
609			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
610				return nil, err
611			}
612			parts = append(parts, part)
613		case toolCallType:
614			part := ToolCall{}
615			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
616				return nil, err
617			}
618			parts = append(parts, part)
619		case toolResultType:
620			part := ToolResult{}
621			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
622				return nil, err
623			}
624			parts = append(parts, part)
625		case finishType:
626			part := Finish{}
627			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
628				return nil, err
629			}
630			parts = append(parts, part)
631		default:
632			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
633		}
634	}
635
636	return parts, nil
637}