1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/db"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type CreateMessageParams struct {
 16	Role             MessageRole
 17	Parts            []ContentPart
 18	Model            string
 19	Provider         string
 20	IsSummaryMessage bool
 21}
 22
 23type Service interface {
 24	pubsub.Suscriber[Message]
 25	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 26	Update(ctx context.Context, message Message) error
 27	Get(ctx context.Context, id string) (Message, error)
 28	List(ctx context.Context, sessionID string) ([]Message, error)
 29	Delete(ctx context.Context, id string) error
 30	DeleteSessionMessages(ctx context.Context, sessionID string) error
 31}
 32
 33type service struct {
 34	*pubsub.Broker[Message]
 35	q db.Querier
 36}
 37
 38func NewService(q db.Querier) Service {
 39	return &service{
 40		Broker: pubsub.NewBroker[Message](),
 41		q:      q,
 42	}
 43}
 44
 45func (s *service) Delete(ctx context.Context, id string) error {
 46	message, err := s.Get(ctx, id)
 47	if err != nil {
 48		return err
 49	}
 50	err = s.q.DeleteMessage(ctx, message.ID)
 51	if err != nil {
 52		return err
 53	}
 54	s.Publish(pubsub.DeletedEvent, message)
 55	return nil
 56}
 57
 58func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 59	if params.Role != Assistant {
 60		params.Parts = append(params.Parts, Finish{
 61			Reason: "stop",
 62		})
 63	}
 64	partsJSON, err := marshallParts(params.Parts)
 65	if err != nil {
 66		return Message{}, err
 67	}
 68	isSummary := int64(0)
 69	if params.IsSummaryMessage {
 70		isSummary = 1
 71	}
 72	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 73		ID:               uuid.New().String(),
 74		SessionID:        sessionID,
 75		Role:             string(params.Role),
 76		Parts:            string(partsJSON),
 77		Model:            sql.NullString{String: string(params.Model), Valid: true},
 78		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 79		IsSummaryMessage: isSummary,
 80	})
 81	if err != nil {
 82		return Message{}, err
 83	}
 84	message, err := s.fromDBItem(dbMessage)
 85	if err != nil {
 86		return Message{}, err
 87	}
 88	s.Publish(pubsub.CreatedEvent, message)
 89	return message, nil
 90}
 91
 92func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 93	messages, err := s.List(ctx, sessionID)
 94	if err != nil {
 95		return err
 96	}
 97	for _, message := range messages {
 98		if message.SessionID == sessionID {
 99			err = s.Delete(ctx, message.ID)
100			if err != nil {
101				return err
102			}
103		}
104	}
105	return nil
106}
107
108func (s *service) Update(ctx context.Context, message Message) error {
109	parts, err := marshallParts(message.Parts)
110	if err != nil {
111		return err
112	}
113	finishedAt := sql.NullInt64{}
114	if f := message.FinishPart(); f != nil {
115		finishedAt.Int64 = f.Time
116		finishedAt.Valid = true
117	}
118	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
119		ID:         message.ID,
120		Parts:      string(parts),
121		FinishedAt: finishedAt,
122	})
123	if err != nil {
124		return err
125	}
126	message.UpdatedAt = time.Now().Unix()
127	s.Publish(pubsub.UpdatedEvent, message)
128	return nil
129}
130
131func (s *service) Get(ctx context.Context, id string) (Message, error) {
132	dbMessage, err := s.q.GetMessage(ctx, id)
133	if err != nil {
134		return Message{}, err
135	}
136	return s.fromDBItem(dbMessage)
137}
138
139func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
140	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
141	if err != nil {
142		return nil, err
143	}
144	messages := make([]Message, len(dbMessages))
145	for i, dbMessage := range dbMessages {
146		messages[i], err = s.fromDBItem(dbMessage)
147		if err != nil {
148			return nil, err
149		}
150	}
151	return messages, nil
152}
153
154func (s *service) fromDBItem(item db.Message) (Message, error) {
155	parts, err := unmarshallParts([]byte(item.Parts))
156	if err != nil {
157		return Message{}, err
158	}
159	return Message{
160		ID:               item.ID,
161		SessionID:        item.SessionID,
162		Role:             MessageRole(item.Role),
163		Parts:            parts,
164		Model:            item.Model.String,
165		Provider:         item.Provider.String,
166		CreatedAt:        item.CreatedAt,
167		UpdatedAt:        item.UpdatedAt,
168		IsSummaryMessage: item.IsSummaryMessage != 0,
169	}, nil
170}
171
172type partType string
173
174const (
175	reasoningType  partType = "reasoning"
176	textType       partType = "text"
177	imageURLType   partType = "image_url"
178	binaryType     partType = "binary"
179	toolCallType   partType = "tool_call"
180	toolResultType partType = "tool_result"
181	finishType     partType = "finish"
182)
183
184type partWrapper struct {
185	Type partType    `json:"type"`
186	Data ContentPart `json:"data"`
187}
188
189func marshallParts(parts []ContentPart) ([]byte, error) {
190	wrappedParts := make([]partWrapper, len(parts))
191
192	for i, part := range parts {
193		var typ partType
194
195		switch part.(type) {
196		case ReasoningContent:
197			typ = reasoningType
198		case TextContent:
199			typ = textType
200		case ImageURLContent:
201			typ = imageURLType
202		case BinaryContent:
203			typ = binaryType
204		case ToolCall:
205			typ = toolCallType
206		case ToolResult:
207			typ = toolResultType
208		case Finish:
209			typ = finishType
210		default:
211			return nil, fmt.Errorf("unknown part type: %T", part)
212		}
213
214		wrappedParts[i] = partWrapper{
215			Type: typ,
216			Data: part,
217		}
218	}
219	return json.Marshal(wrappedParts)
220}
221
222func unmarshallParts(data []byte) ([]ContentPart, error) {
223	temp := []json.RawMessage{}
224
225	if err := json.Unmarshal(data, &temp); err != nil {
226		return nil, err
227	}
228
229	parts := make([]ContentPart, 0)
230
231	for _, rawPart := range temp {
232		var wrapper struct {
233			Type partType        `json:"type"`
234			Data json.RawMessage `json:"data"`
235		}
236
237		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
238			return nil, err
239		}
240
241		switch wrapper.Type {
242		case reasoningType:
243			part := ReasoningContent{}
244			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
245				return nil, err
246			}
247			parts = append(parts, part)
248		case textType:
249			part := TextContent{}
250			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
251				return nil, err
252			}
253			parts = append(parts, part)
254		case imageURLType:
255			part := ImageURLContent{}
256			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
257				return nil, err
258			}
259		case binaryType:
260			part := BinaryContent{}
261			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
262				return nil, err
263			}
264			parts = append(parts, part)
265		case toolCallType:
266			part := ToolCall{}
267			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
268				return nil, err
269			}
270			parts = append(parts, part)
271		case toolResultType:
272			part := ToolResult{}
273			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
274				return nil, err
275			}
276			parts = append(parts, part)
277		case finishType:
278			part := Finish{}
279			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
280				return nil, err
281			}
282			parts = append(parts, part)
283		default:
284			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
285		}
286	}
287
288	return parts, nil
289}