1package message
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7
8 "github.com/google/uuid"
9 "github.com/kujtimiihoxha/termai/internal/db"
10 "github.com/kujtimiihoxha/termai/internal/pubsub"
11)
12
13type MessageRole string
14
15const (
16 Assistant MessageRole = "assistant"
17 User MessageRole = "user"
18 System MessageRole = "system"
19 Tool MessageRole = "tool"
20)
21
22type ToolResult struct {
23 ToolCallID string
24 Content string
25 IsError bool
26 // TODO: support for images
27}
28
29type ToolCall struct {
30 ID string
31 Name string
32 Input string
33 Type string
34}
35
36type Message struct {
37 ID string
38 SessionID string
39
40 // NEW
41 Role MessageRole
42 Content string
43 Thinking string
44
45 Finished bool
46
47 ToolResults []ToolResult
48 ToolCalls []ToolCall
49 CreatedAt int64
50 UpdatedAt int64
51}
52
53type CreateMessageParams struct {
54 Role MessageRole
55 Content string
56 ToolCalls []ToolCall
57 ToolResults []ToolResult
58}
59
60type Service interface {
61 pubsub.Suscriber[Message]
62 Create(sessionID string, params CreateMessageParams) (Message, error)
63 Update(message Message) error
64 Get(id string) (Message, error)
65 List(sessionID string) ([]Message, error)
66 Delete(id string) error
67 DeleteSessionMessages(sessionID string) error
68}
69
70type service struct {
71 *pubsub.Broker[Message]
72 q db.Querier
73 ctx context.Context
74}
75
76func (s *service) Delete(id string) error {
77 message, err := s.Get(id)
78 if err != nil {
79 return err
80 }
81 err = s.q.DeleteMessage(s.ctx, message.ID)
82 if err != nil {
83 return err
84 }
85 s.Publish(pubsub.DeletedEvent, message)
86 return nil
87}
88
89func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
90 toolCallsStr, err := json.Marshal(params.ToolCalls)
91 if err != nil {
92 return Message{}, err
93 }
94 toolResultsStr, err := json.Marshal(params.ToolResults)
95 if err != nil {
96 return Message{}, err
97 }
98 dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
99 ID: uuid.New().String(),
100 SessionID: sessionID,
101 Role: string(params.Role),
102 Finished: params.Role != Assistant,
103 Content: params.Content,
104 ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
105 ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
106 })
107 if err != nil {
108 return Message{}, err
109 }
110 message, err := s.fromDBItem(dbMessage)
111 if err != nil {
112 return Message{}, err
113 }
114 s.Publish(pubsub.CreatedEvent, message)
115 return message, nil
116}
117
118func (s *service) DeleteSessionMessages(sessionID string) error {
119 messages, err := s.List(sessionID)
120 if err != nil {
121 return err
122 }
123 for _, message := range messages {
124 if message.SessionID == sessionID {
125 err = s.Delete(message.ID)
126 if err != nil {
127 return err
128 }
129 }
130 }
131 return nil
132}
133
134func (s *service) Update(message Message) error {
135 toolCallsStr, err := json.Marshal(message.ToolCalls)
136 if err != nil {
137 return err
138 }
139 toolResultsStr, err := json.Marshal(message.ToolResults)
140 if err != nil {
141 return err
142 }
143 err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
144 ID: message.ID,
145 Content: message.Content,
146 Thinking: message.Thinking,
147 Finished: message.Finished,
148 ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
149 ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
150 })
151 if err != nil {
152 return err
153 }
154 s.Publish(pubsub.UpdatedEvent, message)
155 return nil
156}
157
158func (s *service) Get(id string) (Message, error) {
159 dbMessage, err := s.q.GetMessage(s.ctx, id)
160 if err != nil {
161 return Message{}, err
162 }
163 return s.fromDBItem(dbMessage)
164}
165
166func (s *service) List(sessionID string) ([]Message, error) {
167 dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
168 if err != nil {
169 return nil, err
170 }
171 messages := make([]Message, len(dbMessages))
172 for i, dbMessage := range dbMessages {
173 messages[i], err = s.fromDBItem(dbMessage)
174 if err != nil {
175 return nil, err
176 }
177 }
178 return messages, nil
179}
180
181func (s *service) fromDBItem(item db.Message) (Message, error) {
182 toolCalls := make([]ToolCall, 0)
183 if item.ToolCalls.Valid {
184 err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
185 if err != nil {
186 return Message{}, err
187 }
188 }
189
190 toolResults := make([]ToolResult, 0)
191 if item.ToolResults.Valid {
192 err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
193 if err != nil {
194 return Message{}, err
195 }
196 }
197
198 return Message{
199 ID: item.ID,
200 SessionID: item.SessionID,
201 Role: MessageRole(item.Role),
202 Content: item.Content,
203 Thinking: item.Thinking,
204 Finished: item.Finished,
205 ToolCalls: toolCalls,
206 ToolResults: toolResults,
207 CreatedAt: item.CreatedAt,
208 UpdatedAt: item.UpdatedAt,
209 }, nil
210}
211
212func NewService(ctx context.Context, q db.Querier) Service {
213 return &service{
214 Broker: pubsub.NewBroker[Message](),
215 q: q,
216 ctx: ctx,
217 }
218}