1package message
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"slices"
  7	"time"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10)
 11
 12type MessageRole string
 13
 14const (
 15	Assistant MessageRole = "assistant"
 16	User      MessageRole = "user"
 17	System    MessageRole = "system"
 18	Tool      MessageRole = "tool"
 19)
 20
 21func (r MessageRole) MarshalText() ([]byte, error) {
 22	return []byte(r), nil
 23}
 24
 25func (r *MessageRole) UnmarshalText(data []byte) error {
 26	*r = MessageRole(data)
 27	return nil
 28}
 29
 30type FinishReason string
 31
 32const (
 33	FinishReasonEndTurn          FinishReason = "end_turn"
 34	FinishReasonMaxTokens        FinishReason = "max_tokens"
 35	FinishReasonToolUse          FinishReason = "tool_use"
 36	FinishReasonCanceled         FinishReason = "canceled"
 37	FinishReasonError            FinishReason = "error"
 38	FinishReasonPermissionDenied FinishReason = "permission_denied"
 39
 40	// Should never happen
 41	FinishReasonUnknown FinishReason = "unknown"
 42)
 43
 44func (fr FinishReason) MarshalText() ([]byte, error) {
 45	return []byte(fr), nil
 46}
 47
 48func (fr *FinishReason) UnmarshalText(data []byte) error {
 49	*fr = FinishReason(data)
 50	return nil
 51}
 52
 53type ContentPart interface {
 54	isPart()
 55}
 56
 57type ReasoningContent struct {
 58	Thinking   string `json:"thinking"`
 59	Signature  string `json:"signature"`
 60	StartedAt  int64  `json:"started_at,omitempty"`
 61	FinishedAt int64  `json:"finished_at,omitempty"`
 62}
 63
 64func (tc ReasoningContent) String() string {
 65	return tc.Thinking
 66}
 67func (ReasoningContent) isPart() {}
 68
 69type TextContent struct {
 70	Text string `json:"text"`
 71}
 72
 73func (tc TextContent) String() string {
 74	return tc.Text
 75}
 76
 77func (TextContent) isPart() {}
 78
 79type ImageURLContent struct {
 80	URL    string `json:"url"`
 81	Detail string `json:"detail,omitempty"`
 82}
 83
 84func (iuc ImageURLContent) String() string {
 85	return iuc.URL
 86}
 87
 88func (ImageURLContent) isPart() {}
 89
 90type BinaryContent struct {
 91	Path     string
 92	MIMEType string
 93	Data     []byte
 94}
 95
 96func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 97	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 98	if p == catwalk.InferenceProviderOpenAI {
 99		return "data:" + bc.MIMEType + ";base64," + base64Encoded
100	}
101	return base64Encoded
102}
103
104func (BinaryContent) isPart() {}
105
106type ToolCall struct {
107	ID       string `json:"id"`
108	Name     string `json:"name"`
109	Input    string `json:"input"`
110	Type     string `json:"type"`
111	Finished bool   `json:"finished"`
112}
113
114func (ToolCall) isPart() {}
115
116type ToolResult struct {
117	ToolCallID string `json:"tool_call_id"`
118	Name       string `json:"name"`
119	Content    string `json:"content"`
120	Metadata   string `json:"metadata"`
121	IsError    bool   `json:"is_error"`
122}
123
124func (ToolResult) isPart() {}
125
126type Finish struct {
127	Reason  FinishReason `json:"reason"`
128	Time    int64        `json:"time"`
129	Message string       `json:"message,omitempty"`
130	Details string       `json:"details,omitempty"`
131}
132
133func (Finish) isPart() {}
134
135type Message struct {
136	ID        string        `json:"id"`
137	Role      MessageRole   `json:"role"`
138	SessionID string        `json:"session_id"`
139	Parts     []ContentPart `json:"parts"`
140	Model     string        `json:"model"`
141	Provider  string        `json:"provider"`
142	CreatedAt int64         `json:"created_at"`
143	UpdatedAt int64         `json:"updated_at"`
144}
145
146// MarshalJSON implements the [json.Marshaler] interface.
147func (m Message) MarshalJSON() ([]byte, error) {
148	// We need to handle the Parts specially since they're ContentPart interfaces
149	// which can't be directly marshaled by the standard JSON package.
150	parts, err := marshallParts(m.Parts)
151	if err != nil {
152		return nil, err
153	}
154
155	// Create an alias to avoid infinite recursion
156	type Alias Message
157	return json.Marshal(&struct {
158		Parts json.RawMessage `json:"parts"`
159		*Alias
160	}{
161		Parts: json.RawMessage(parts),
162		Alias: (*Alias)(&m),
163	})
164}
165
166// UnmarshalJSON implements the [json.Unmarshaler] interface.
167func (m *Message) UnmarshalJSON(data []byte) error {
168	// Create an alias to avoid infinite recursion
169	type Alias Message
170	aux := &struct {
171		Parts json.RawMessage `json:"parts"`
172		*Alias
173	}{
174		Alias: (*Alias)(m),
175	}
176
177	if err := json.Unmarshal(data, &aux); err != nil {
178		return err
179	}
180
181	// Unmarshal the parts using our custom function
182	parts, err := unmarshallParts([]byte(aux.Parts))
183	if err != nil {
184		return err
185	}
186
187	m.Parts = parts
188	return nil
189}
190
191func (m *Message) Content() TextContent {
192	for _, part := range m.Parts {
193		if c, ok := part.(TextContent); ok {
194			return c
195		}
196	}
197	return TextContent{}
198}
199
200func (m *Message) ReasoningContent() ReasoningContent {
201	for _, part := range m.Parts {
202		if c, ok := part.(ReasoningContent); ok {
203			return c
204		}
205	}
206	return ReasoningContent{}
207}
208
209func (m *Message) ImageURLContent() []ImageURLContent {
210	imageURLContents := make([]ImageURLContent, 0)
211	for _, part := range m.Parts {
212		if c, ok := part.(ImageURLContent); ok {
213			imageURLContents = append(imageURLContents, c)
214		}
215	}
216	return imageURLContents
217}
218
219func (m *Message) BinaryContent() []BinaryContent {
220	binaryContents := make([]BinaryContent, 0)
221	for _, part := range m.Parts {
222		if c, ok := part.(BinaryContent); ok {
223			binaryContents = append(binaryContents, c)
224		}
225	}
226	return binaryContents
227}
228
229func (m *Message) ToolCalls() []ToolCall {
230	toolCalls := make([]ToolCall, 0)
231	for _, part := range m.Parts {
232		if c, ok := part.(ToolCall); ok {
233			toolCalls = append(toolCalls, c)
234		}
235	}
236	return toolCalls
237}
238
239func (m *Message) ToolResults() []ToolResult {
240	toolResults := make([]ToolResult, 0)
241	for _, part := range m.Parts {
242		if c, ok := part.(ToolResult); ok {
243			toolResults = append(toolResults, c)
244		}
245	}
246	return toolResults
247}
248
249func (m *Message) IsFinished() bool {
250	for _, part := range m.Parts {
251		if _, ok := part.(Finish); ok {
252			return true
253		}
254	}
255	return false
256}
257
258func (m *Message) FinishPart() *Finish {
259	for _, part := range m.Parts {
260		if c, ok := part.(Finish); ok {
261			return &c
262		}
263	}
264	return nil
265}
266
267func (m *Message) FinishReason() FinishReason {
268	for _, part := range m.Parts {
269		if c, ok := part.(Finish); ok {
270			return c.Reason
271		}
272	}
273	return ""
274}
275
276func (m *Message) IsThinking() bool {
277	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
278		return true
279	}
280	return false
281}
282
283func (m *Message) AppendContent(delta string) {
284	found := false
285	for i, part := range m.Parts {
286		if c, ok := part.(TextContent); ok {
287			m.Parts[i] = TextContent{Text: c.Text + delta}
288			found = true
289		}
290	}
291	if !found {
292		m.Parts = append(m.Parts, TextContent{Text: delta})
293	}
294}
295
296func (m *Message) AppendReasoningContent(delta string) {
297	found := false
298	for i, part := range m.Parts {
299		if c, ok := part.(ReasoningContent); ok {
300			m.Parts[i] = ReasoningContent{
301				Thinking:   c.Thinking + delta,
302				Signature:  c.Signature,
303				StartedAt:  c.StartedAt,
304				FinishedAt: c.FinishedAt,
305			}
306			found = true
307		}
308	}
309	if !found {
310		m.Parts = append(m.Parts, ReasoningContent{
311			Thinking:  delta,
312			StartedAt: time.Now().Unix(),
313		})
314	}
315}
316
317func (m *Message) AppendReasoningSignature(signature string) {
318	for i, part := range m.Parts {
319		if c, ok := part.(ReasoningContent); ok {
320			m.Parts[i] = ReasoningContent{
321				Thinking:   c.Thinking,
322				Signature:  c.Signature + signature,
323				StartedAt:  c.StartedAt,
324				FinishedAt: c.FinishedAt,
325			}
326			return
327		}
328	}
329	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
330}
331
332func (m *Message) FinishThinking() {
333	for i, part := range m.Parts {
334		if c, ok := part.(ReasoningContent); ok {
335			if c.FinishedAt == 0 {
336				m.Parts[i] = ReasoningContent{
337					Thinking:   c.Thinking,
338					Signature:  c.Signature,
339					StartedAt:  c.StartedAt,
340					FinishedAt: time.Now().Unix(),
341				}
342			}
343			return
344		}
345	}
346}
347
348func (m *Message) ThinkingDuration() time.Duration {
349	reasoning := m.ReasoningContent()
350	if reasoning.StartedAt == 0 {
351		return 0
352	}
353
354	endTime := reasoning.FinishedAt
355	if endTime == 0 {
356		endTime = time.Now().Unix()
357	}
358
359	return time.Duration(endTime-reasoning.StartedAt) * time.Second
360}
361
362func (m *Message) FinishToolCall(toolCallID string) {
363	for i, part := range m.Parts {
364		if c, ok := part.(ToolCall); ok {
365			if c.ID == toolCallID {
366				m.Parts[i] = ToolCall{
367					ID:       c.ID,
368					Name:     c.Name,
369					Input:    c.Input,
370					Type:     c.Type,
371					Finished: true,
372				}
373				return
374			}
375		}
376	}
377}
378
379func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
380	for i, part := range m.Parts {
381		if c, ok := part.(ToolCall); ok {
382			if c.ID == toolCallID {
383				m.Parts[i] = ToolCall{
384					ID:       c.ID,
385					Name:     c.Name,
386					Input:    c.Input + inputDelta,
387					Type:     c.Type,
388					Finished: c.Finished,
389				}
390				return
391			}
392		}
393	}
394}
395
396func (m *Message) AddToolCall(tc ToolCall) {
397	for i, part := range m.Parts {
398		if c, ok := part.(ToolCall); ok {
399			if c.ID == tc.ID {
400				m.Parts[i] = tc
401				return
402			}
403		}
404	}
405	m.Parts = append(m.Parts, tc)
406}
407
408func (m *Message) SetToolCalls(tc []ToolCall) {
409	// remove any existing tool call part it could have multiple
410	parts := make([]ContentPart, 0)
411	for _, part := range m.Parts {
412		if _, ok := part.(ToolCall); ok {
413			continue
414		}
415		parts = append(parts, part)
416	}
417	m.Parts = parts
418	for _, toolCall := range tc {
419		m.Parts = append(m.Parts, toolCall)
420	}
421}
422
423func (m *Message) AddToolResult(tr ToolResult) {
424	m.Parts = append(m.Parts, tr)
425}
426
427func (m *Message) SetToolResults(tr []ToolResult) {
428	for _, toolResult := range tr {
429		m.Parts = append(m.Parts, toolResult)
430	}
431}
432
433func (m *Message) AddFinish(reason FinishReason, message, details string) {
434	// remove any existing finish part
435	for i, part := range m.Parts {
436		if _, ok := part.(Finish); ok {
437			m.Parts = slices.Delete(m.Parts, i, i+1)
438			break
439		}
440	}
441	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
442}
443
444func (m *Message) AddImageURL(url, detail string) {
445	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
446}
447
448func (m *Message) AddBinary(mimeType string, data []byte) {
449	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
450}