message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role     MessageRole
 17	Parts    []ContentPart
 18	Model    string
 19	Provider string
 20}
 21
 22type Service interface {
 23	pubsub.Suscriber[Message]
 24	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 25	Update(ctx context.Context, message Message) error
 26	Get(ctx context.Context, id string) (Message, error)
 27	List(ctx context.Context, sessionID string) ([]Message, error)
 28	Delete(ctx context.Context, id string) error
 29	DeleteSessionMessages(ctx context.Context, sessionID string) error
 30}
 31
 32type service struct {
 33	*pubsub.Broker[Message]
 34	q db.Querier
 35}
 36
 37func NewService(q db.Querier) Service {
 38	return &service{
 39		Broker: pubsub.NewBroker[Message](),
 40		q:      q,
 41	}
 42}
 43
 44func (s *service) Delete(ctx context.Context, id string) error {
 45	message, err := s.Get(ctx, id)
 46	if err != nil {
 47		return err
 48	}
 49	err = s.q.DeleteMessage(ctx, message.ID)
 50	if err != nil {
 51		return err
 52	}
 53	s.Publish(pubsub.DeletedEvent, message)
 54	return nil
 55}
 56
 57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 58	if params.Role != Assistant {
 59		params.Parts = append(params.Parts, Finish{
 60			Reason: "stop",
 61		})
 62	}
 63	partsJSON, err := marshallParts(params.Parts)
 64	if err != nil {
 65		return Message{}, err
 66	}
 67	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 68		ID:        uuid.New().String(),
 69		SessionID: sessionID,
 70		Role:      string(params.Role),
 71		Parts:     string(partsJSON),
 72		Model:     sql.NullString{String: string(params.Model), Valid: true},
 73		Provider:  sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 74	})
 75	if err != nil {
 76		return Message{}, err
 77	}
 78	message, err := s.fromDBItem(dbMessage)
 79	if err != nil {
 80		return Message{}, err
 81	}
 82	s.Publish(pubsub.CreatedEvent, message)
 83	return message, nil
 84}
 85
 86func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 87	messages, err := s.List(ctx, sessionID)
 88	if err != nil {
 89		return err
 90	}
 91	for _, message := range messages {
 92		if message.SessionID == sessionID {
 93			err = s.Delete(ctx, message.ID)
 94			if err != nil {
 95				return err
 96			}
 97		}
 98	}
 99	return nil
100}
101
102func (s *service) Update(ctx context.Context, message Message) error {
103	parts, err := marshallParts(message.Parts)
104	if err != nil {
105		return err
106	}
107	finishedAt := sql.NullInt64{}
108	if f := message.FinishPart(); f != nil {
109		finishedAt.Int64 = f.Time
110		finishedAt.Valid = true
111	}
112	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
113		ID:         message.ID,
114		Parts:      string(parts),
115		FinishedAt: finishedAt,
116	})
117	if err != nil {
118		return err
119	}
120	message.UpdatedAt = time.Now().Unix()
121	s.Publish(pubsub.UpdatedEvent, message)
122	return nil
123}
124
125func (s *service) Get(ctx context.Context, id string) (Message, error) {
126	dbMessage, err := s.q.GetMessage(ctx, id)
127	if err != nil {
128		return Message{}, err
129	}
130	return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
134	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
135	if err != nil {
136		return nil, err
137	}
138	messages := make([]Message, len(dbMessages))
139	for i, dbMessage := range dbMessages {
140		messages[i], err = s.fromDBItem(dbMessage)
141		if err != nil {
142			return nil, err
143		}
144	}
145	return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149	parts, err := unmarshallParts([]byte(item.Parts))
150	if err != nil {
151		return Message{}, err
152	}
153	return Message{
154		ID:        item.ID,
155		SessionID: item.SessionID,
156		Role:      MessageRole(item.Role),
157		Parts:     parts,
158		Model:     item.Model.String,
159		Provider:  item.Provider.String,
160		CreatedAt: item.CreatedAt,
161		UpdatedAt: item.UpdatedAt,
162	}, nil
163}
164
165type partType string
166
167const (
168	reasoningType  partType = "reasoning"
169	textType       partType = "text"
170	imageURLType   partType = "image_url"
171	binaryType     partType = "binary"
172	toolCallType   partType = "tool_call"
173	toolResultType partType = "tool_result"
174	finishType     partType = "finish"
175	retryType      partType = "retry"
176)
177
178type partWrapper struct {
179	Type partType    `json:"type"`
180	Data ContentPart `json:"data"`
181}
182
183func marshallParts(parts []ContentPart) ([]byte, error) {
184	wrappedParts := make([]partWrapper, len(parts))
185
186	for i, part := range parts {
187		var typ partType
188
189		switch part.(type) {
190		case ReasoningContent:
191			typ = reasoningType
192		case TextContent:
193			typ = textType
194		case ImageURLContent:
195			typ = imageURLType
196		case BinaryContent:
197			typ = binaryType
198		case ToolCall:
199			typ = toolCallType
200		case ToolResult:
201			typ = toolResultType
202		case Finish:
203			typ = finishType
204		case RetryContent:
205			typ = retryType
206		default:
207			return nil, fmt.Errorf("unknown part type: %T", part)
208		}
209
210		wrappedParts[i] = partWrapper{
211			Type: typ,
212			Data: part,
213		}
214	}
215	return json.Marshal(wrappedParts)
216}
217
218func unmarshallParts(data []byte) ([]ContentPart, error) {
219	temp := []json.RawMessage{}
220
221	if err := json.Unmarshal(data, &temp); err != nil {
222		return nil, err
223	}
224
225	parts := make([]ContentPart, 0)
226
227	for _, rawPart := range temp {
228		var wrapper struct {
229			Type partType        `json:"type"`
230			Data json.RawMessage `json:"data"`
231		}
232
233		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
234			return nil, err
235		}
236
237		switch wrapper.Type {
238		case reasoningType:
239			part := ReasoningContent{}
240			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
241				return nil, err
242			}
243			parts = append(parts, part)
244		case textType:
245			part := TextContent{}
246			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
247				return nil, err
248			}
249			parts = append(parts, part)
250		case imageURLType:
251			part := ImageURLContent{}
252			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253				return nil, err
254			}
255		case binaryType:
256			part := BinaryContent{}
257			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
258				return nil, err
259			}
260			parts = append(parts, part)
261		case toolCallType:
262			part := ToolCall{}
263			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
264				return nil, err
265			}
266			parts = append(parts, part)
267		case toolResultType:
268			part := ToolResult{}
269			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
270				return nil, err
271			}
272			parts = append(parts, part)
273		case finishType:
274			part := Finish{}
275			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
276				return nil, err
277			}
278			parts = append(parts, part)
279		case retryType:
280			part := RetryContent{}
281			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
282				return nil, err
283			}
284			parts = append(parts, part)
285		default:
286			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
287		}
288	}
289
290	return parts, nil
291}