content.go

  1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9)
 10
 11type MessageRole string
 12
 13const (
 14	Assistant MessageRole = "assistant"
 15	User      MessageRole = "user"
 16	System    MessageRole = "system"
 17	Tool      MessageRole = "tool"
 18)
 19
 20type FinishReason string
 21
 22const (
 23	FinishReasonEndTurn          FinishReason = "end_turn"
 24	FinishReasonMaxTokens        FinishReason = "max_tokens"
 25	FinishReasonToolUse          FinishReason = "tool_use"
 26	FinishReasonCanceled         FinishReason = "canceled"
 27	FinishReasonError            FinishReason = "error"
 28	FinishReasonPermissionDenied FinishReason = "permission_denied"
 29
 30	// Should never happen
 31	FinishReasonUnknown FinishReason = "unknown"
 32)
 33
 34type ContentPart interface {
 35	isPart()
 36}
 37
 38type ReasoningContent struct {
 39	Thinking   string `json:"thinking"`
 40	Signature  string `json:"signature"`
 41	StartedAt  int64  `json:"started_at,omitempty"`
 42	FinishedAt int64  `json:"finished_at,omitempty"`
 43}
 44
 45func (tc ReasoningContent) String() string {
 46	return tc.Thinking
 47}
 48func (ReasoningContent) isPart() {}
 49
 50type TextContent struct {
 51	Text string `json:"text"`
 52}
 53
 54func (tc TextContent) String() string {
 55	return tc.Text
 56}
 57
 58func (TextContent) isPart() {}
 59
 60type ImageURLContent struct {
 61	URL    string `json:"url"`
 62	Detail string `json:"detail,omitempty"`
 63}
 64
 65func (iuc ImageURLContent) String() string {
 66	return iuc.URL
 67}
 68
 69func (ImageURLContent) isPart() {}
 70
 71type BinaryContent struct {
 72	Path     string
 73	MIMEType string
 74	Data     []byte
 75}
 76
 77func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 78	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 79	if p == catwalk.InferenceProviderOpenAI {
 80		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 81	}
 82	return base64Encoded
 83}
 84
 85func (BinaryContent) isPart() {}
 86
 87type ToolCall struct {
 88	ID       string `json:"id"`
 89	Name     string `json:"name"`
 90	Input    string `json:"input"`
 91	Type     string `json:"type"`
 92	Finished bool   `json:"finished"`
 93}
 94
 95func (ToolCall) isPart() {}
 96
 97type ToolResult struct {
 98	ToolCallID string `json:"tool_call_id"`
 99	Name       string `json:"name"`
100	Content    string `json:"content"`
101	Metadata   string `json:"metadata"`
102	IsError    bool   `json:"is_error"`
103}
104
105func (ToolResult) isPart() {}
106
107type Finish struct {
108	Reason  FinishReason `json:"reason"`
109	Time    int64        `json:"time"`
110	Message string       `json:"message,omitempty"`
111	Details string       `json:"details,omitempty"`
112}
113
114func (Finish) isPart() {}
115
116type Retry struct {
117	Error      string `json:"error"`
118	RetryAfter int64  `json:"retry_after"`
119	Timestamp  int64  `json:"timestamp"`
120}
121
122type RetryContent struct {
123	Retries  []Retry `json:"retries"`
124	Retrying bool    `json:"retrying"`
125}
126
127func (RetryContent) isPart() {}
128
129type Message struct {
130	ID        string
131	Role      MessageRole
132	SessionID string
133	Parts     []ContentPart
134	Model     string
135	Provider  string
136	CreatedAt int64
137	UpdatedAt int64
138}
139
140func (m *Message) Content() TextContent {
141	for _, part := range m.Parts {
142		if c, ok := part.(TextContent); ok {
143			return c
144		}
145	}
146	return TextContent{}
147}
148
149func (m *Message) ReasoningContent() ReasoningContent {
150	for _, part := range m.Parts {
151		if c, ok := part.(ReasoningContent); ok {
152			return c
153		}
154	}
155	return ReasoningContent{}
156}
157
158func (m *Message) ImageURLContent() []ImageURLContent {
159	imageURLContents := make([]ImageURLContent, 0)
160	for _, part := range m.Parts {
161		if c, ok := part.(ImageURLContent); ok {
162			imageURLContents = append(imageURLContents, c)
163		}
164	}
165	return imageURLContents
166}
167
168func (m *Message) BinaryContent() []BinaryContent {
169	binaryContents := make([]BinaryContent, 0)
170	for _, part := range m.Parts {
171		if c, ok := part.(BinaryContent); ok {
172			binaryContents = append(binaryContents, c)
173		}
174	}
175	return binaryContents
176}
177
178func (m *Message) ToolCalls() []ToolCall {
179	toolCalls := make([]ToolCall, 0)
180	for _, part := range m.Parts {
181		if c, ok := part.(ToolCall); ok {
182			toolCalls = append(toolCalls, c)
183		}
184	}
185	return toolCalls
186}
187
188func (m *Message) ToolResults() []ToolResult {
189	toolResults := make([]ToolResult, 0)
190	for _, part := range m.Parts {
191		if c, ok := part.(ToolResult); ok {
192			toolResults = append(toolResults, c)
193		}
194	}
195	return toolResults
196}
197
198func (m *Message) IsFinished() bool {
199	for _, part := range m.Parts {
200		if _, ok := part.(Finish); ok {
201			return true
202		}
203	}
204	return false
205}
206
207func (m *Message) FinishPart() *Finish {
208	for _, part := range m.Parts {
209		if c, ok := part.(Finish); ok {
210			return &c
211		}
212	}
213	return nil
214}
215
216func (m *Message) FinishReason() FinishReason {
217	for _, part := range m.Parts {
218		if c, ok := part.(Finish); ok {
219			return c.Reason
220		}
221	}
222	return ""
223}
224
225func (m *Message) IsThinking() bool {
226	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
227		return true
228	}
229	return false
230}
231
232func (m *Message) AppendContent(delta string) {
233	found := false
234	for i, part := range m.Parts {
235		if c, ok := part.(TextContent); ok {
236			m.Parts[i] = TextContent{Text: c.Text + delta}
237			found = true
238		}
239	}
240	if !found {
241		m.Parts = append(m.Parts, TextContent{Text: delta})
242	}
243}
244
245func (m *Message) AppendReasoningContent(delta string) {
246	found := false
247	for i, part := range m.Parts {
248		if c, ok := part.(ReasoningContent); ok {
249			m.Parts[i] = ReasoningContent{
250				Thinking:   c.Thinking + delta,
251				Signature:  c.Signature,
252				StartedAt:  c.StartedAt,
253				FinishedAt: c.FinishedAt,
254			}
255			found = true
256		}
257	}
258	if !found {
259		m.Parts = append(m.Parts, ReasoningContent{
260			Thinking:  delta,
261			StartedAt: time.Now().Unix(),
262		})
263	}
264}
265
266func (m *Message) AppendReasoningSignature(signature string) {
267	for i, part := range m.Parts {
268		if c, ok := part.(ReasoningContent); ok {
269			m.Parts[i] = ReasoningContent{
270				Thinking:   c.Thinking,
271				Signature:  c.Signature + signature,
272				StartedAt:  c.StartedAt,
273				FinishedAt: c.FinishedAt,
274			}
275			return
276		}
277	}
278	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
279}
280
281func (m *Message) FinishThinking() {
282	for i, part := range m.Parts {
283		if c, ok := part.(ReasoningContent); ok {
284			if c.FinishedAt == 0 {
285				m.Parts[i] = ReasoningContent{
286					Thinking:   c.Thinking,
287					Signature:  c.Signature,
288					StartedAt:  c.StartedAt,
289					FinishedAt: time.Now().Unix(),
290				}
291			}
292			return
293		}
294	}
295}
296
297func (m *Message) ThinkingDuration() time.Duration {
298	reasoning := m.ReasoningContent()
299	if reasoning.StartedAt == 0 {
300		return 0
301	}
302
303	endTime := reasoning.FinishedAt
304	if endTime == 0 {
305		endTime = time.Now().Unix()
306	}
307
308	return time.Duration(endTime-reasoning.StartedAt) * time.Second
309}
310
311func (m *Message) FinishToolCall(toolCallID string) {
312	for i, part := range m.Parts {
313		if c, ok := part.(ToolCall); ok {
314			if c.ID == toolCallID {
315				m.Parts[i] = ToolCall{
316					ID:       c.ID,
317					Name:     c.Name,
318					Input:    c.Input,
319					Type:     c.Type,
320					Finished: true,
321				}
322				return
323			}
324		}
325	}
326}
327
328func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
329	for i, part := range m.Parts {
330		if c, ok := part.(ToolCall); ok {
331			if c.ID == toolCallID {
332				m.Parts[i] = ToolCall{
333					ID:       c.ID,
334					Name:     c.Name,
335					Input:    c.Input + inputDelta,
336					Type:     c.Type,
337					Finished: c.Finished,
338				}
339				return
340			}
341		}
342	}
343}
344
345func (m *Message) AddToolCall(tc ToolCall) {
346	for i, part := range m.Parts {
347		if c, ok := part.(ToolCall); ok {
348			if c.ID == tc.ID {
349				m.Parts[i] = tc
350				return
351			}
352		}
353	}
354	m.Parts = append(m.Parts, tc)
355}
356
357func (m *Message) SetToolCalls(tc []ToolCall) {
358	// remove any existing tool call part it could have multiple
359	parts := make([]ContentPart, 0)
360	for _, part := range m.Parts {
361		if _, ok := part.(ToolCall); ok {
362			continue
363		}
364		parts = append(parts, part)
365	}
366	m.Parts = parts
367	for _, toolCall := range tc {
368		m.Parts = append(m.Parts, toolCall)
369	}
370}
371
372func (m *Message) AddToolResult(tr ToolResult) {
373	m.Parts = append(m.Parts, tr)
374}
375
376func (m *Message) SetToolResults(tr []ToolResult) {
377	for _, toolResult := range tr {
378		m.Parts = append(m.Parts, toolResult)
379	}
380}
381
382func (m *Message) AddFinish(reason FinishReason, message, details string) {
383	// remove any existing finish part
384	for i, part := range m.Parts {
385		if _, ok := part.(Finish); ok {
386			m.Parts = slices.Delete(m.Parts, i, i+1)
387			break
388		}
389	}
390	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
391}
392
393func (m *Message) AddImageURL(url, detail string) {
394	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
395}
396
397func (m *Message) AddBinary(mimeType string, data []byte) {
398	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
399}
400
401func (m *Message) RetryContent() *RetryContent {
402	for _, part := range m.Parts {
403		if c, ok := part.(RetryContent); ok {
404			return &c
405		}
406	}
407	return nil
408}
409
410func (m *Message) AddRetry(error string, retryAfter int64) {
411	retry := Retry{
412		Error:      error,
413		RetryAfter: retryAfter,
414		Timestamp:  time.Now().Unix(),
415	}
416
417	found := false
418	for i, part := range m.Parts {
419		if c, ok := part.(RetryContent); ok {
420			m.Parts[i] = RetryContent{
421				Retries:  append(c.Retries, retry),
422				Retrying: c.Retrying,
423			}
424			found = true
425			break
426		}
427	}
428	if !found {
429		m.Parts = append(m.Parts, RetryContent{
430			Retries:  []Retry{retry},
431			Retrying: false,
432		})
433	}
434}
435
436func (m *Message) SetRetrying(retrying bool) {
437	found := false
438	for i, part := range m.Parts {
439		if c, ok := part.(RetryContent); ok {
440			m.Parts[i] = RetryContent{
441				Retries:  c.Retries,
442				Retrying: retrying,
443			}
444			found = true
445			break
446		}
447	}
448	if !found && retrying {
449		m.Parts = append(m.Parts, RetryContent{
450			Retries:  []Retry{},
451			Retrying: retrying,
452		})
453	}
454}
455
456func (m *Message) IsRetrying() bool {
457	if retryContent := m.RetryContent(); retryContent != nil {
458		return retryContent.Retrying
459	}
460	return false
461}
462
463func (m *Message) GetRetries() []Retry {
464	if retryContent := m.RetryContent(); retryContent != nil {
465		return retryContent.Retries
466	}
467	return []Retry{}
468}