message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21	HookOutputs      []HookOutput
 22}
 23
 24type Service interface {
 25	pubsub.Suscriber[Message]
 26	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 27	Update(ctx context.Context, message Message) error
 28	Get(ctx context.Context, id string) (Message, error)
 29	List(ctx context.Context, sessionID string) ([]Message, error)
 30	Delete(ctx context.Context, id string) error
 31	DeleteSessionMessages(ctx context.Context, sessionID string) error
 32}
 33
 34type service struct {
 35	*pubsub.Broker[Message]
 36	q db.Querier
 37}
 38
 39func NewService(q db.Querier) Service {
 40	return &service{
 41		Broker: pubsub.NewBroker[Message](),
 42		q:      q,
 43	}
 44}
 45
 46func (s *service) Delete(ctx context.Context, id string) error {
 47	message, err := s.Get(ctx, id)
 48	if err != nil {
 49		return err
 50	}
 51	err = s.q.DeleteMessage(ctx, message.ID)
 52	if err != nil {
 53		return err
 54	}
 55	s.Publish(pubsub.DeletedEvent, message)
 56	return nil
 57}
 58
 59func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 60	if params.Role != Assistant {
 61		params.Parts = append(params.Parts, Finish{
 62			Reason: "stop",
 63		})
 64	}
 65	partsJSON, err := marshallParts(params.Parts)
 66	if err != nil {
 67		return Message{}, err
 68	}
 69	isSummary := int64(0)
 70	if params.IsSummaryMessage {
 71		isSummary = 1
 72	}
 73	hookOutputsJSON, err := json.Marshal(params.HookOutputs)
 74	if err != nil {
 75		return Message{}, err
 76	}
 77	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 78		ID:               uuid.New().String(),
 79		SessionID:        sessionID,
 80		Role:             string(params.Role),
 81		Parts:            string(partsJSON),
 82		Model:            sql.NullString{String: string(params.Model), Valid: true},
 83		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 84		IsSummaryMessage: isSummary,
 85		HookOutputs:      string(hookOutputsJSON),
 86	})
 87	if err != nil {
 88		return Message{}, err
 89	}
 90	message, err := s.fromDBItem(dbMessage)
 91	if err != nil {
 92		return Message{}, err
 93	}
 94	s.Publish(pubsub.CreatedEvent, message)
 95	return message, nil
 96}
 97
 98func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 99	messages, err := s.List(ctx, sessionID)
100	if err != nil {
101		return err
102	}
103	for _, message := range messages {
104		if message.SessionID == sessionID {
105			err = s.Delete(ctx, message.ID)
106			if err != nil {
107				return err
108			}
109		}
110	}
111	return nil
112}
113
114func (s *service) Update(ctx context.Context, message Message) error {
115	parts, err := marshallParts(message.Parts)
116	if err != nil {
117		return err
118	}
119	finishedAt := sql.NullInt64{}
120	if f := message.FinishPart(); f != nil {
121		finishedAt.Int64 = f.Time
122		finishedAt.Valid = true
123	}
124	hookOutputsJSON, err := json.Marshal(message.HookOutputs)
125	if err != nil {
126		return err
127	}
128	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
129		ID:          message.ID,
130		Parts:       string(parts),
131		FinishedAt:  finishedAt,
132		HookOutputs: string(hookOutputsJSON),
133	})
134	if err != nil {
135		return err
136	}
137	message.UpdatedAt = time.Now().Unix()
138	s.Publish(pubsub.UpdatedEvent, message)
139	return nil
140}
141
142func (s *service) Get(ctx context.Context, id string) (Message, error) {
143	dbMessage, err := s.q.GetMessage(ctx, id)
144	if err != nil {
145		return Message{}, err
146	}
147	return s.fromDBItem(dbMessage)
148}
149
150func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
151	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
152	if err != nil {
153		return nil, err
154	}
155	messages := make([]Message, len(dbMessages))
156	for i, dbMessage := range dbMessages {
157		messages[i], err = s.fromDBItem(dbMessage)
158		if err != nil {
159			return nil, err
160		}
161	}
162	return messages, nil
163}
164
165func (s *service) fromDBItem(item db.Message) (Message, error) {
166	parts, err := unmarshallParts([]byte(item.Parts))
167	if err != nil {
168		return Message{}, err
169	}
170	var hookOutputs []HookOutput
171	if item.HookOutputs != "" {
172		if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil {
173			return Message{}, err
174		}
175	}
176	return Message{
177		ID:               item.ID,
178		SessionID:        item.SessionID,
179		Role:             MessageRole(item.Role),
180		Parts:            parts,
181		Model:            item.Model.String,
182		Provider:         item.Provider.String,
183		CreatedAt:        item.CreatedAt,
184		UpdatedAt:        item.UpdatedAt,
185		IsSummaryMessage: item.IsSummaryMessage != 0,
186		HookOutputs:      hookOutputs,
187	}, nil
188}
189
190type partType string
191
192const (
193	reasoningType  partType = "reasoning"
194	textType       partType = "text"
195	imageURLType   partType = "image_url"
196	binaryType     partType = "binary"
197	toolCallType   partType = "tool_call"
198	toolResultType partType = "tool_result"
199	finishType     partType = "finish"
200)
201
202type partWrapper struct {
203	Type partType    `json:"type"`
204	Data ContentPart `json:"data"`
205}
206
207func marshallParts(parts []ContentPart) ([]byte, error) {
208	wrappedParts := make([]partWrapper, len(parts))
209
210	for i, part := range parts {
211		var typ partType
212
213		switch part.(type) {
214		case ReasoningContent:
215			typ = reasoningType
216		case TextContent:
217			typ = textType
218		case ImageURLContent:
219			typ = imageURLType
220		case BinaryContent:
221			typ = binaryType
222		case ToolCall:
223			typ = toolCallType
224		case ToolResult:
225			typ = toolResultType
226		case Finish:
227			typ = finishType
228		default:
229			return nil, fmt.Errorf("unknown part type: %T", part)
230		}
231
232		wrappedParts[i] = partWrapper{
233			Type: typ,
234			Data: part,
235		}
236	}
237	return json.Marshal(wrappedParts)
238}
239
240func unmarshallParts(data []byte) ([]ContentPart, error) {
241	temp := []json.RawMessage{}
242
243	if err := json.Unmarshal(data, &temp); err != nil {
244		return nil, err
245	}
246
247	parts := make([]ContentPart, 0)
248
249	for _, rawPart := range temp {
250		var wrapper struct {
251			Type partType        `json:"type"`
252			Data json.RawMessage `json:"data"`
253		}
254
255		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
256			return nil, err
257		}
258
259		switch wrapper.Type {
260		case reasoningType:
261			part := ReasoningContent{}
262			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
263				return nil, err
264			}
265			parts = append(parts, part)
266		case textType:
267			part := TextContent{}
268			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
269				return nil, err
270			}
271			parts = append(parts, part)
272		case imageURLType:
273			part := ImageURLContent{}
274			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
275				return nil, err
276			}
277			parts = append(parts, part)
278		case binaryType:
279			part := BinaryContent{}
280			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
281				return nil, err
282			}
283			parts = append(parts, part)
284		case toolCallType:
285			part := ToolCall{}
286			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
287				return nil, err
288			}
289			parts = append(parts, part)
290		case toolResultType:
291			part := ToolResult{}
292			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
293				return nil, err
294			}
295			parts = append(parts, part)
296		case finishType:
297			part := Finish{}
298			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
299				return nil, err
300			}
301			parts = append(parts, part)
302		default:
303			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
304		}
305	}
306
307	return parts, nil
308}