message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21}
 22
 23type Service interface {
 24	pubsub.Subscriber[Message]
 25	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 26	Update(ctx context.Context, message Message) error
 27	Get(ctx context.Context, id string) (Message, error)
 28	List(ctx context.Context, sessionID string) ([]Message, error)
 29	ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
 30	ListAllUserMessages(ctx context.Context) ([]Message, error)
 31	Delete(ctx context.Context, id string) error
 32	DeleteSessionMessages(ctx context.Context, sessionID string) error
 33}
 34
 35type service struct {
 36	*pubsub.Broker[Message]
 37	q db.Querier
 38}
 39
 40func NewService(q db.Querier) Service {
 41	return &service{
 42		Broker: pubsub.NewBroker[Message](),
 43		q:      q,
 44	}
 45}
 46
 47func (s *service) Delete(ctx context.Context, id string) error {
 48	message, err := s.Get(ctx, id)
 49	if err != nil {
 50		return err
 51	}
 52	err = s.q.DeleteMessage(ctx, message.ID)
 53	if err != nil {
 54		return err
 55	}
 56	// Clone the message before publishing to avoid race conditions with
 57	// concurrent modifications to the Parts slice.
 58	s.Publish(pubsub.DeletedEvent, message.Clone())
 59	return nil
 60}
 61
 62func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 63	if params.Role != Assistant {
 64		params.Parts = append(params.Parts, Finish{
 65			Reason: "stop",
 66		})
 67	}
 68	partsJSON, err := marshalParts(params.Parts)
 69	if err != nil {
 70		return Message{}, err
 71	}
 72	isSummary := int64(0)
 73	if params.IsSummaryMessage {
 74		isSummary = 1
 75	}
 76	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 77		ID:               uuid.New().String(),
 78		SessionID:        sessionID,
 79		Role:             string(params.Role),
 80		Parts:            string(partsJSON),
 81		Model:            sql.NullString{String: string(params.Model), Valid: true},
 82		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 83		IsSummaryMessage: isSummary,
 84	})
 85	if err != nil {
 86		return Message{}, err
 87	}
 88	message, err := s.fromDBItem(dbMessage)
 89	if err != nil {
 90		return Message{}, err
 91	}
 92	// Clone the message before publishing to avoid race conditions with
 93	// concurrent modifications to the Parts slice.
 94	s.Publish(pubsub.CreatedEvent, message.Clone())
 95	return message, nil
 96}
 97
 98func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 99	messages, err := s.List(ctx, sessionID)
100	if err != nil {
101		return err
102	}
103	for _, message := range messages {
104		if message.SessionID == sessionID {
105			err = s.Delete(ctx, message.ID)
106			if err != nil {
107				return err
108			}
109		}
110	}
111	return nil
112}
113
114func (s *service) Update(ctx context.Context, message Message) error {
115	parts, err := marshalParts(message.Parts)
116	if err != nil {
117		return err
118	}
119	finishedAt := sql.NullInt64{}
120	if f := message.FinishPart(); f != nil {
121		finishedAt.Int64 = f.Time
122		finishedAt.Valid = true
123	}
124	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
125		ID:         message.ID,
126		Parts:      string(parts),
127		FinishedAt: finishedAt,
128	})
129	if err != nil {
130		return err
131	}
132	message.UpdatedAt = time.Now().Unix()
133	// Clone the message before publishing to avoid race conditions with
134	// concurrent modifications to the Parts slice.
135	s.Publish(pubsub.UpdatedEvent, message.Clone())
136	return nil
137}
138
139func (s *service) Get(ctx context.Context, id string) (Message, error) {
140	dbMessage, err := s.q.GetMessage(ctx, id)
141	if err != nil {
142		return Message{}, err
143	}
144	return s.fromDBItem(dbMessage)
145}
146
147func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
148	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
149	if err != nil {
150		return nil, err
151	}
152	messages := make([]Message, len(dbMessages))
153	for i, dbMessage := range dbMessages {
154		messages[i], err = s.fromDBItem(dbMessage)
155		if err != nil {
156			return nil, err
157		}
158	}
159	return messages, nil
160}
161
162func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
163	dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
164	if err != nil {
165		return nil, err
166	}
167	messages := make([]Message, len(dbMessages))
168	for i, dbMessage := range dbMessages {
169		messages[i], err = s.fromDBItem(dbMessage)
170		if err != nil {
171			return nil, err
172		}
173	}
174	return messages, nil
175}
176
177func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
178	dbMessages, err := s.q.ListAllUserMessages(ctx)
179	if err != nil {
180		return nil, err
181	}
182	messages := make([]Message, len(dbMessages))
183	for i, dbMessage := range dbMessages {
184		messages[i], err = s.fromDBItem(dbMessage)
185		if err != nil {
186			return nil, err
187		}
188	}
189	return messages, nil
190}
191
192func (s *service) fromDBItem(item db.Message) (Message, error) {
193	parts, err := unmarshalParts([]byte(item.Parts))
194	if err != nil {
195		return Message{}, err
196	}
197	return Message{
198		ID:               item.ID,
199		SessionID:        item.SessionID,
200		Role:             MessageRole(item.Role),
201		Parts:            parts,
202		Model:            item.Model.String,
203		Provider:         item.Provider.String,
204		CreatedAt:        item.CreatedAt,
205		UpdatedAt:        item.UpdatedAt,
206		IsSummaryMessage: item.IsSummaryMessage != 0,
207	}, nil
208}
209
210type partType string
211
212const (
213	reasoningType  partType = "reasoning"
214	textType       partType = "text"
215	imageURLType   partType = "image_url"
216	binaryType     partType = "binary"
217	toolCallType   partType = "tool_call"
218	toolResultType partType = "tool_result"
219	finishType     partType = "finish"
220)
221
222type partWrapper struct {
223	Type partType    `json:"type"`
224	Data ContentPart `json:"data"`
225}
226
227func marshalParts(parts []ContentPart) ([]byte, error) {
228	wrappedParts := make([]partWrapper, len(parts))
229
230	for i, part := range parts {
231		var typ partType
232
233		switch part.(type) {
234		case ReasoningContent:
235			typ = reasoningType
236		case TextContent:
237			typ = textType
238		case ImageURLContent:
239			typ = imageURLType
240		case BinaryContent:
241			typ = binaryType
242		case ToolCall:
243			typ = toolCallType
244		case ToolResult:
245			typ = toolResultType
246		case Finish:
247			typ = finishType
248		default:
249			return nil, fmt.Errorf("unknown part type: %T", part)
250		}
251
252		wrappedParts[i] = partWrapper{
253			Type: typ,
254			Data: part,
255		}
256	}
257	return json.Marshal(wrappedParts)
258}
259
260func unmarshalParts(data []byte) ([]ContentPart, error) {
261	temp := []json.RawMessage{}
262
263	if err := json.Unmarshal(data, &temp); err != nil {
264		return nil, err
265	}
266
267	parts := make([]ContentPart, 0)
268
269	for _, rawPart := range temp {
270		var wrapper struct {
271			Type partType        `json:"type"`
272			Data json.RawMessage `json:"data"`
273		}
274
275		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
276			return nil, err
277		}
278
279		switch wrapper.Type {
280		case reasoningType:
281			part := ReasoningContent{}
282			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
283				return nil, err
284			}
285			parts = append(parts, part)
286		case textType:
287			part := TextContent{}
288			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
289				return nil, err
290			}
291			parts = append(parts, part)
292		case imageURLType:
293			part := ImageURLContent{}
294			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
295				return nil, err
296			}
297			parts = append(parts, part)
298		case binaryType:
299			part := BinaryContent{}
300			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
301				return nil, err
302			}
303			parts = append(parts, part)
304		case toolCallType:
305			part := ToolCall{}
306			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
307				return nil, err
308			}
309			parts = append(parts, part)
310		case toolResultType:
311			part := ToolResult{}
312			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
313				return nil, err
314			}
315			parts = append(parts, part)
316		case finishType:
317			part := Finish{}
318			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
319				return nil, err
320			}
321			parts = append(parts, part)
322		default:
323			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
324		}
325	}
326
327	return parts, nil
328}