message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21}
 22
 23type Service interface {
 24	pubsub.Subscriber[Message]
 25	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 26	Update(ctx context.Context, message Message) error
 27	Get(ctx context.Context, id string) (Message, error)
 28	List(ctx context.Context, sessionID string) ([]Message, error)
 29	ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
 30	ListAllUserMessages(ctx context.Context) ([]Message, error)
 31	Copy(ctx context.Context, sessionID string, message Message) (Message, error)
 32	Delete(ctx context.Context, id string) error
 33	DeleteSessionMessages(ctx context.Context, sessionID string) error
 34	DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error
 35}
 36
 37type service struct {
 38	*pubsub.Broker[Message]
 39	q db.Querier
 40}
 41
 42func NewService(q db.Querier) Service {
 43	return &service{
 44		Broker: pubsub.NewBroker[Message](),
 45		q:      q,
 46	}
 47}
 48
 49func (s *service) Delete(ctx context.Context, id string) error {
 50	message, err := s.Get(ctx, id)
 51	if err != nil {
 52		return err
 53	}
 54	err = s.q.DeleteMessage(ctx, message.ID)
 55	if err != nil {
 56		return err
 57	}
 58	// Clone the message before publishing to avoid race conditions with
 59	// concurrent modifications to the Parts slice.
 60	s.Publish(pubsub.DeletedEvent, message.Clone())
 61	return nil
 62}
 63
 64func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 65	if params.Role != Assistant {
 66		params.Parts = append(params.Parts, Finish{
 67			Reason: "stop",
 68		})
 69	}
 70	partsJSON, err := marshalParts(params.Parts)
 71	if err != nil {
 72		return Message{}, err
 73	}
 74	isSummary := int64(0)
 75	if params.IsSummaryMessage {
 76		isSummary = 1
 77	}
 78	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 79		ID:               uuid.New().String(),
 80		SessionID:        sessionID,
 81		Role:             string(params.Role),
 82		Parts:            string(partsJSON),
 83		Model:            sql.NullString{String: string(params.Model), Valid: true},
 84		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 85		IsSummaryMessage: isSummary,
 86	})
 87	if err != nil {
 88		return Message{}, err
 89	}
 90	message, err := s.fromDBItem(dbMessage)
 91	if err != nil {
 92		return Message{}, err
 93	}
 94	// Clone the message before publishing to avoid race conditions with
 95	// concurrent modifications to the Parts slice.
 96	s.Publish(pubsub.CreatedEvent, message.Clone())
 97	return message, nil
 98}
 99
100func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) {
101	params := CreateMessageParams{
102		Role:             message.Role,
103		Parts:            message.Parts,
104		Model:            message.Model,
105		Provider:         message.Provider,
106		IsSummaryMessage: message.IsSummaryMessage,
107	}
108	return s.Create(ctx, sessionID, params)
109}
110
111func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
112	messages, err := s.List(ctx, sessionID)
113	if err != nil {
114		return err
115	}
116	for _, message := range messages {
117		if message.SessionID == sessionID {
118			err = s.Delete(ctx, message.ID)
119			if err != nil {
120				return err
121			}
122		}
123	}
124	return nil
125}
126
127func (s *service) DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error {
128	allMessages, err := s.List(ctx, sessionID)
129	if err != nil {
130		return err
131	}
132	targetIndex := -1
133	for i, msg := range allMessages {
134		if msg.ID == messageID {
135			targetIndex = i
136			break
137		}
138	}
139	if targetIndex == -1 {
140		return fmt.Errorf("message not found: %s", messageID)
141	}
142	for i := len(allMessages) - 1; i >= targetIndex; i-- {
143		if err := s.Delete(ctx, allMessages[i].ID); err != nil {
144			return err
145		}
146	}
147	return nil
148}
149
150func (s *service) Update(ctx context.Context, message Message) error {
151	parts, err := marshalParts(message.Parts)
152	if err != nil {
153		return err
154	}
155	finishedAt := sql.NullInt64{}
156	if f := message.FinishPart(); f != nil {
157		finishedAt.Int64 = f.Time
158		finishedAt.Valid = true
159	}
160	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
161		ID:         message.ID,
162		Parts:      string(parts),
163		FinishedAt: finishedAt,
164	})
165	if err != nil {
166		return err
167	}
168	message.UpdatedAt = time.Now().Unix()
169	// Clone the message before publishing to avoid race conditions with
170	// concurrent modifications to the Parts slice.
171	s.Publish(pubsub.UpdatedEvent, message.Clone())
172	return nil
173}
174
175func (s *service) Get(ctx context.Context, id string) (Message, error) {
176	dbMessage, err := s.q.GetMessage(ctx, id)
177	if err != nil {
178		return Message{}, err
179	}
180	return s.fromDBItem(dbMessage)
181}
182
183func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
184	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
185	if err != nil {
186		return nil, err
187	}
188	messages := make([]Message, len(dbMessages))
189	for i, dbMessage := range dbMessages {
190		messages[i], err = s.fromDBItem(dbMessage)
191		if err != nil {
192			return nil, err
193		}
194	}
195	return messages, nil
196}
197
198func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
199	dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
200	if err != nil {
201		return nil, err
202	}
203	messages := make([]Message, len(dbMessages))
204	for i, dbMessage := range dbMessages {
205		messages[i], err = s.fromDBItem(dbMessage)
206		if err != nil {
207			return nil, err
208		}
209	}
210	return messages, nil
211}
212
213func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
214	dbMessages, err := s.q.ListAllUserMessages(ctx)
215	if err != nil {
216		return nil, err
217	}
218	messages := make([]Message, len(dbMessages))
219	for i, dbMessage := range dbMessages {
220		messages[i], err = s.fromDBItem(dbMessage)
221		if err != nil {
222			return nil, err
223		}
224	}
225	return messages, nil
226}
227
228func (s *service) fromDBItem(item db.Message) (Message, error) {
229	parts, err := unmarshalParts([]byte(item.Parts))
230	if err != nil {
231		return Message{}, err
232	}
233	return Message{
234		ID:               item.ID,
235		SessionID:        item.SessionID,
236		Role:             MessageRole(item.Role),
237		Parts:            parts,
238		Model:            item.Model.String,
239		Provider:         item.Provider.String,
240		CreatedAt:        item.CreatedAt,
241		UpdatedAt:        item.UpdatedAt,
242		IsSummaryMessage: item.IsSummaryMessage != 0,
243	}, nil
244}
245
246type partType string
247
248const (
249	reasoningType  partType = "reasoning"
250	textType       partType = "text"
251	imageURLType   partType = "image_url"
252	binaryType     partType = "binary"
253	toolCallType   partType = "tool_call"
254	toolResultType partType = "tool_result"
255	finishType     partType = "finish"
256)
257
258type partWrapper struct {
259	Type partType    `json:"type"`
260	Data ContentPart `json:"data"`
261}
262
263func marshalParts(parts []ContentPart) ([]byte, error) {
264	wrappedParts := make([]partWrapper, len(parts))
265
266	for i, part := range parts {
267		var typ partType
268
269		switch part.(type) {
270		case ReasoningContent:
271			typ = reasoningType
272		case TextContent:
273			typ = textType
274		case ImageURLContent:
275			typ = imageURLType
276		case BinaryContent:
277			typ = binaryType
278		case ToolCall:
279			typ = toolCallType
280		case ToolResult:
281			typ = toolResultType
282		case Finish:
283			typ = finishType
284		default:
285			return nil, fmt.Errorf("unknown part type: %T", part)
286		}
287
288		wrappedParts[i] = partWrapper{
289			Type: typ,
290			Data: part,
291		}
292	}
293	return json.Marshal(wrappedParts)
294}
295
296func unmarshalParts(data []byte) ([]ContentPart, error) {
297	temp := []json.RawMessage{}
298
299	if err := json.Unmarshal(data, &temp); err != nil {
300		return nil, err
301	}
302
303	parts := make([]ContentPart, 0)
304
305	for _, rawPart := range temp {
306		var wrapper struct {
307			Type partType        `json:"type"`
308			Data json.RawMessage `json:"data"`
309		}
310
311		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
312			return nil, err
313		}
314
315		switch wrapper.Type {
316		case reasoningType:
317			part := ReasoningContent{}
318			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
319				return nil, err
320			}
321			parts = append(parts, part)
322		case textType:
323			part := TextContent{}
324			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
325				return nil, err
326			}
327			parts = append(parts, part)
328		case imageURLType:
329			part := ImageURLContent{}
330			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
331				return nil, err
332			}
333			parts = append(parts, part)
334		case binaryType:
335			part := BinaryContent{}
336			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
337				return nil, err
338			}
339			parts = append(parts, part)
340		case toolCallType:
341			part := ToolCall{}
342			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
343				return nil, err
344			}
345			parts = append(parts, part)
346		case toolResultType:
347			part := ToolResult{}
348			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
349				return nil, err
350			}
351			parts = append(parts, part)
352		case finishType:
353			part := Finish{}
354			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
355				return nil, err
356			}
357			parts = append(parts, part)
358		default:
359			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
360		}
361	}
362
363	return parts, nil
364}