1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9)
 10
 11type MessageRole string
 12
 13const (
 14	Assistant MessageRole = "assistant"
 15	User      MessageRole = "user"
 16	System    MessageRole = "system"
 17	Tool      MessageRole = "tool"
 18)
 19
 20type FinishReason string
 21
 22const (
 23	FinishReasonEndTurn          FinishReason = "end_turn"
 24	FinishReasonMaxTokens        FinishReason = "max_tokens"
 25	FinishReasonToolUse          FinishReason = "tool_use"
 26	FinishReasonCanceled         FinishReason = "canceled"
 27	FinishReasonError            FinishReason = "error"
 28	FinishReasonPermissionDenied FinishReason = "permission_denied"
 29
 30	// Should never happen
 31	FinishReasonUnknown FinishReason = "unknown"
 32)
 33
 34type ContentPart interface {
 35	isPart()
 36}
 37
 38type ReasoningContent struct {
 39	Thinking   string `json:"thinking"`
 40	Signature  string `json:"signature"`
 41	StartedAt  int64  `json:"started_at,omitempty"`
 42	FinishedAt int64  `json:"finished_at,omitempty"`
 43}
 44
 45func (tc ReasoningContent) String() string {
 46	return tc.Thinking
 47}
 48func (ReasoningContent) isPart() {}
 49
 50type TextContent struct {
 51	Text string `json:"text"`
 52}
 53
 54func (tc TextContent) String() string {
 55	return tc.Text
 56}
 57
 58func (TextContent) isPart() {}
 59
 60type ImageURLContent struct {
 61	URL    string `json:"url"`
 62	Detail string `json:"detail,omitempty"`
 63}
 64
 65func (iuc ImageURLContent) String() string {
 66	return iuc.URL
 67}
 68
 69func (ImageURLContent) isPart() {}
 70
 71type BinaryContent struct {
 72	Path     string
 73	MIMEType string
 74	Data     []byte
 75}
 76
 77func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 78	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 79	if p == catwalk.InferenceProviderOpenAI {
 80		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 81	}
 82	return base64Encoded
 83}
 84
 85func (BinaryContent) isPart() {}
 86
 87type ToolCall struct {
 88	ID       string `json:"id"`
 89	Name     string `json:"name"`
 90	Input    string `json:"input"`
 91	Type     string `json:"type"`
 92	Finished bool   `json:"finished"`
 93}
 94
 95func (ToolCall) isPart() {}
 96
 97type ToolResult struct {
 98	ToolCallID string `json:"tool_call_id"`
 99	Name       string `json:"name"`
100	Content    string `json:"content"`
101	Metadata   string `json:"metadata"`
102	IsError    bool   `json:"is_error"`
103}
104
105func (ToolResult) isPart() {}
106
107type Finish struct {
108	Reason  FinishReason `json:"reason"`
109	Time    int64        `json:"time"`
110	Message string       `json:"message,omitempty"`
111	Details string       `json:"details,omitempty"`
112}
113
114func (Finish) isPart() {}
115
116type Message struct {
117	ID        string
118	Role      MessageRole
119	SessionID string
120	Parts     []ContentPart
121	Model     string
122	Provider  string
123	CreatedAt int64
124	UpdatedAt int64
125}
126
127func (m *Message) Content() TextContent {
128	for _, part := range m.Parts {
129		if c, ok := part.(TextContent); ok {
130			return c
131		}
132	}
133	return TextContent{}
134}
135
136func (m *Message) ReasoningContent() ReasoningContent {
137	for _, part := range m.Parts {
138		if c, ok := part.(ReasoningContent); ok {
139			return c
140		}
141	}
142	return ReasoningContent{}
143}
144
145func (m *Message) ImageURLContent() []ImageURLContent {
146	imageURLContents := make([]ImageURLContent, 0)
147	for _, part := range m.Parts {
148		if c, ok := part.(ImageURLContent); ok {
149			imageURLContents = append(imageURLContents, c)
150		}
151	}
152	return imageURLContents
153}
154
155func (m *Message) BinaryContent() []BinaryContent {
156	binaryContents := make([]BinaryContent, 0)
157	for _, part := range m.Parts {
158		if c, ok := part.(BinaryContent); ok {
159			binaryContents = append(binaryContents, c)
160		}
161	}
162	return binaryContents
163}
164
165func (m *Message) ToolCalls() []ToolCall {
166	toolCalls := make([]ToolCall, 0)
167	for _, part := range m.Parts {
168		if c, ok := part.(ToolCall); ok {
169			toolCalls = append(toolCalls, c)
170		}
171	}
172	return toolCalls
173}
174
175func (m *Message) ToolResults() []ToolResult {
176	toolResults := make([]ToolResult, 0)
177	for _, part := range m.Parts {
178		if c, ok := part.(ToolResult); ok {
179			toolResults = append(toolResults, c)
180		}
181	}
182	return toolResults
183}
184
185func (m *Message) IsFinished() bool {
186	for _, part := range m.Parts {
187		if _, ok := part.(Finish); ok {
188			return true
189		}
190	}
191	return false
192}
193
194func (m *Message) FinishPart() *Finish {
195	for _, part := range m.Parts {
196		if c, ok := part.(Finish); ok {
197			return &c
198		}
199	}
200	return nil
201}
202
203func (m *Message) FinishReason() FinishReason {
204	for _, part := range m.Parts {
205		if c, ok := part.(Finish); ok {
206			return c.Reason
207		}
208	}
209	return ""
210}
211
212func (m *Message) IsThinking() bool {
213	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
214		return true
215	}
216	return false
217}
218
219func (m *Message) AppendContent(delta string) {
220	found := false
221	for i, part := range m.Parts {
222		if c, ok := part.(TextContent); ok {
223			m.Parts[i] = TextContent{Text: c.Text + delta}
224			found = true
225		}
226	}
227	if !found {
228		m.Parts = append(m.Parts, TextContent{Text: delta})
229	}
230}
231
232func (m *Message) AppendReasoningContent(delta string) {
233	found := false
234	for i, part := range m.Parts {
235		if c, ok := part.(ReasoningContent); ok {
236			m.Parts[i] = ReasoningContent{
237				Thinking:   c.Thinking + delta,
238				Signature:  c.Signature,
239				StartedAt:  c.StartedAt,
240				FinishedAt: c.FinishedAt,
241			}
242			found = true
243		}
244	}
245	if !found {
246		m.Parts = append(m.Parts, ReasoningContent{
247			Thinking:  delta,
248			StartedAt: time.Now().Unix(),
249		})
250	}
251}
252
253func (m *Message) AppendReasoningSignature(signature string) {
254	for i, part := range m.Parts {
255		if c, ok := part.(ReasoningContent); ok {
256			m.Parts[i] = ReasoningContent{
257				Thinking:   c.Thinking,
258				Signature:  c.Signature + signature,
259				StartedAt:  c.StartedAt,
260				FinishedAt: c.FinishedAt,
261			}
262			return
263		}
264	}
265	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
266}
267
268func (m *Message) FinishThinking() {
269	for i, part := range m.Parts {
270		if c, ok := part.(ReasoningContent); ok {
271			if c.FinishedAt == 0 {
272				m.Parts[i] = ReasoningContent{
273					Thinking:   c.Thinking,
274					Signature:  c.Signature,
275					StartedAt:  c.StartedAt,
276					FinishedAt: time.Now().Unix(),
277				}
278			}
279			return
280		}
281	}
282}
283
284func (m *Message) ThinkingDuration() time.Duration {
285	reasoning := m.ReasoningContent()
286	if reasoning.StartedAt == 0 {
287		return 0
288	}
289
290	endTime := reasoning.FinishedAt
291	if endTime == 0 {
292		endTime = time.Now().Unix()
293	}
294
295	return time.Duration(endTime-reasoning.StartedAt) * time.Second
296}
297
298func (m *Message) FinishToolCall(toolCallID string) {
299	for i, part := range m.Parts {
300		if c, ok := part.(ToolCall); ok {
301			if c.ID == toolCallID {
302				m.Parts[i] = ToolCall{
303					ID:       c.ID,
304					Name:     c.Name,
305					Input:    c.Input,
306					Type:     c.Type,
307					Finished: true,
308				}
309				return
310			}
311		}
312	}
313}
314
315func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
316	for i, part := range m.Parts {
317		if c, ok := part.(ToolCall); ok {
318			if c.ID == toolCallID {
319				m.Parts[i] = ToolCall{
320					ID:       c.ID,
321					Name:     c.Name,
322					Input:    c.Input + inputDelta,
323					Type:     c.Type,
324					Finished: c.Finished,
325				}
326				return
327			}
328		}
329	}
330}
331
332func (m *Message) AddToolCall(tc ToolCall) {
333	for i, part := range m.Parts {
334		if c, ok := part.(ToolCall); ok {
335			if c.ID == tc.ID {
336				m.Parts[i] = tc
337				return
338			}
339		}
340	}
341	m.Parts = append(m.Parts, tc)
342}
343
344func (m *Message) SetToolCalls(tc []ToolCall) {
345	// remove any existing tool call part it could have multiple
346	parts := make([]ContentPart, 0)
347	for _, part := range m.Parts {
348		if _, ok := part.(ToolCall); ok {
349			continue
350		}
351		parts = append(parts, part)
352	}
353	m.Parts = parts
354	for _, toolCall := range tc {
355		m.Parts = append(m.Parts, toolCall)
356	}
357}
358
359func (m *Message) AddToolResult(tr ToolResult) {
360	m.Parts = append(m.Parts, tr)
361}
362
363func (m *Message) SetToolResults(tr []ToolResult) {
364	for _, toolResult := range tr {
365		m.Parts = append(m.Parts, toolResult)
366	}
367}
368
369func (m *Message) AddFinish(reason FinishReason, message, details string) {
370	// remove any existing finish part
371	for i, part := range m.Parts {
372		if _, ok := part.(Finish); ok {
373			m.Parts = slices.Delete(m.Parts, i, i+1)
374			break
375		}
376	}
377	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
378}
379
380func (m *Message) AddImageURL(url, detail string) {
381	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
382}
383
384func (m *Message) AddBinary(mimeType string, data []byte) {
385	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
386}