message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8
  9	"github.com/google/uuid"
 10	"github.com/kujtimiihoxha/termai/internal/db"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/pubsub"
 13)
 14
 15type CreateMessageParams struct {
 16	Role  MessageRole
 17	Parts []ContentPart
 18	Model models.ModelID
 19}
 20
 21type Service interface {
 22	pubsub.Suscriber[Message]
 23	Create(sessionID string, params CreateMessageParams) (Message, error)
 24	Update(message Message) error
 25	Get(id string) (Message, error)
 26	List(sessionID string) ([]Message, error)
 27	Delete(id string) error
 28	DeleteSessionMessages(sessionID string) error
 29}
 30
 31type service struct {
 32	*pubsub.Broker[Message]
 33	q   db.Querier
 34	ctx context.Context
 35}
 36
 37func NewService(ctx context.Context, q db.Querier) Service {
 38	return &service{
 39		Broker: pubsub.NewBroker[Message](),
 40		q:      q,
 41		ctx:    ctx,
 42	}
 43}
 44
 45func (s *service) Delete(id string) error {
 46	message, err := s.Get(id)
 47	if err != nil {
 48		return err
 49	}
 50	err = s.q.DeleteMessage(s.ctx, message.ID)
 51	if err != nil {
 52		return err
 53	}
 54	s.Publish(pubsub.DeletedEvent, message)
 55	return nil
 56}
 57
 58func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
 59	if params.Role != Assistant {
 60		params.Parts = append(params.Parts, Finish{
 61			Reason: "stop",
 62		})
 63	}
 64	partsJSON, err := marshallParts(params.Parts)
 65	if err != nil {
 66		return Message{}, err
 67	}
 68
 69	dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
 70		ID:        uuid.New().String(),
 71		SessionID: sessionID,
 72		Role:      string(params.Role),
 73		Parts:     string(partsJSON),
 74		Model:     sql.NullString{String: string(params.Model), Valid: true},
 75	})
 76	if err != nil {
 77		return Message{}, err
 78	}
 79	message, err := s.fromDBItem(dbMessage)
 80	if err != nil {
 81		return Message{}, err
 82	}
 83	s.Publish(pubsub.CreatedEvent, message)
 84	return message, nil
 85}
 86
 87func (s *service) DeleteSessionMessages(sessionID string) error {
 88	messages, err := s.List(sessionID)
 89	if err != nil {
 90		return err
 91	}
 92	for _, message := range messages {
 93		if message.SessionID == sessionID {
 94			err = s.Delete(message.ID)
 95			if err != nil {
 96				return err
 97			}
 98		}
 99	}
100	return nil
101}
102
103func (s *service) Update(message Message) error {
104	parts, err := marshallParts(message.Parts)
105	if err != nil {
106		return err
107	}
108	finishedAt := sql.NullInt64{}
109	if f := message.FinishPart(); f != nil {
110		finishedAt.Int64 = f.Time
111		finishedAt.Valid = true
112	}
113	err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
114		ID:         message.ID,
115		Parts:      string(parts),
116		FinishedAt: finishedAt,
117	})
118	if err != nil {
119		return err
120	}
121	s.Publish(pubsub.UpdatedEvent, message)
122	return nil
123}
124
125func (s *service) Get(id string) (Message, error) {
126	dbMessage, err := s.q.GetMessage(s.ctx, id)
127	if err != nil {
128		return Message{}, err
129	}
130	return s.fromDBItem(dbMessage)
131}
132
133func (s *service) List(sessionID string) ([]Message, error) {
134	dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
135	if err != nil {
136		return nil, err
137	}
138	messages := make([]Message, len(dbMessages))
139	for i, dbMessage := range dbMessages {
140		messages[i], err = s.fromDBItem(dbMessage)
141		if err != nil {
142			return nil, err
143		}
144	}
145	return messages, nil
146}
147
148func (s *service) fromDBItem(item db.Message) (Message, error) {
149	parts, err := unmarshallParts([]byte(item.Parts))
150	if err != nil {
151		return Message{}, err
152	}
153	return Message{
154		ID:        item.ID,
155		SessionID: item.SessionID,
156		Role:      MessageRole(item.Role),
157		Parts:     parts,
158		Model:     models.ModelID(item.Model.String),
159		CreatedAt: item.CreatedAt,
160		UpdatedAt: item.UpdatedAt,
161	}, nil
162}
163
164type partType string
165
166const (
167	reasoningType  partType = "reasoning"
168	textType       partType = "text"
169	imageURLType   partType = "image_url"
170	binaryType     partType = "binary"
171	toolCallType   partType = "tool_call"
172	toolResultType partType = "tool_result"
173	finishType     partType = "finish"
174)
175
176type partWrapper struct {
177	Type partType    `json:"type"`
178	Data ContentPart `json:"data"`
179}
180
181func marshallParts(parts []ContentPart) ([]byte, error) {
182	wrappedParts := make([]partWrapper, len(parts))
183
184	for i, part := range parts {
185		var typ partType
186
187		switch part.(type) {
188		case ReasoningContent:
189			typ = reasoningType
190		case TextContent:
191			typ = textType
192		case ImageURLContent:
193			typ = imageURLType
194		case BinaryContent:
195			typ = binaryType
196		case ToolCall:
197			typ = toolCallType
198		case ToolResult:
199			typ = toolResultType
200		case Finish:
201			typ = finishType
202		default:
203			return nil, fmt.Errorf("unknown part type: %T", part)
204		}
205
206		wrappedParts[i] = partWrapper{
207			Type: typ,
208			Data: part,
209		}
210	}
211	return json.Marshal(wrappedParts)
212}
213
214func unmarshallParts(data []byte) ([]ContentPart, error) {
215	temp := []json.RawMessage{}
216
217	if err := json.Unmarshal(data, &temp); err != nil {
218		return nil, err
219	}
220
221	parts := make([]ContentPart, 0)
222
223	for _, rawPart := range temp {
224		var wrapper struct {
225			Type partType        `json:"type"`
226			Data json.RawMessage `json:"data"`
227		}
228
229		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
230			return nil, err
231		}
232
233		switch wrapper.Type {
234		case reasoningType:
235			part := ReasoningContent{}
236			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
237				return nil, err
238			}
239			parts = append(parts, part)
240		case textType:
241			part := TextContent{}
242			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
243				return nil, err
244			}
245			parts = append(parts, part)
246		case imageURLType:
247			part := ImageURLContent{}
248			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
249				return nil, err
250			}
251		case binaryType:
252			part := BinaryContent{}
253			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
254				return nil, err
255			}
256			parts = append(parts, part)
257		case toolCallType:
258			part := ToolCall{}
259			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
260				return nil, err
261			}
262			parts = append(parts, part)
263		case toolResultType:
264			part := ToolResult{}
265			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
266				return nil, err
267			}
268			parts = append(parts, part)
269		case finishType:
270			part := Finish{}
271			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
272				return nil, err
273			}
274			parts = append(parts, part)
275		default:
276			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
277		}
278
279	}
280
281	return parts, nil
282}