1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8
9 "github.com/google/uuid"
10 "github.com/kujtimiihoxha/opencode/internal/db"
11 "github.com/kujtimiihoxha/opencode/internal/llm/models"
12 "github.com/kujtimiihoxha/opencode/internal/pubsub"
13)
14
15type CreateMessageParams struct {
16 Role MessageRole
17 Parts []ContentPart
18 Model models.ModelID
19}
20
21type Service interface {
22 pubsub.Suscriber[Message]
23 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
24 Update(ctx context.Context, message Message) error
25 Get(ctx context.Context, id string) (Message, error)
26 List(ctx context.Context, sessionID string) ([]Message, error)
27 Delete(ctx context.Context, id string) error
28 DeleteSessionMessages(ctx context.Context, sessionID string) error
29}
30
31type service struct {
32 *pubsub.Broker[Message]
33 q db.Querier
34}
35
36func NewService(q db.Querier) Service {
37 return &service{
38 Broker: pubsub.NewBroker[Message](),
39 q: q,
40 }
41}
42
43func (s *service) Delete(ctx context.Context, id string) error {
44 message, err := s.Get(ctx, id)
45 if err != nil {
46 return err
47 }
48 err = s.q.DeleteMessage(ctx, message.ID)
49 if err != nil {
50 return err
51 }
52 s.Publish(pubsub.DeletedEvent, message)
53 return nil
54}
55
56func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
57 if params.Role != Assistant {
58 params.Parts = append(params.Parts, Finish{
59 Reason: "stop",
60 })
61 }
62 partsJSON, err := marshallParts(params.Parts)
63 if err != nil {
64 return Message{}, err
65 }
66
67 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
68 ID: uuid.New().String(),
69 SessionID: sessionID,
70 Role: string(params.Role),
71 Parts: string(partsJSON),
72 Model: sql.NullString{String: string(params.Model), Valid: true},
73 })
74 if err != nil {
75 return Message{}, err
76 }
77 message, err := s.fromDBItem(dbMessage)
78 if err != nil {
79 return Message{}, err
80 }
81 s.Publish(pubsub.CreatedEvent, message)
82 return message, nil
83}
84
85func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
86 messages, err := s.List(ctx, sessionID)
87 if err != nil {
88 return err
89 }
90 for _, message := range messages {
91 if message.SessionID == sessionID {
92 err = s.Delete(ctx, message.ID)
93 if err != nil {
94 return err
95 }
96 }
97 }
98 return nil
99}
100
101func (s *service) Update(ctx context.Context, message Message) error {
102 parts, err := marshallParts(message.Parts)
103 if err != nil {
104 return err
105 }
106 finishedAt := sql.NullInt64{}
107 if f := message.FinishPart(); f != nil {
108 finishedAt.Int64 = f.Time
109 finishedAt.Valid = true
110 }
111 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
112 ID: message.ID,
113 Parts: string(parts),
114 FinishedAt: finishedAt,
115 })
116 if err != nil {
117 return err
118 }
119 s.Publish(pubsub.UpdatedEvent, message)
120 return nil
121}
122
123func (s *service) Get(ctx context.Context, id string) (Message, error) {
124 dbMessage, err := s.q.GetMessage(ctx, id)
125 if err != nil {
126 return Message{}, err
127 }
128 return s.fromDBItem(dbMessage)
129}
130
131func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
132 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
133 if err != nil {
134 return nil, err
135 }
136 messages := make([]Message, len(dbMessages))
137 for i, dbMessage := range dbMessages {
138 messages[i], err = s.fromDBItem(dbMessage)
139 if err != nil {
140 return nil, err
141 }
142 }
143 return messages, nil
144}
145
146func (s *service) fromDBItem(item db.Message) (Message, error) {
147 parts, err := unmarshallParts([]byte(item.Parts))
148 if err != nil {
149 return Message{}, err
150 }
151 return Message{
152 ID: item.ID,
153 SessionID: item.SessionID,
154 Role: MessageRole(item.Role),
155 Parts: parts,
156 Model: models.ModelID(item.Model.String),
157 CreatedAt: item.CreatedAt,
158 UpdatedAt: item.UpdatedAt,
159 }, nil
160}
161
162type partType string
163
164const (
165 reasoningType partType = "reasoning"
166 textType partType = "text"
167 imageURLType partType = "image_url"
168 binaryType partType = "binary"
169 toolCallType partType = "tool_call"
170 toolResultType partType = "tool_result"
171 finishType partType = "finish"
172)
173
174type partWrapper struct {
175 Type partType `json:"type"`
176 Data ContentPart `json:"data"`
177}
178
179func marshallParts(parts []ContentPart) ([]byte, error) {
180 wrappedParts := make([]partWrapper, len(parts))
181
182 for i, part := range parts {
183 var typ partType
184
185 switch part.(type) {
186 case ReasoningContent:
187 typ = reasoningType
188 case TextContent:
189 typ = textType
190 case ImageURLContent:
191 typ = imageURLType
192 case BinaryContent:
193 typ = binaryType
194 case ToolCall:
195 typ = toolCallType
196 case ToolResult:
197 typ = toolResultType
198 case Finish:
199 typ = finishType
200 default:
201 return nil, fmt.Errorf("unknown part type: %T", part)
202 }
203
204 wrappedParts[i] = partWrapper{
205 Type: typ,
206 Data: part,
207 }
208 }
209 return json.Marshal(wrappedParts)
210}
211
212func unmarshallParts(data []byte) ([]ContentPart, error) {
213 temp := []json.RawMessage{}
214
215 if err := json.Unmarshal(data, &temp); err != nil {
216 return nil, err
217 }
218
219 parts := make([]ContentPart, 0)
220
221 for _, rawPart := range temp {
222 var wrapper struct {
223 Type partType `json:"type"`
224 Data json.RawMessage `json:"data"`
225 }
226
227 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
228 return nil, err
229 }
230
231 switch wrapper.Type {
232 case reasoningType:
233 part := ReasoningContent{}
234 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
235 return nil, err
236 }
237 parts = append(parts, part)
238 case textType:
239 part := TextContent{}
240 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
241 return nil, err
242 }
243 parts = append(parts, part)
244 case imageURLType:
245 part := ImageURLContent{}
246 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
247 return nil, err
248 }
249 case binaryType:
250 part := BinaryContent{}
251 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
252 return nil, err
253 }
254 parts = append(parts, part)
255 case toolCallType:
256 part := ToolCall{}
257 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
258 return nil, err
259 }
260 parts = append(parts, part)
261 case toolResultType:
262 part := ToolResult{}
263 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
264 return nil, err
265 }
266 parts = append(parts, part)
267 case finishType:
268 part := Finish{}
269 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
270 return nil, err
271 }
272 parts = append(parts, part)
273 default:
274 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
275 }
276
277 }
278
279 return parts, nil
280}