content.go

  1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9)
 10
 11type MessageRole string
 12
 13const (
 14	Assistant MessageRole = "assistant"
 15	User      MessageRole = "user"
 16	System    MessageRole = "system"
 17	Tool      MessageRole = "tool"
 18)
 19
 20func (r MessageRole) MarshalText() ([]byte, error) {
 21	return []byte(r), nil
 22}
 23
 24func (r *MessageRole) UnmarshalText(data []byte) error {
 25	*r = MessageRole(data)
 26	return nil
 27}
 28
 29type FinishReason string
 30
 31const (
 32	FinishReasonEndTurn          FinishReason = "end_turn"
 33	FinishReasonMaxTokens        FinishReason = "max_tokens"
 34	FinishReasonToolUse          FinishReason = "tool_use"
 35	FinishReasonCanceled         FinishReason = "canceled"
 36	FinishReasonError            FinishReason = "error"
 37	FinishReasonPermissionDenied FinishReason = "permission_denied"
 38
 39	// Should never happen
 40	FinishReasonUnknown FinishReason = "unknown"
 41)
 42
 43func (fr FinishReason) MarshalText() ([]byte, error) {
 44	return []byte(fr), nil
 45}
 46
 47func (fr *FinishReason) UnmarshalText(data []byte) error {
 48	*fr = FinishReason(data)
 49	return nil
 50}
 51
 52type ContentPart interface {
 53	isPart()
 54}
 55
 56type ReasoningContent struct {
 57	Thinking   string `json:"thinking"`
 58	Signature  string `json:"signature"`
 59	StartedAt  int64  `json:"started_at,omitempty"`
 60	FinishedAt int64  `json:"finished_at,omitempty"`
 61}
 62
 63func (tc ReasoningContent) String() string {
 64	return tc.Thinking
 65}
 66func (ReasoningContent) isPart() {}
 67
 68type TextContent struct {
 69	Text string `json:"text"`
 70}
 71
 72func (tc TextContent) String() string {
 73	return tc.Text
 74}
 75
 76func (TextContent) isPart() {}
 77
 78type ImageURLContent struct {
 79	URL    string `json:"url"`
 80	Detail string `json:"detail,omitempty"`
 81}
 82
 83func (iuc ImageURLContent) String() string {
 84	return iuc.URL
 85}
 86
 87func (ImageURLContent) isPart() {}
 88
 89type BinaryContent struct {
 90	Path     string
 91	MIMEType string
 92	Data     []byte
 93}
 94
 95func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 96	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 97	if p == catwalk.InferenceProviderOpenAI {
 98		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 99	}
100	return base64Encoded
101}
102
103func (BinaryContent) isPart() {}
104
105type ToolCall struct {
106	ID       string `json:"id"`
107	Name     string `json:"name"`
108	Input    string `json:"input"`
109	Type     string `json:"type"`
110	Finished bool   `json:"finished"`
111}
112
113func (ToolCall) isPart() {}
114
115type ToolResult struct {
116	ToolCallID string `json:"tool_call_id"`
117	Name       string `json:"name"`
118	Content    string `json:"content"`
119	Metadata   string `json:"metadata"`
120	IsError    bool   `json:"is_error"`
121}
122
123func (ToolResult) isPart() {}
124
125type Finish struct {
126	Reason  FinishReason `json:"reason"`
127	Time    int64        `json:"time"`
128	Message string       `json:"message,omitempty"`
129	Details string       `json:"details,omitempty"`
130}
131
132func (Finish) isPart() {}
133
134type Message struct {
135	ID        string        `json:"id"`
136	Role      MessageRole   `json:"role"`
137	SessionID string        `json:"session_id"`
138	Parts     []ContentPart `json:"parts"`
139	Model     string        `json:"model"`
140	Provider  string        `json:"provider"`
141	CreatedAt int64         `json:"created_at"`
142	UpdatedAt int64         `json:"updated_at"`
143}
144
145func (m *Message) Content() TextContent {
146	for _, part := range m.Parts {
147		if c, ok := part.(TextContent); ok {
148			return c
149		}
150	}
151	return TextContent{}
152}
153
154func (m *Message) ReasoningContent() ReasoningContent {
155	for _, part := range m.Parts {
156		if c, ok := part.(ReasoningContent); ok {
157			return c
158		}
159	}
160	return ReasoningContent{}
161}
162
163func (m *Message) ImageURLContent() []ImageURLContent {
164	imageURLContents := make([]ImageURLContent, 0)
165	for _, part := range m.Parts {
166		if c, ok := part.(ImageURLContent); ok {
167			imageURLContents = append(imageURLContents, c)
168		}
169	}
170	return imageURLContents
171}
172
173func (m *Message) BinaryContent() []BinaryContent {
174	binaryContents := make([]BinaryContent, 0)
175	for _, part := range m.Parts {
176		if c, ok := part.(BinaryContent); ok {
177			binaryContents = append(binaryContents, c)
178		}
179	}
180	return binaryContents
181}
182
183func (m *Message) ToolCalls() []ToolCall {
184	toolCalls := make([]ToolCall, 0)
185	for _, part := range m.Parts {
186		if c, ok := part.(ToolCall); ok {
187			toolCalls = append(toolCalls, c)
188		}
189	}
190	return toolCalls
191}
192
193func (m *Message) ToolResults() []ToolResult {
194	toolResults := make([]ToolResult, 0)
195	for _, part := range m.Parts {
196		if c, ok := part.(ToolResult); ok {
197			toolResults = append(toolResults, c)
198		}
199	}
200	return toolResults
201}
202
203func (m *Message) IsFinished() bool {
204	for _, part := range m.Parts {
205		if _, ok := part.(Finish); ok {
206			return true
207		}
208	}
209	return false
210}
211
212func (m *Message) FinishPart() *Finish {
213	for _, part := range m.Parts {
214		if c, ok := part.(Finish); ok {
215			return &c
216		}
217	}
218	return nil
219}
220
221func (m *Message) FinishReason() FinishReason {
222	for _, part := range m.Parts {
223		if c, ok := part.(Finish); ok {
224			return c.Reason
225		}
226	}
227	return ""
228}
229
230func (m *Message) IsThinking() bool {
231	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
232		return true
233	}
234	return false
235}
236
237func (m *Message) AppendContent(delta string) {
238	found := false
239	for i, part := range m.Parts {
240		if c, ok := part.(TextContent); ok {
241			m.Parts[i] = TextContent{Text: c.Text + delta}
242			found = true
243		}
244	}
245	if !found {
246		m.Parts = append(m.Parts, TextContent{Text: delta})
247	}
248}
249
250func (m *Message) AppendReasoningContent(delta string) {
251	found := false
252	for i, part := range m.Parts {
253		if c, ok := part.(ReasoningContent); ok {
254			m.Parts[i] = ReasoningContent{
255				Thinking:   c.Thinking + delta,
256				Signature:  c.Signature,
257				StartedAt:  c.StartedAt,
258				FinishedAt: c.FinishedAt,
259			}
260			found = true
261		}
262	}
263	if !found {
264		m.Parts = append(m.Parts, ReasoningContent{
265			Thinking:  delta,
266			StartedAt: time.Now().Unix(),
267		})
268	}
269}
270
271func (m *Message) AppendReasoningSignature(signature string) {
272	for i, part := range m.Parts {
273		if c, ok := part.(ReasoningContent); ok {
274			m.Parts[i] = ReasoningContent{
275				Thinking:   c.Thinking,
276				Signature:  c.Signature + signature,
277				StartedAt:  c.StartedAt,
278				FinishedAt: c.FinishedAt,
279			}
280			return
281		}
282	}
283	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
284}
285
286func (m *Message) FinishThinking() {
287	for i, part := range m.Parts {
288		if c, ok := part.(ReasoningContent); ok {
289			if c.FinishedAt == 0 {
290				m.Parts[i] = ReasoningContent{
291					Thinking:   c.Thinking,
292					Signature:  c.Signature,
293					StartedAt:  c.StartedAt,
294					FinishedAt: time.Now().Unix(),
295				}
296			}
297			return
298		}
299	}
300}
301
302func (m *Message) ThinkingDuration() time.Duration {
303	reasoning := m.ReasoningContent()
304	if reasoning.StartedAt == 0 {
305		return 0
306	}
307
308	endTime := reasoning.FinishedAt
309	if endTime == 0 {
310		endTime = time.Now().Unix()
311	}
312
313	return time.Duration(endTime-reasoning.StartedAt) * time.Second
314}
315
316func (m *Message) FinishToolCall(toolCallID string) {
317	for i, part := range m.Parts {
318		if c, ok := part.(ToolCall); ok {
319			if c.ID == toolCallID {
320				m.Parts[i] = ToolCall{
321					ID:       c.ID,
322					Name:     c.Name,
323					Input:    c.Input,
324					Type:     c.Type,
325					Finished: true,
326				}
327				return
328			}
329		}
330	}
331}
332
333func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
334	for i, part := range m.Parts {
335		if c, ok := part.(ToolCall); ok {
336			if c.ID == toolCallID {
337				m.Parts[i] = ToolCall{
338					ID:       c.ID,
339					Name:     c.Name,
340					Input:    c.Input + inputDelta,
341					Type:     c.Type,
342					Finished: c.Finished,
343				}
344				return
345			}
346		}
347	}
348}
349
350func (m *Message) AddToolCall(tc ToolCall) {
351	for i, part := range m.Parts {
352		if c, ok := part.(ToolCall); ok {
353			if c.ID == tc.ID {
354				m.Parts[i] = tc
355				return
356			}
357		}
358	}
359	m.Parts = append(m.Parts, tc)
360}
361
362func (m *Message) SetToolCalls(tc []ToolCall) {
363	// remove any existing tool call part it could have multiple
364	parts := make([]ContentPart, 0)
365	for _, part := range m.Parts {
366		if _, ok := part.(ToolCall); ok {
367			continue
368		}
369		parts = append(parts, part)
370	}
371	m.Parts = parts
372	for _, toolCall := range tc {
373		m.Parts = append(m.Parts, toolCall)
374	}
375}
376
377func (m *Message) AddToolResult(tr ToolResult) {
378	m.Parts = append(m.Parts, tr)
379}
380
381func (m *Message) SetToolResults(tr []ToolResult) {
382	for _, toolResult := range tr {
383		m.Parts = append(m.Parts, toolResult)
384	}
385}
386
387func (m *Message) AddFinish(reason FinishReason, message, details string) {
388	// remove any existing finish part
389	for i, part := range m.Parts {
390		if _, ok := part.(Finish); ok {
391			m.Parts = slices.Delete(m.Parts, i, i+1)
392			break
393		}
394	}
395	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
396}
397
398func (m *Message) AddImageURL(url, detail string) {
399	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
400}
401
402func (m *Message) AddBinary(mimeType string, data []byte) {
403	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
404}