1package message
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/google/uuid"
9 "github.com/kujtimiihoxha/termai/internal/db"
10 "github.com/kujtimiihoxha/termai/internal/pubsub"
11)
12
13type CreateMessageParams struct {
14 Role MessageRole
15 Parts []ContentPart
16}
17
18type Service interface {
19 pubsub.Suscriber[Message]
20 Create(sessionID string, params CreateMessageParams) (Message, error)
21 Update(message Message) error
22 Get(id string) (Message, error)
23 List(sessionID string) ([]Message, error)
24 Delete(id string) error
25 DeleteSessionMessages(sessionID string) error
26}
27
28type service struct {
29 *pubsub.Broker[Message]
30 q db.Querier
31 ctx context.Context
32}
33
34func NewService(ctx context.Context, q db.Querier) Service {
35 return &service{
36 Broker: pubsub.NewBroker[Message](),
37 q: q,
38 ctx: ctx,
39 }
40}
41
42func (s *service) Delete(id string) error {
43 message, err := s.Get(id)
44 if err != nil {
45 return err
46 }
47 err = s.q.DeleteMessage(s.ctx, message.ID)
48 if err != nil {
49 return err
50 }
51 s.Publish(pubsub.DeletedEvent, message)
52 return nil
53}
54
55func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
56 if params.Role != Assistant {
57 params.Parts = append(params.Parts, Finish{
58 Reason: "stop",
59 })
60 }
61 partsJSON, err := marshallParts(params.Parts)
62 if err != nil {
63 return Message{}, err
64 }
65
66 dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
67 ID: uuid.New().String(),
68 SessionID: sessionID,
69 Role: string(params.Role),
70 Parts: string(partsJSON),
71 })
72 if err != nil {
73 return Message{}, err
74 }
75 message, err := s.fromDBItem(dbMessage)
76 if err != nil {
77 return Message{}, err
78 }
79 s.Publish(pubsub.CreatedEvent, message)
80 return message, nil
81}
82
83func (s *service) DeleteSessionMessages(sessionID string) error {
84 messages, err := s.List(sessionID)
85 if err != nil {
86 return err
87 }
88 for _, message := range messages {
89 if message.SessionID == sessionID {
90 err = s.Delete(message.ID)
91 if err != nil {
92 return err
93 }
94 }
95 }
96 return nil
97}
98
99func (s *service) Update(message Message) error {
100 parts, err := marshallParts(message.Parts)
101 if err != nil {
102 return err
103 }
104 err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
105 ID: message.ID,
106 Parts: string(parts),
107 })
108 if err != nil {
109 return err
110 }
111 s.Publish(pubsub.UpdatedEvent, message)
112 return nil
113}
114
115func (s *service) Get(id string) (Message, error) {
116 dbMessage, err := s.q.GetMessage(s.ctx, id)
117 if err != nil {
118 return Message{}, err
119 }
120 return s.fromDBItem(dbMessage)
121}
122
123func (s *service) List(sessionID string) ([]Message, error) {
124 dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
125 if err != nil {
126 return nil, err
127 }
128 messages := make([]Message, len(dbMessages))
129 for i, dbMessage := range dbMessages {
130 messages[i], err = s.fromDBItem(dbMessage)
131 if err != nil {
132 return nil, err
133 }
134 }
135 return messages, nil
136}
137
138func (s *service) fromDBItem(item db.Message) (Message, error) {
139 parts, err := unmarshallParts([]byte(item.Parts))
140 if err != nil {
141 return Message{}, err
142 }
143 return Message{
144 ID: item.ID,
145 SessionID: item.SessionID,
146 Role: MessageRole(item.Role),
147 Parts: parts,
148 CreatedAt: item.CreatedAt,
149 UpdatedAt: item.UpdatedAt,
150 }, nil
151}
152
153type partType string
154
155const (
156 reasoningType partType = "reasoning"
157 textType partType = "text"
158 imageURLType partType = "image_url"
159 binaryType partType = "binary"
160 toolCallType partType = "tool_call"
161 toolResultType partType = "tool_result"
162 finishType partType = "finish"
163)
164
165type partWrapper struct {
166 Type partType `json:"type"`
167 Data ContentPart `json:"data"`
168}
169
170func marshallParts(parts []ContentPart) ([]byte, error) {
171 wrappedParts := make([]partWrapper, len(parts))
172
173 for i, part := range parts {
174 var typ partType
175
176 switch part.(type) {
177 case ReasoningContent:
178 typ = reasoningType
179 case TextContent:
180 typ = textType
181 case ImageURLContent:
182 typ = imageURLType
183 case BinaryContent:
184 typ = binaryType
185 case ToolCall:
186 typ = toolCallType
187 case ToolResult:
188 typ = toolResultType
189 case Finish:
190 typ = finishType
191 default:
192 return nil, fmt.Errorf("unknown part type: %T", part)
193 }
194
195 wrappedParts[i] = partWrapper{
196 Type: typ,
197 Data: part,
198 }
199 }
200 return json.Marshal(wrappedParts)
201}
202
203func unmarshallParts(data []byte) ([]ContentPart, error) {
204 temp := []json.RawMessage{}
205
206 if err := json.Unmarshal(data, &temp); err != nil {
207 return nil, err
208 }
209
210 parts := make([]ContentPart, 0)
211
212 for _, rawPart := range temp {
213 var wrapper struct {
214 Type partType `json:"type"`
215 Data json.RawMessage `json:"data"`
216 }
217
218 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
219 return nil, err
220 }
221
222 switch wrapper.Type {
223 case reasoningType:
224 part := ReasoningContent{}
225 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
226 return nil, err
227 }
228 parts = append(parts, part)
229 case textType:
230 part := TextContent{}
231 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
232 return nil, err
233 }
234 parts = append(parts, part)
235 case imageURLType:
236 part := ImageURLContent{}
237 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
238 return nil, err
239 }
240 case binaryType:
241 part := BinaryContent{}
242 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
243 return nil, err
244 }
245 parts = append(parts, part)
246 case toolCallType:
247 part := ToolCall{}
248 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
249 return nil, err
250 }
251 parts = append(parts, part)
252 case toolResultType:
253 part := ToolResult{}
254 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
255 return nil, err
256 }
257 parts = append(parts, part)
258 case finishType:
259 part := Finish{}
260 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
261 return nil, err
262 }
263 parts = append(parts, part)
264 default:
265 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
266 }
267
268 }
269
270 return parts, nil
271}