message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/llm/models"
 12	"github.com/charmbracelet/crush/internal/pubsub"
 13	"github.com/google/uuid"
 14)
 15
 16type CreateMessageParams struct {
 17	Role  MessageRole
 18	Parts []ContentPart
 19	Model models.ModelID
 20}
 21
 22type Service interface {
 23	pubsub.Suscriber[Message]
 24	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 25	Update(ctx context.Context, message Message) error
 26	Get(ctx context.Context, id string) (Message, error)
 27	List(ctx context.Context, sessionID string) ([]Message, error)
 28	Delete(ctx context.Context, id string) error
 29	DeleteSessionMessages(ctx context.Context, sessionID string) error
 30}
 31
 32type service struct {
 33	*pubsub.Broker[Message]
 34	q db.Querier
 35}
 36
 37func NewService(q db.Querier) Service {
 38	return &service{
 39		Broker: pubsub.NewBroker[Message](),
 40		q:      q,
 41	}
 42}
 43
 44func (s *service) Delete(ctx context.Context, id string) error {
 45	message, err := s.Get(ctx, id)
 46	if err != nil {
 47		return err
 48	}
 49	err = s.q.DeleteMessage(ctx, message.ID)
 50	if err != nil {
 51		return err
 52	}
 53	s.Publish(pubsub.DeletedEvent, message)
 54	return nil
 55}
 56
 57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 58	if params.Role != Assistant {
 59		params.Parts = append(params.Parts, Finish{
 60			Reason: "stop",
 61		})
 62	}
 63	partsJSON, err := marshallParts(params.Parts)
 64	if err != nil {
 65		return Message{}, err
 66	}
 67	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 68		ID:        uuid.New().String(),
 69		SessionID: sessionID,
 70		Role:      string(params.Role),
 71		Parts:     string(partsJSON),
 72		Model:     sql.NullString{String: string(params.Model), Valid: true},
 73	})
 74	if err != nil {
 75		return Message{}, err
 76	}
 77	message, err := s.fromDBItem(dbMessage)
 78	if err != nil {
 79		return Message{}, err
 80	}
 81	s.Publish(pubsub.CreatedEvent, message)
 82	return message, nil
 83}
 84
 85func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 86	messages, err := s.List(ctx, sessionID)
 87	if err != nil {
 88		return err
 89	}
 90	for _, message := range messages {
 91		if message.SessionID == sessionID {
 92			err = s.Delete(ctx, message.ID)
 93			if err != nil {
 94				return err
 95			}
 96		}
 97	}
 98	return nil
 99}
100
101func (s *service) Update(ctx context.Context, message Message) error {
102	parts, err := marshallParts(message.Parts)
103	if err != nil {
104		return err
105	}
106	finishedAt := sql.NullInt64{}
107	if f := message.FinishPart(); f != nil {
108		finishedAt.Int64 = f.Time
109		finishedAt.Valid = true
110	}
111	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
112		ID:         message.ID,
113		Parts:      string(parts),
114		FinishedAt: finishedAt,
115	})
116	if err != nil {
117		return err
118	}
119	message.UpdatedAt = time.Now().Unix()
120	s.Publish(pubsub.UpdatedEvent, message)
121	return nil
122}
123
124func (s *service) Get(ctx context.Context, id string) (Message, error) {
125	dbMessage, err := s.q.GetMessage(ctx, id)
126	if err != nil {
127		return Message{}, err
128	}
129	return s.fromDBItem(dbMessage)
130}
131
132func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
133	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
134	if err != nil {
135		return nil, err
136	}
137	messages := make([]Message, len(dbMessages))
138	for i, dbMessage := range dbMessages {
139		messages[i], err = s.fromDBItem(dbMessage)
140		if err != nil {
141			return nil, err
142		}
143	}
144	return messages, nil
145}
146
147func (s *service) fromDBItem(item db.Message) (Message, error) {
148	parts, err := unmarshallParts([]byte(item.Parts))
149	if err != nil {
150		return Message{}, err
151	}
152	return Message{
153		ID:        item.ID,
154		SessionID: item.SessionID,
155		Role:      MessageRole(item.Role),
156		Parts:     parts,
157		Model:     models.ModelID(item.Model.String),
158		CreatedAt: item.CreatedAt,
159		UpdatedAt: item.UpdatedAt,
160	}, nil
161}
162
163type partType string
164
165const (
166	reasoningType  partType = "reasoning"
167	textType       partType = "text"
168	imageURLType   partType = "image_url"
169	binaryType     partType = "binary"
170	toolCallType   partType = "tool_call"
171	toolResultType partType = "tool_result"
172	finishType     partType = "finish"
173)
174
175type partWrapper struct {
176	Type partType    `json:"type"`
177	Data ContentPart `json:"data"`
178}
179
180func marshallParts(parts []ContentPart) ([]byte, error) {
181	wrappedParts := make([]partWrapper, len(parts))
182
183	for i, part := range parts {
184		var typ partType
185
186		switch part.(type) {
187		case ReasoningContent:
188			typ = reasoningType
189		case TextContent:
190			typ = textType
191		case ImageURLContent:
192			typ = imageURLType
193		case BinaryContent:
194			typ = binaryType
195		case ToolCall:
196			typ = toolCallType
197		case ToolResult:
198			typ = toolResultType
199		case Finish:
200			typ = finishType
201		default:
202			return nil, fmt.Errorf("unknown part type: %T", part)
203		}
204
205		wrappedParts[i] = partWrapper{
206			Type: typ,
207			Data: part,
208		}
209	}
210	return json.Marshal(wrappedParts)
211}
212
213func unmarshallParts(data []byte) ([]ContentPart, error) {
214	temp := []json.RawMessage{}
215
216	if err := json.Unmarshal(data, &temp); err != nil {
217		return nil, err
218	}
219
220	parts := make([]ContentPart, 0)
221
222	for _, rawPart := range temp {
223		var wrapper struct {
224			Type partType        `json:"type"`
225			Data json.RawMessage `json:"data"`
226		}
227
228		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
229			return nil, err
230		}
231
232		switch wrapper.Type {
233		case reasoningType:
234			part := ReasoningContent{}
235			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
236				return nil, err
237			}
238			parts = append(parts, part)
239		case textType:
240			part := TextContent{}
241			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
242				return nil, err
243			}
244			parts = append(parts, part)
245		case imageURLType:
246			part := ImageURLContent{}
247			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
248				return nil, err
249			}
250		case binaryType:
251			part := BinaryContent{}
252			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
253				return nil, err
254			}
255			parts = append(parts, part)
256		case toolCallType:
257			part := ToolCall{}
258			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
259				return nil, err
260			}
261			parts = append(parts, part)
262		case toolResultType:
263			part := ToolResult{}
264			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
265				return nil, err
266			}
267			parts = append(parts, part)
268		case finishType:
269			part := Finish{}
270			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
271				return nil, err
272			}
273			parts = append(parts, part)
274		default:
275			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
276		}
277	}
278
279	return parts, nil
280}