message.go

  1package message
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8
  9	"github.com/google/uuid"
 10	"github.com/kujtimiihoxha/opencode/internal/db"
 11	"github.com/kujtimiihoxha/opencode/internal/llm/models"
 12	"github.com/kujtimiihoxha/opencode/internal/pubsub"
 13)
 14
 15type CreateMessageParams struct {
 16	Role  MessageRole
 17	Parts []ContentPart
 18	Model models.ModelID
 19}
 20
 21type Service interface {
 22	pubsub.Suscriber[Message]
 23	Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
 24	Update(ctx context.Context, message Message) error
 25	Get(ctx context.Context, id string) (Message, error)
 26	List(ctx context.Context, sessionID string) ([]Message, error)
 27	Delete(ctx context.Context, id string) error
 28	DeleteSessionMessages(ctx context.Context, sessionID string) error
 29}
 30
 31type service struct {
 32	*pubsub.Broker[Message]
 33	q db.Querier
 34}
 35
 36func NewService(q db.Querier) Service {
 37	return &service{
 38		Broker: pubsub.NewBroker[Message](),
 39		q:      q,
 40	}
 41}
 42
 43func (s *service) Delete(ctx context.Context, id string) error {
 44	message, err := s.Get(ctx, id)
 45	if err != nil {
 46		return err
 47	}
 48	err = s.q.DeleteMessage(ctx, message.ID)
 49	if err != nil {
 50		return err
 51	}
 52	s.Publish(pubsub.DeletedEvent, message)
 53	return nil
 54}
 55
 56func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
 57	if params.Role != Assistant {
 58		params.Parts = append(params.Parts, Finish{
 59			Reason: "stop",
 60		})
 61	}
 62	partsJSON, err := marshallParts(params.Parts)
 63	if err != nil {
 64		return Message{}, err
 65	}
 66
 67	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 68		ID:        uuid.New().String(),
 69		SessionID: sessionID,
 70		Role:      string(params.Role),
 71		Parts:     string(partsJSON),
 72		Model:     sql.NullString{String: string(params.Model), Valid: true},
 73	})
 74	if err != nil {
 75		return Message{}, err
 76	}
 77	message, err := s.fromDBItem(dbMessage)
 78	if err != nil {
 79		return Message{}, err
 80	}
 81	s.Publish(pubsub.CreatedEvent, message)
 82	return message, nil
 83}
 84
 85func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 86	messages, err := s.List(ctx, sessionID)
 87	if err != nil {
 88		return err
 89	}
 90	for _, message := range messages {
 91		if message.SessionID == sessionID {
 92			err = s.Delete(ctx, message.ID)
 93			if err != nil {
 94				return err
 95			}
 96		}
 97	}
 98	return nil
 99}
100
101func (s *service) Update(ctx context.Context, message Message) error {
102	parts, err := marshallParts(message.Parts)
103	if err != nil {
104		return err
105	}
106	finishedAt := sql.NullInt64{}
107	if f := message.FinishPart(); f != nil {
108		finishedAt.Int64 = f.Time
109		finishedAt.Valid = true
110	}
111	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
112		ID:         message.ID,
113		Parts:      string(parts),
114		FinishedAt: finishedAt,
115	})
116	if err != nil {
117		return err
118	}
119	s.Publish(pubsub.UpdatedEvent, message)
120	return nil
121}
122
123func (s *service) Get(ctx context.Context, id string) (Message, error) {
124	dbMessage, err := s.q.GetMessage(ctx, id)
125	if err != nil {
126		return Message{}, err
127	}
128	return s.fromDBItem(dbMessage)
129}
130
131func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
132	dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
133	if err != nil {
134		return nil, err
135	}
136	messages := make([]Message, len(dbMessages))
137	for i, dbMessage := range dbMessages {
138		messages[i], err = s.fromDBItem(dbMessage)
139		if err != nil {
140			return nil, err
141		}
142	}
143	return messages, nil
144}
145
146func (s *service) fromDBItem(item db.Message) (Message, error) {
147	parts, err := unmarshallParts([]byte(item.Parts))
148	if err != nil {
149		return Message{}, err
150	}
151	return Message{
152		ID:        item.ID,
153		SessionID: item.SessionID,
154		Role:      MessageRole(item.Role),
155		Parts:     parts,
156		Model:     models.ModelID(item.Model.String),
157		CreatedAt: item.CreatedAt,
158		UpdatedAt: item.UpdatedAt,
159	}, nil
160}
161
162type partType string
163
164const (
165	reasoningType  partType = "reasoning"
166	textType       partType = "text"
167	imageURLType   partType = "image_url"
168	binaryType     partType = "binary"
169	toolCallType   partType = "tool_call"
170	toolResultType partType = "tool_result"
171	finishType     partType = "finish"
172)
173
174type partWrapper struct {
175	Type partType    `json:"type"`
176	Data ContentPart `json:"data"`
177}
178
179func marshallParts(parts []ContentPart) ([]byte, error) {
180	wrappedParts := make([]partWrapper, len(parts))
181
182	for i, part := range parts {
183		var typ partType
184
185		switch part.(type) {
186		case ReasoningContent:
187			typ = reasoningType
188		case TextContent:
189			typ = textType
190		case ImageURLContent:
191			typ = imageURLType
192		case BinaryContent:
193			typ = binaryType
194		case ToolCall:
195			typ = toolCallType
196		case ToolResult:
197			typ = toolResultType
198		case Finish:
199			typ = finishType
200		default:
201			return nil, fmt.Errorf("unknown part type: %T", part)
202		}
203
204		wrappedParts[i] = partWrapper{
205			Type: typ,
206			Data: part,
207		}
208	}
209	return json.Marshal(wrappedParts)
210}
211
212func unmarshallParts(data []byte) ([]ContentPart, error) {
213	temp := []json.RawMessage{}
214
215	if err := json.Unmarshal(data, &temp); err != nil {
216		return nil, err
217	}
218
219	parts := make([]ContentPart, 0)
220
221	for _, rawPart := range temp {
222		var wrapper struct {
223			Type partType        `json:"type"`
224			Data json.RawMessage `json:"data"`
225		}
226
227		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
228			return nil, err
229		}
230
231		switch wrapper.Type {
232		case reasoningType:
233			part := ReasoningContent{}
234			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
235				return nil, err
236			}
237			parts = append(parts, part)
238		case textType:
239			part := TextContent{}
240			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
241				return nil, err
242			}
243			parts = append(parts, part)
244		case imageURLType:
245			part := ImageURLContent{}
246			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
247				return nil, err
248			}
249		case binaryType:
250			part := BinaryContent{}
251			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
252				return nil, err
253			}
254			parts = append(parts, part)
255		case toolCallType:
256			part := ToolCall{}
257			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
258				return nil, err
259			}
260			parts = append(parts, part)
261		case toolResultType:
262			part := ToolResult{}
263			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
264				return nil, err
265			}
266			parts = append(parts, part)
267		case finishType:
268			part := Finish{}
269			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
270				return nil, err
271			}
272			parts = append(parts, part)
273		default:
274			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
275		}
276
277	}
278
279	return parts, nil
280}