1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8
9 "github.com/google/uuid"
10 "github.com/kujtimiihoxha/termai/internal/db"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/pubsub"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model models.ModelID
19}
20
21type Service interface {
22 pubsub.Suscriber[Message]
23 Create(sessionID string, params CreateMessageParams) (Message, error)
24 Update(message Message) error
25 Get(id string) (Message, error)
26 List(sessionID string) ([]Message, error)
27 Delete(id string) error
28 DeleteSessionMessages(sessionID string) error
29}
30
31type service struct {
32 *pubsub.Broker[Message]
33 q db.Querier
34 ctx context.Context
35}
36
37func NewService(ctx context.Context, q db.Querier) Service {
38 return &service{
39 Broker: pubsub.NewBroker[Message](),
40 q: q,
41 ctx: ctx,
42 }
43}
44
45func (s *service) Delete(id string) error {
46 message, err := s.Get(id)
47 if err != nil {
48 return err
49 }
50 err = s.q.DeleteMessage(s.ctx, message.ID)
51 if err != nil {
52 return err
53 }
54 s.Publish(pubsub.DeletedEvent, message)
55 return nil
56}
57
58func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
59 if params.Role != Assistant {
60 params.Parts = append(params.Parts, Finish{
61 Reason: "stop",
62 })
63 }
64 partsJSON, err := marshallParts(params.Parts)
65 if err != nil {
66 return Message{}, err
67 }
68
69 dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
70 ID: uuid.New().String(),
71 SessionID: sessionID,
72 Role: string(params.Role),
73 Parts: string(partsJSON),
74 Model: sql.NullString{String: string(params.Model), Valid: true},
75 })
76 if err != nil {
77 return Message{}, err
78 }
79 message, err := s.fromDBItem(dbMessage)
80 if err != nil {
81 return Message{}, err
82 }
83 s.Publish(pubsub.CreatedEvent, message)
84 return message, nil
85}
86
87func (s *service) DeleteSessionMessages(sessionID string) error {
88 messages, err := s.List(sessionID)
89 if err != nil {
90 return err
91 }
92 for _, message := range messages {
93 if message.SessionID == sessionID {
94 err = s.Delete(message.ID)
95 if err != nil {
96 return err
97 }
98 }
99 }
100 return nil
101}
102
103func (s *service) Update(message Message) error {
104 parts, err := marshallParts(message.Parts)
105 if err != nil {
106 return err
107 }
108 finishedAt := sql.NullInt64{}
109 if f := message.FinishPart(); f != nil {
110 finishedAt.Int64 = f.Time
111 finishedAt.Valid = true
112 }
113 err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
114 ID: message.ID,
115 Parts: string(parts),
116 FinishedAt: finishedAt,
117 })
118 if err != nil {
119 return err
120 }
121 s.Publish(pubsub.UpdatedEvent, message)
122 return nil
123}
124
125func (s *service) Get(id string) (Message, error) {
126 dbMessage, err := s.q.GetMessage(s.ctx, id)
127 if err != nil {
128 return Message{}, err
129 }
130 return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(sessionID string) ([]Message, error) {
134 dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
135 if err != nil {
136 return nil, err
137 }
138 messages := make([]Message, len(dbMessages))
139 for i, dbMessage := range dbMessages {
140 messages[i], err = s.fromDBItem(dbMessage)
141 if err != nil {
142 return nil, err
143 }
144 }
145 return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149 parts, err := unmarshallParts([]byte(item.Parts))
150 if err != nil {
151 return Message{}, err
152 }
153 return Message{
154 ID: item.ID,
155 SessionID: item.SessionID,
156 Role: MessageRole(item.Role),
157 Parts: parts,
158 Model: models.ModelID(item.Model.String),
159 CreatedAt: item.CreatedAt,
160 UpdatedAt: item.UpdatedAt,
161 }, nil
162}
163
164type partType string
165
166const (
167 reasoningType partType = "reasoning"
168 textType partType = "text"
169 imageURLType partType = "image_url"
170 binaryType partType = "binary"
171 toolCallType partType = "tool_call"
172 toolResultType partType = "tool_result"
173 finishType partType = "finish"
174)
175
176type partWrapper struct {
177 Type partType `json:"type"`
178 Data ContentPart `json:"data"`
179}
180
181func marshallParts(parts []ContentPart) ([]byte, error) {
182 wrappedParts := make([]partWrapper, len(parts))
183
184 for i, part := range parts {
185 var typ partType
186
187 switch part.(type) {
188 case ReasoningContent:
189 typ = reasoningType
190 case TextContent:
191 typ = textType
192 case ImageURLContent:
193 typ = imageURLType
194 case BinaryContent:
195 typ = binaryType
196 case ToolCall:
197 typ = toolCallType
198 case ToolResult:
199 typ = toolResultType
200 case Finish:
201 typ = finishType
202 default:
203 return nil, fmt.Errorf("unknown part type: %T", part)
204 }
205
206 wrappedParts[i] = partWrapper{
207 Type: typ,
208 Data: part,
209 }
210 }
211 return json.Marshal(wrappedParts)
212}
213
214func unmarshallParts(data []byte) ([]ContentPart, error) {
215 temp := []json.RawMessage{}
216
217 if err := json.Unmarshal(data, &temp); err != nil {
218 return nil, err
219 }
220
221 parts := make([]ContentPart, 0)
222
223 for _, rawPart := range temp {
224 var wrapper struct {
225 Type partType `json:"type"`
226 Data json.RawMessage `json:"data"`
227 }
228
229 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
230 return nil, err
231 }
232
233 switch wrapper.Type {
234 case reasoningType:
235 part := ReasoningContent{}
236 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
237 return nil, err
238 }
239 parts = append(parts, part)
240 case textType:
241 part := TextContent{}
242 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
243 return nil, err
244 }
245 parts = append(parts, part)
246 case imageURLType:
247 part := ImageURLContent{}
248 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
249 return nil, err
250 }
251 case binaryType:
252 part := BinaryContent{}
253 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
254 return nil, err
255 }
256 parts = append(parts, part)
257 case toolCallType:
258 part := ToolCall{}
259 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
260 return nil, err
261 }
262 parts = append(parts, part)
263 case toolResultType:
264 part := ToolResult{}
265 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
266 return nil, err
267 }
268 parts = append(parts, part)
269 case finishType:
270 part := Finish{}
271 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
272 return nil, err
273 }
274 parts = append(parts, part)
275 default:
276 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
277 }
278
279 }
280
281 return parts, nil
282}