message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21}
 22
 23type Service interface {
 24	pubsub.Subscriber[Message]
 25	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 26	Update(ctx context.Context, message Message) error
 27	Get(ctx context.Context, id string) (Message, error)
 28	List(ctx context.Context, sessionID string) ([]Message, error)
 29	Delete(ctx context.Context, id string) error
 30	DeleteSessionMessages(ctx context.Context, sessionID string) error
 31}
 32
 33type service struct {
 34	*pubsub.Broker[Message]
 35	q db.Querier
 36}
 37
 38func NewService(q db.Querier) Service {
 39	return &service{
 40		Broker: pubsub.NewBroker[Message](),
 41		q:      q,
 42	}
 43}
 44
 45func (s *service) Delete(ctx context.Context, id string) error {
 46	message, err := s.Get(ctx, id)
 47	if err != nil {
 48		return err
 49	}
 50	err = s.q.DeleteMessage(ctx, message.ID)
 51	if err != nil {
 52		return err
 53	}
 54	// Clone the message before publishing to avoid race conditions with
 55	// concurrent modifications to the Parts slice.
 56	s.Publish(pubsub.DeletedEvent, message.Clone())
 57	return nil
 58}
 59
 60func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 61	if params.Role != Assistant {
 62		params.Parts = append(params.Parts, Finish{
 63			Reason: "stop",
 64		})
 65	}
 66	partsJSON, err := marshallParts(params.Parts)
 67	if err != nil {
 68		return Message{}, err
 69	}
 70	isSummary := int64(0)
 71	if params.IsSummaryMessage {
 72		isSummary = 1
 73	}
 74	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 75		ID:               uuid.New().String(),
 76		SessionID:        sessionID,
 77		Role:             string(params.Role),
 78		Parts:            string(partsJSON),
 79		Model:            sql.NullString{String: string(params.Model), Valid: true},
 80		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 81		IsSummaryMessage: isSummary,
 82	})
 83	if err != nil {
 84		return Message{}, err
 85	}
 86	message, err := s.fromDBItem(dbMessage)
 87	if err != nil {
 88		return Message{}, err
 89	}
 90	// Clone the message before publishing to avoid race conditions with
 91	// concurrent modifications to the Parts slice.
 92	s.Publish(pubsub.CreatedEvent, message.Clone())
 93	return message, nil
 94}
 95
 96func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 97	messages, err := s.List(ctx, sessionID)
 98	if err != nil {
 99		return err
100	}
101	for _, message := range messages {
102		if message.SessionID == sessionID {
103			err = s.Delete(ctx, message.ID)
104			if err != nil {
105				return err
106			}
107		}
108	}
109	return nil
110}
111
112func (s *service) Update(ctx context.Context, message Message) error {
113	parts, err := marshallParts(message.Parts)
114	if err != nil {
115		return err
116	}
117	finishedAt := sql.NullInt64{}
118	if f := message.FinishPart(); f != nil {
119		finishedAt.Int64 = f.Time
120		finishedAt.Valid = true
121	}
122	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
123		ID:         message.ID,
124		Parts:      string(parts),
125		FinishedAt: finishedAt,
126	})
127	if err != nil {
128		return err
129	}
130	message.UpdatedAt = time.Now().Unix()
131	// Clone the message before publishing to avoid race conditions with
132	// concurrent modifications to the Parts slice.
133	s.Publish(pubsub.UpdatedEvent, message.Clone())
134	return nil
135}
136
137func (s *service) Get(ctx context.Context, id string) (Message, error) {
138	dbMessage, err := s.q.GetMessage(ctx, id)
139	if err != nil {
140		return Message{}, err
141	}
142	return s.fromDBItem(dbMessage)
143}
144
145func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
146	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
147	if err != nil {
148		return nil, err
149	}
150	messages := make([]Message, len(dbMessages))
151	for i, dbMessage := range dbMessages {
152		messages[i], err = s.fromDBItem(dbMessage)
153		if err != nil {
154			return nil, err
155		}
156	}
157	return messages, nil
158}
159
160func (s *service) fromDBItem(item db.Message) (Message, error) {
161	parts, err := unmarshallParts([]byte(item.Parts))
162	if err != nil {
163		return Message{}, err
164	}
165	return Message{
166		ID:               item.ID,
167		SessionID:        item.SessionID,
168		Role:             MessageRole(item.Role),
169		Parts:            parts,
170		Model:            item.Model.String,
171		Provider:         item.Provider.String,
172		CreatedAt:        item.CreatedAt,
173		UpdatedAt:        item.UpdatedAt,
174		IsSummaryMessage: item.IsSummaryMessage != 0,
175	}, nil
176}
177
178type partType string
179
180const (
181	reasoningType  partType = "reasoning"
182	textType       partType = "text"
183	imageURLType   partType = "image_url"
184	binaryType     partType = "binary"
185	toolCallType   partType = "tool_call"
186	toolResultType partType = "tool_result"
187	finishType     partType = "finish"
188)
189
190type partWrapper struct {
191	Type partType    `json:"type"`
192	Data ContentPart `json:"data"`
193}
194
195func marshallParts(parts []ContentPart) ([]byte, error) {
196	wrappedParts := make([]partWrapper, len(parts))
197
198	for i, part := range parts {
199		var typ partType
200
201		switch part.(type) {
202		case ReasoningContent:
203			typ = reasoningType
204		case TextContent:
205			typ = textType
206		case ImageURLContent:
207			typ = imageURLType
208		case BinaryContent:
209			typ = binaryType
210		case ToolCall:
211			typ = toolCallType
212		case ToolResult:
213			typ = toolResultType
214		case Finish:
215			typ = finishType
216		default:
217			return nil, fmt.Errorf("unknown part type: %T", part)
218		}
219
220		wrappedParts[i] = partWrapper{
221			Type: typ,
222			Data: part,
223		}
224	}
225	return json.Marshal(wrappedParts)
226}
227
228func unmarshallParts(data []byte) ([]ContentPart, error) {
229	temp := []json.RawMessage{}
230
231	if err := json.Unmarshal(data, &temp); err != nil {
232		return nil, err
233	}
234
235	parts := make([]ContentPart, 0)
236
237	for _, rawPart := range temp {
238		var wrapper struct {
239			Type partType        `json:"type"`
240			Data json.RawMessage `json:"data"`
241		}
242
243		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
244			return nil, err
245		}
246
247		switch wrapper.Type {
248		case reasoningType:
249			part := ReasoningContent{}
250			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
251				return nil, err
252			}
253			parts = append(parts, part)
254		case textType:
255			part := TextContent{}
256			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
257				return nil, err
258			}
259			parts = append(parts, part)
260		case imageURLType:
261			part := ImageURLContent{}
262			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
263				return nil, err
264			}
265			parts = append(parts, part)
266		case binaryType:
267			part := BinaryContent{}
268			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
269				return nil, err
270			}
271			parts = append(parts, part)
272		case toolCallType:
273			part := ToolCall{}
274			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
275				return nil, err
276			}
277			parts = append(parts, part)
278		case toolResultType:
279			part := ToolResult{}
280			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
281				return nil, err
282			}
283			parts = append(parts, part)
284		case finishType:
285			part := Finish{}
286			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
287				return nil, err
288			}
289			parts = append(parts, part)
290		default:
291			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
292		}
293	}
294
295	return parts, nil
296}