message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7
  8	"github.com/google/uuid"
  9	"github.com/kujtimiihoxha/termai/internal/db"
 10	"github.com/kujtimiihoxha/termai/internal/pubsub"
 11)
 12
 13type MessageRole string
 14
 15const (
 16	Assistant MessageRole = "assistant"
 17	User      MessageRole = "user"
 18	System    MessageRole = "system"
 19	Tool      MessageRole = "tool"
 20)
 21
 22type ToolResult struct {
 23	ToolCallID string
 24	Content    string
 25	IsError    bool
 26	// TODO: support for images
 27}
 28
 29type ToolCall struct {
 30	ID    string
 31	Name  string
 32	Input string
 33	Type  string
 34}
 35
 36type Message struct {
 37	ID        string
 38	SessionID string
 39
 40	// NEW
 41	Role     MessageRole
 42	Content  string
 43	Thinking string
 44
 45	Finished bool
 46
 47	ToolResults []ToolResult
 48	ToolCalls   []ToolCall
 49	CreatedAt   int64
 50	UpdatedAt   int64
 51}
 52
 53type CreateMessageParams struct {
 54	Role        MessageRole
 55	Content     string
 56	ToolCalls   []ToolCall
 57	ToolResults []ToolResult
 58}
 59
 60type Service interface {
 61	pubsub.Suscriber[Message]
 62	Create(sessionID string, params CreateMessageParams) (Message, error)
 63	Update(message Message) error
 64	Get(id string) (Message, error)
 65	List(sessionID string) ([]Message, error)
 66	Delete(id string) error
 67	DeleteSessionMessages(sessionID string) error
 68}
 69
 70type service struct {
 71	*pubsub.Broker[Message]
 72	q   db.Querier
 73	ctx context.Context
 74}
 75
 76func (s *service) Delete(id string) error {
 77	message, err := s.Get(id)
 78	if err != nil {
 79		return err
 80	}
 81	err = s.q.DeleteMessage(s.ctx, message.ID)
 82	if err != nil {
 83		return err
 84	}
 85	s.Publish(pubsub.DeletedEvent, message)
 86	return nil
 87}
 88
 89func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
 90	toolCallsStr, err := json.Marshal(params.ToolCalls)
 91	if err != nil {
 92		return Message{}, err
 93	}
 94	toolResultsStr, err := json.Marshal(params.ToolResults)
 95	if err != nil {
 96		return Message{}, err
 97	}
 98	dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
 99		ID:          uuid.New().String(),
100		SessionID:   sessionID,
101		Role:        string(params.Role),
102		Finished:    params.Role != Assistant,
103		Content:     params.Content,
104		ToolCalls:   sql.NullString{String: string(toolCallsStr), Valid: true},
105		ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
106	})
107	if err != nil {
108		return Message{}, err
109	}
110	message, err := s.fromDBItem(dbMessage)
111	if err != nil {
112		return Message{}, err
113	}
114	s.Publish(pubsub.CreatedEvent, message)
115	return message, nil
116}
117
118func (s *service) DeleteSessionMessages(sessionID string) error {
119	messages, err := s.List(sessionID)
120	if err != nil {
121		return err
122	}
123	for _, message := range messages {
124		if message.SessionID == sessionID {
125			err = s.Delete(message.ID)
126			if err != nil {
127				return err
128			}
129		}
130	}
131	return nil
132}
133
134func (s *service) Update(message Message) error {
135	toolCallsStr, err := json.Marshal(message.ToolCalls)
136	if err != nil {
137		return err
138	}
139	toolResultsStr, err := json.Marshal(message.ToolResults)
140	if err != nil {
141		return err
142	}
143	err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
144		ID:          message.ID,
145		Content:     message.Content,
146		Thinking:    message.Thinking,
147		Finished:    message.Finished,
148		ToolCalls:   sql.NullString{String: string(toolCallsStr), Valid: true},
149		ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
150	})
151	if err != nil {
152		return err
153	}
154	s.Publish(pubsub.UpdatedEvent, message)
155	return nil
156}
157
158func (s *service) Get(id string) (Message, error) {
159	dbMessage, err := s.q.GetMessage(s.ctx, id)
160	if err != nil {
161		return Message{}, err
162	}
163	return s.fromDBItem(dbMessage)
164}
165
166func (s *service) List(sessionID string) ([]Message, error) {
167	dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
168	if err != nil {
169		return nil, err
170	}
171	messages := make([]Message, len(dbMessages))
172	for i, dbMessage := range dbMessages {
173		messages[i], err = s.fromDBItem(dbMessage)
174		if err != nil {
175			return nil, err
176		}
177	}
178	return messages, nil
179}
180
181func (s *service) fromDBItem(item db.Message) (Message, error) {
182	toolCalls := make([]ToolCall, 0)
183	if item.ToolCalls.Valid {
184		err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
185		if err != nil {
186			return Message{}, err
187		}
188	}
189
190	toolResults := make([]ToolResult, 0)
191	if item.ToolResults.Valid {
192		err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
193		if err != nil {
194			return Message{}, err
195		}
196	}
197
198	return Message{
199		ID:          item.ID,
200		SessionID:   item.SessionID,
201		Role:        MessageRole(item.Role),
202		Content:     item.Content,
203		Thinking:    item.Thinking,
204		Finished:    item.Finished,
205		ToolCalls:   toolCalls,
206		ToolResults: toolResults,
207		CreatedAt:   item.CreatedAt,
208		UpdatedAt:   item.UpdatedAt,
209	}, nil
210}
211
212func NewService(ctx context.Context, q db.Querier) Service {
213	return &service{
214		Broker: pubsub.NewBroker[Message](),
215		q:      q,
216		ctx:    ctx,
217	}
218}