1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "time"
9
10 "github.com/google/uuid"
11 "github.com/opencode-ai/opencode/internal/db"
12 "github.com/opencode-ai/opencode/internal/llm/models"
13 "github.com/opencode-ai/opencode/internal/pubsub"
14)
15
16type CreateMessageParams struct {
17 Role MessageRole
18 Parts []ContentPart
19 Model models.ModelID
20}
21
22type Service interface {
23 pubsub.Suscriber[Message]
24 Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
25 Update(ctx context.Context, message Message) error
26 Get(ctx context.Context, id string) (Message, error)
27 List(ctx context.Context, sessionID string) ([]Message, error)
28 Delete(ctx context.Context, id string) error
29 DeleteSessionMessages(ctx context.Context, sessionID string) error
30}
31
32type service struct {
33 *pubsub.Broker[Message]
34 q db.Querier
35}
36
37func NewService(q db.Querier) Service {
38 return &service{
39 Broker: pubsub.NewBroker[Message](),
40 q: q,
41 }
42}
43
44func (s *service) Delete(ctx context.Context, id string) error {
45 message, err := s.Get(ctx, id)
46 if err != nil {
47 return err
48 }
49 err = s.q.DeleteMessage(ctx, message.ID)
50 if err != nil {
51 return err
52 }
53 s.Publish(pubsub.DeletedEvent, message)
54 return nil
55}
56
57func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
58 if params.Role != Assistant {
59 params.Parts = append(params.Parts, Finish{
60 Reason: "stop",
61 })
62 }
63 partsJSON, err := marshallParts(params.Parts)
64 if err != nil {
65 return Message{}, err
66 }
67
68 dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
69 ID: uuid.New().String(),
70 SessionID: sessionID,
71 Role: string(params.Role),
72 Parts: string(partsJSON),
73 Model: sql.NullString{String: string(params.Model), Valid: true},
74 })
75 if err != nil {
76 return Message{}, err
77 }
78 message, err := s.fromDBItem(dbMessage)
79 if err != nil {
80 return Message{}, err
81 }
82 s.Publish(pubsub.CreatedEvent, message)
83 return message, nil
84}
85
86func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
87 messages, err := s.List(ctx, sessionID)
88 if err != nil {
89 return err
90 }
91 for _, message := range messages {
92 if message.SessionID == sessionID {
93 err = s.Delete(ctx, message.ID)
94 if err != nil {
95 return err
96 }
97 }
98 }
99 return nil
100}
101
102func (s *service) Update(ctx context.Context, message Message) error {
103 parts, err := marshallParts(message.Parts)
104 if err != nil {
105 return err
106 }
107 finishedAt := sql.NullInt64{}
108 if f := message.FinishPart(); f != nil {
109 finishedAt.Int64 = f.Time
110 finishedAt.Valid = true
111 }
112 err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
113 ID: message.ID,
114 Parts: string(parts),
115 FinishedAt: finishedAt,
116 })
117 if err != nil {
118 return err
119 }
120 message.UpdatedAt = time.Now().Unix()
121 s.Publish(pubsub.UpdatedEvent, message)
122 return nil
123}
124
125func (s *service) Get(ctx context.Context, id string) (Message, error) {
126 dbMessage, err := s.q.GetMessage(ctx, id)
127 if err != nil {
128 return Message{}, err
129 }
130 return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
134 dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
135 if err != nil {
136 return nil, err
137 }
138 messages := make([]Message, len(dbMessages))
139 for i, dbMessage := range dbMessages {
140 messages[i], err = s.fromDBItem(dbMessage)
141 if err != nil {
142 return nil, err
143 }
144 }
145 return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149 parts, err := unmarshallParts([]byte(item.Parts))
150 if err != nil {
151 return Message{}, err
152 }
153 return Message{
154 ID: item.ID,
155 SessionID: item.SessionID,
156 Role: MessageRole(item.Role),
157 Parts: parts,
158 Model: models.ModelID(item.Model.String),
159 CreatedAt: item.CreatedAt,
160 UpdatedAt: item.UpdatedAt,
161 }, nil
162}
163
164type partType string
165
166const (
167 reasoningType partType = "reasoning"
168 textType partType = "text"
169 imageURLType partType = "image_url"
170 binaryType partType = "binary"
171 toolCallType partType = "tool_call"
172 toolResultType partType = "tool_result"
173 finishType partType = "finish"
174)
175
176type partWrapper struct {
177 Type partType `json:"type"`
178 Data ContentPart `json:"data"`
179}
180
181func marshallParts(parts []ContentPart) ([]byte, error) {
182 wrappedParts := make([]partWrapper, len(parts))
183
184 for i, part := range parts {
185 var typ partType
186
187 switch part.(type) {
188 case ReasoningContent:
189 typ = reasoningType
190 case TextContent:
191 typ = textType
192 case ImageURLContent:
193 typ = imageURLType
194 case BinaryContent:
195 typ = binaryType
196 case ToolCall:
197 typ = toolCallType
198 case ToolResult:
199 typ = toolResultType
200 case Finish:
201 typ = finishType
202 default:
203 return nil, fmt.Errorf("unknown part type: %T", part)
204 }
205
206 wrappedParts[i] = partWrapper{
207 Type: typ,
208 Data: part,
209 }
210 }
211 return json.Marshal(wrappedParts)
212}
213
214func unmarshallParts(data []byte) ([]ContentPart, error) {
215 temp := []json.RawMessage{}
216
217 if err := json.Unmarshal(data, &temp); err != nil {
218 return nil, err
219 }
220
221 parts := make([]ContentPart, 0)
222
223 for _, rawPart := range temp {
224 var wrapper struct {
225 Type partType `json:"type"`
226 Data json.RawMessage `json:"data"`
227 }
228
229 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
230 return nil, err
231 }
232
233 switch wrapper.Type {
234 case reasoningType:
235 part := ReasoningContent{}
236 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
237 return nil, err
238 }
239 parts = append(parts, part)
240 case textType:
241 part := TextContent{}
242 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
243 return nil, err
244 }
245 parts = append(parts, part)
246 case imageURLType:
247 part := ImageURLContent{}
248 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
249 return nil, err
250 }
251 case binaryType:
252 part := BinaryContent{}
253 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
254 return nil, err
255 }
256 parts = append(parts, part)
257 case toolCallType:
258 part := ToolCall{}
259 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
260 return nil, err
261 }
262 parts = append(parts, part)
263 case toolResultType:
264 part := ToolResult{}
265 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
266 return nil, err
267 }
268 parts = append(parts, part)
269 case finishType:
270 part := Finish{}
271 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
272 return nil, err
273 }
274 parts = append(parts, part)
275 default:
276 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
277 }
278
279 }
280
281 return parts, nil
282}