message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21}
 22
 23type Service interface {
 24	pubsub.Suscriber[Message]
 25	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 26	Update(ctx context.Context, message Message) error
 27	Get(ctx context.Context, id string) (Message, error)
 28	List(ctx context.Context, sessionID string) ([]Message, error)
 29	FullList(ctx context.Context) ([]Message, error)
 30	Delete(ctx context.Context, id string) error
 31	DeleteSessionMessages(ctx context.Context, sessionID string) error
 32}
 33
 34type service struct {
 35	*pubsub.Broker[Message]
 36	q db.Querier
 37}
 38
 39func NewService(q db.Querier) Service {
 40	return &service{
 41		Broker: pubsub.NewBroker[Message](),
 42		q:      q,
 43	}
 44}
 45
 46func (s *service) Delete(ctx context.Context, id string) error {
 47	message, err := s.Get(ctx, id)
 48	if err != nil {
 49		return err
 50	}
 51	err = s.q.DeleteMessage(ctx, message.ID)
 52	if err != nil {
 53		return err
 54	}
 55	s.Publish(pubsub.DeletedEvent, message)
 56	return nil
 57}
 58
 59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 60	if params.Role != Assistant {
 61		params.Parts = append(params.Parts, Finish{
 62			Reason: "stop",
 63		})
 64	}
 65	partsJSON, err := marshallParts(params.Parts)
 66	if err != nil {
 67		return Message{}, err
 68	}
 69	isSummary := int64(0)
 70	if params.IsSummaryMessage {
 71		isSummary = 1
 72	}
 73	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 74		ID:               uuid.New().String(),
 75		SessionID:        sessionID,
 76		Role:             string(params.Role),
 77		Parts:            string(partsJSON),
 78		Model:            sql.NullString{String: string(params.Model), Valid: true},
 79		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 80		IsSummaryMessage: isSummary,
 81	})
 82	if err != nil {
 83		return Message{}, err
 84	}
 85	message, err := fromDBItem(dbMessage)
 86	if err != nil {
 87		return Message{}, err
 88	}
 89	s.Publish(pubsub.CreatedEvent, message)
 90	return message, nil
 91}
 92
 93func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 94	messages, err := s.List(ctx, sessionID)
 95	if err != nil {
 96		return err
 97	}
 98	for _, message := range messages {
 99		if message.SessionID == sessionID {
100			err = s.Delete(ctx, message.ID)
101			if err != nil {
102				return err
103			}
104		}
105	}
106	return nil
107}
108
109func (s *service) Update(ctx context.Context, message Message) error {
110	parts, err := marshallParts(message.Parts)
111	if err != nil {
112		return err
113	}
114	finishedAt := sql.NullInt64{}
115	if f := message.FinishPart(); f != nil {
116		finishedAt.Int64 = f.Time
117		finishedAt.Valid = true
118	}
119	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
120		ID:         message.ID,
121		Parts:      string(parts),
122		FinishedAt: finishedAt,
123	})
124	if err != nil {
125		return err
126	}
127	message.UpdatedAt = time.Now().Unix()
128	s.Publish(pubsub.UpdatedEvent, message)
129	return nil
130}
131
132func (s *service) Get(ctx context.Context, id string) (Message, error) {
133	dbMessage, err := s.q.GetMessage(ctx, id)
134	if err != nil {
135		return Message{}, err
136	}
137	return fromDBItem(dbMessage)
138}
139
140func convertDBMessagesToMessages(dbMessages []db.Message, err error) ([]Message, error) {
141	if err != nil {
142		return nil, err
143	}
144	messages := make([]Message, len(dbMessages))
145	for i, dbMessage := range dbMessages {
146		messages[i], err = fromDBItem(dbMessage)
147		if err != nil {
148			return nil, err
149		}
150	}
151	return messages, nil
152}
153
154func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
155	return convertDBMessagesToMessages(s.q.ListMessagesBySession(ctx, sessionID))
156}
157
158func (s *service) FullList(ctx context.Context) ([]Message, error) {
159	return convertDBMessagesToMessages(s.q.ListAllMessages(ctx))
160}
161
162func fromDBItem(item db.Message) (Message, error) {
163	parts, err := unmarshallParts([]byte(item.Parts))
164	if err != nil {
165		return Message{}, err
166	}
167	return Message{
168		ID:               item.ID,
169		SessionID:        item.SessionID,
170		Role:             MessageRole(item.Role),
171		Parts:            parts,
172		Model:            item.Model.String,
173		Provider:         item.Provider.String,
174		CreatedAt:        item.CreatedAt,
175		UpdatedAt:        item.UpdatedAt,
176		IsSummaryMessage: item.IsSummaryMessage != 0,
177	}, nil
178}
179
180type partType string
181
182const (
183	reasoningType  partType = "reasoning"
184	textType       partType = "text"
185	imageURLType   partType = "image_url"
186	binaryType     partType = "binary"
187	toolCallType   partType = "tool_call"
188	toolResultType partType = "tool_result"
189	finishType     partType = "finish"
190)
191
192type partWrapper struct {
193	Type partType    `json:"type"`
194	Data ContentPart `json:"data"`
195}
196
197func marshallParts(parts []ContentPart) ([]byte, error) {
198	wrappedParts := make([]partWrapper, len(parts))
199
200	for i, part := range parts {
201		var typ partType
202
203		switch part.(type) {
204		case ReasoningContent:
205			typ = reasoningType
206		case TextContent:
207			typ = textType
208		case ImageURLContent:
209			typ = imageURLType
210		case BinaryContent:
211			typ = binaryType
212		case ToolCall:
213			typ = toolCallType
214		case ToolResult:
215			typ = toolResultType
216		case Finish:
217			typ = finishType
218		default:
219			return nil, fmt.Errorf("unknown part type: %T", part)
220		}
221
222		wrappedParts[i] = partWrapper{
223			Type: typ,
224			Data: part,
225		}
226	}
227	return json.Marshal(wrappedParts)
228}
229
230func unmarshallParts(data []byte) ([]ContentPart, error) {
231	temp := []json.RawMessage{}
232
233	if err := json.Unmarshal(data, &temp); err != nil {
234		return nil, err
235	}
236
237	parts := make([]ContentPart, 0)
238
239	for _, rawPart := range temp {
240		var wrapper struct {
241			Type partType        `json:"type"`
242			Data json.RawMessage `json:"data"`
243		}
244
245		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
246			return nil, err
247		}
248
249		switch wrapper.Type {
250		case reasoningType:
251			part := ReasoningContent{}
252			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253				return nil, err
254			}
255			parts = append(parts, part)
256		case textType:
257			part := TextContent{}
258			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
259				return nil, err
260			}
261			parts = append(parts, part)
262		case imageURLType:
263			part := ImageURLContent{}
264			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
265				return nil, err
266			}
267			parts = append(parts, part)
268		case binaryType:
269			part := BinaryContent{}
270			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
271				return nil, err
272			}
273			parts = append(parts, part)
274		case toolCallType:
275			part := ToolCall{}
276			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
277				return nil, err
278			}
279			parts = append(parts, part)
280		case toolResultType:
281			part := ToolResult{}
282			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
283				return nil, err
284			}
285			parts = append(parts, part)
286		case finishType:
287			part := Finish{}
288			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
289				return nil, err
290			}
291			parts = append(parts, part)
292		default:
293			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
294		}
295	}
296
297	return parts, nil
298}