message.go

  1package message
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/google/uuid"
  9	"github.com/kujtimiihoxha/termai/internal/db"
 10	"github.com/kujtimiihoxha/termai/internal/pubsub"
 11)
 12
 13type CreateMessageParams struct {
 14	Role  MessageRole
 15	Parts []ContentPart
 16}
 17
 18type Service interface {
 19	pubsub.Suscriber[Message]
 20	Create(sessionID string, params CreateMessageParams) (Message, error)
 21	Update(message Message) error
 22	Get(id string) (Message, error)
 23	List(sessionID string) ([]Message, error)
 24	Delete(id string) error
 25	DeleteSessionMessages(sessionID string) error
 26}
 27
 28type service struct {
 29	*pubsub.Broker[Message]
 30	q   db.Querier
 31	ctx context.Context
 32}
 33
 34func NewService(ctx context.Context, q db.Querier) Service {
 35	return &service{
 36		Broker: pubsub.NewBroker[Message](),
 37		q:      q,
 38		ctx:    ctx,
 39	}
 40}
 41
 42func (s *service) Delete(id string) error {
 43	message, err := s.Get(id)
 44	if err != nil {
 45		return err
 46	}
 47	err = s.q.DeleteMessage(s.ctx, message.ID)
 48	if err != nil {
 49		return err
 50	}
 51	s.Publish(pubsub.DeletedEvent, message)
 52	return nil
 53}
 54
 55func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
 56	if params.Role != Assistant {
 57		params.Parts = append(params.Parts, Finish{
 58			Reason: "stop",
 59		})
 60	}
 61	partsJSON, err := marshallParts(params.Parts)
 62	if err != nil {
 63		return Message{}, err
 64	}
 65
 66	dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
 67		ID:        uuid.New().String(),
 68		SessionID: sessionID,
 69		Role:      string(params.Role),
 70		Parts:     string(partsJSON),
 71	})
 72	if err != nil {
 73		return Message{}, err
 74	}
 75	message, err := s.fromDBItem(dbMessage)
 76	if err != nil {
 77		return Message{}, err
 78	}
 79	s.Publish(pubsub.CreatedEvent, message)
 80	return message, nil
 81}
 82
 83func (s *service) DeleteSessionMessages(sessionID string) error {
 84	messages, err := s.List(sessionID)
 85	if err != nil {
 86		return err
 87	}
 88	for _, message := range messages {
 89		if message.SessionID == sessionID {
 90			err = s.Delete(message.ID)
 91			if err != nil {
 92				return err
 93			}
 94		}
 95	}
 96	return nil
 97}
 98
 99func (s *service) Update(message Message) error {
100	parts, err := marshallParts(message.Parts)
101	if err != nil {
102		return err
103	}
104	err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
105		ID:    message.ID,
106		Parts: string(parts),
107	})
108	if err != nil {
109		return err
110	}
111	s.Publish(pubsub.UpdatedEvent, message)
112	return nil
113}
114
115func (s *service) Get(id string) (Message, error) {
116	dbMessage, err := s.q.GetMessage(s.ctx, id)
117	if err != nil {
118		return Message{}, err
119	}
120	return s.fromDBItem(dbMessage)
121}
122
123func (s *service) List(sessionID string) ([]Message, error) {
124	dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
125	if err != nil {
126		return nil, err
127	}
128	messages := make([]Message, len(dbMessages))
129	for i, dbMessage := range dbMessages {
130		messages[i], err = s.fromDBItem(dbMessage)
131		if err != nil {
132			return nil, err
133		}
134	}
135	return messages, nil
136}
137
138func (s *service) fromDBItem(item db.Message) (Message, error) {
139	parts, err := unmarshallParts([]byte(item.Parts))
140	if err != nil {
141		return Message{}, err
142	}
143	return Message{
144		ID:        item.ID,
145		SessionID: item.SessionID,
146		Role:      MessageRole(item.Role),
147		Parts:     parts,
148		CreatedAt: item.CreatedAt,
149		UpdatedAt: item.UpdatedAt,
150	}, nil
151}
152
153type partType string
154
155const (
156	reasoningType  partType = "reasoning"
157	textType       partType = "text"
158	imageURLType   partType = "image_url"
159	binaryType     partType = "binary"
160	toolCallType   partType = "tool_call"
161	toolResultType partType = "tool_result"
162	finishType     partType = "finish"
163)
164
165type partWrapper struct {
166	Type partType    `json:"type"`
167	Data ContentPart `json:"data"`
168}
169
170func marshallParts(parts []ContentPart) ([]byte, error) {
171	wrappedParts := make([]partWrapper, len(parts))
172
173	for i, part := range parts {
174		var typ partType
175
176		switch part.(type) {
177		case ReasoningContent:
178			typ = reasoningType
179		case TextContent:
180			typ = textType
181		case ImageURLContent:
182			typ = imageURLType
183		case BinaryContent:
184			typ = binaryType
185		case ToolCall:
186			typ = toolCallType
187		case ToolResult:
188			typ = toolResultType
189		case Finish:
190			typ = finishType
191		default:
192			return nil, fmt.Errorf("unknown part type: %T", part)
193		}
194
195		wrappedParts[i] = partWrapper{
196			Type: typ,
197			Data: part,
198		}
199	}
200	return json.Marshal(wrappedParts)
201}
202
203func unmarshallParts(data []byte) ([]ContentPart, error) {
204	temp := []json.RawMessage{}
205
206	if err := json.Unmarshal(data, &temp); err != nil {
207		return nil, err
208	}
209
210	parts := make([]ContentPart, 0)
211
212	for _, rawPart := range temp {
213		var wrapper struct {
214			Type partType        `json:"type"`
215			Data json.RawMessage `json:"data"`
216		}
217
218		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
219			return nil, err
220		}
221
222		switch wrapper.Type {
223		case reasoningType:
224			part := ReasoningContent{}
225			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
226				return nil, err
227			}
228			parts = append(parts, part)
229		case textType:
230			part := TextContent{}
231			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
232				return nil, err
233			}
234			parts = append(parts, part)
235		case imageURLType:
236			part := ImageURLContent{}
237			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
238				return nil, err
239			}
240		case binaryType:
241			part := BinaryContent{}
242			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
243				return nil, err
244			}
245			parts = append(parts, part)
246		case toolCallType:
247			part := ToolCall{}
248			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
249				return nil, err
250			}
251			parts = append(parts, part)
252		case toolResultType:
253			part := ToolResult{}
254			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
255				return nil, err
256			}
257			parts = append(parts, part)
258		case finishType:
259			part := Finish{}
260			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
261				return nil, err
262			}
263			parts = append(parts, part)
264		default:
265			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
266		}
267
268	}
269
270	return parts, nil
271}