content.go

  1package message
  2
  3import (
  4	"encoding/base64"
  5	"errors"
  6	"slices"
  7	"time"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/fantasy/ai"
 11	"github.com/charmbracelet/fantasy/anthropic"
 12)
 13
 14type MessageRole string
 15
 16const (
 17	Assistant MessageRole = "assistant"
 18	User      MessageRole = "user"
 19	System    MessageRole = "system"
 20	Tool      MessageRole = "tool"
 21)
 22
 23type FinishReason string
 24
 25const (
 26	FinishReasonEndTurn          FinishReason = "end_turn"
 27	FinishReasonMaxTokens        FinishReason = "max_tokens"
 28	FinishReasonToolUse          FinishReason = "tool_use"
 29	FinishReasonCanceled         FinishReason = "canceled"
 30	FinishReasonError            FinishReason = "error"
 31	FinishReasonPermissionDenied FinishReason = "permission_denied"
 32
 33	// Should never happen
 34	FinishReasonUnknown FinishReason = "unknown"
 35)
 36
 37type ContentPart interface {
 38	isPart()
 39}
 40
 41type ReasoningContent struct {
 42	Thinking   string `json:"thinking"`
 43	Signature  string `json:"signature"`
 44	StartedAt  int64  `json:"started_at,omitempty"`
 45	FinishedAt int64  `json:"finished_at,omitempty"`
 46}
 47
 48func (tc ReasoningContent) String() string {
 49	return tc.Thinking
 50}
 51func (ReasoningContent) isPart() {}
 52
 53type TextContent struct {
 54	Text string `json:"text"`
 55}
 56
 57func (tc TextContent) String() string {
 58	return tc.Text
 59}
 60
 61func (TextContent) isPart() {}
 62
 63type ImageURLContent struct {
 64	URL    string `json:"url"`
 65	Detail string `json:"detail,omitempty"`
 66}
 67
 68func (iuc ImageURLContent) String() string {
 69	return iuc.URL
 70}
 71
 72func (ImageURLContent) isPart() {}
 73
 74type BinaryContent struct {
 75	Path     string
 76	MIMEType string
 77	Data     []byte
 78}
 79
 80func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 81	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 82	if p == catwalk.InferenceProviderOpenAI {
 83		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 84	}
 85	return base64Encoded
 86}
 87
 88func (BinaryContent) isPart() {}
 89
 90type ToolCall struct {
 91	ID               string `json:"id"`
 92	Name             string `json:"name"`
 93	Input            string `json:"input"`
 94	ProviderExecuted bool   `json:"provider_executed"`
 95	Type             string `json:"type"`
 96	Finished         bool   `json:"finished"`
 97}
 98
 99func (ToolCall) isPart() {}
100
101type ToolResult struct {
102	ToolCallID string `json:"tool_call_id"`
103	Name       string `json:"name"`
104	Content    string `json:"content"`
105	Data       string `json:"data"`
106	MIMEType   string `json:"mime_type"`
107	Metadata   string `json:"metadata"`
108	IsError    bool   `json:"is_error"`
109}
110
111func (ToolResult) isPart() {}
112
113type Finish struct {
114	Reason  FinishReason `json:"reason"`
115	Time    int64        `json:"time"`
116	Message string       `json:"message,omitempty"`
117	Details string       `json:"details,omitempty"`
118}
119
120func (Finish) isPart() {}
121
122type Message struct {
123	ID        string
124	Role      MessageRole
125	SessionID string
126	Parts     []ContentPart
127	Model     string
128	Provider  string
129	CreatedAt int64
130	UpdatedAt int64
131}
132
133func (m *Message) Content() TextContent {
134	for _, part := range m.Parts {
135		if c, ok := part.(TextContent); ok {
136			return c
137		}
138	}
139	return TextContent{}
140}
141
142func (m *Message) ReasoningContent() ReasoningContent {
143	for _, part := range m.Parts {
144		if c, ok := part.(ReasoningContent); ok {
145			return c
146		}
147	}
148	return ReasoningContent{}
149}
150
151func (m *Message) ImageURLContent() []ImageURLContent {
152	imageURLContents := make([]ImageURLContent, 0)
153	for _, part := range m.Parts {
154		if c, ok := part.(ImageURLContent); ok {
155			imageURLContents = append(imageURLContents, c)
156		}
157	}
158	return imageURLContents
159}
160
161func (m *Message) BinaryContent() []BinaryContent {
162	binaryContents := make([]BinaryContent, 0)
163	for _, part := range m.Parts {
164		if c, ok := part.(BinaryContent); ok {
165			binaryContents = append(binaryContents, c)
166		}
167	}
168	return binaryContents
169}
170
171func (m *Message) ToolCalls() []ToolCall {
172	toolCalls := make([]ToolCall, 0)
173	for _, part := range m.Parts {
174		if c, ok := part.(ToolCall); ok {
175			toolCalls = append(toolCalls, c)
176		}
177	}
178	return toolCalls
179}
180
181func (m *Message) ToolResults() []ToolResult {
182	toolResults := make([]ToolResult, 0)
183	for _, part := range m.Parts {
184		if c, ok := part.(ToolResult); ok {
185			toolResults = append(toolResults, c)
186		}
187	}
188	return toolResults
189}
190
191func (m *Message) IsFinished() bool {
192	for _, part := range m.Parts {
193		if _, ok := part.(Finish); ok {
194			return true
195		}
196	}
197	return false
198}
199
200func (m *Message) FinishPart() *Finish {
201	for _, part := range m.Parts {
202		if c, ok := part.(Finish); ok {
203			return &c
204		}
205	}
206	return nil
207}
208
209func (m *Message) FinishReason() FinishReason {
210	for _, part := range m.Parts {
211		if c, ok := part.(Finish); ok {
212			return c.Reason
213		}
214	}
215	return ""
216}
217
218func (m *Message) IsThinking() bool {
219	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
220		return true
221	}
222	return false
223}
224
225func (m *Message) AppendContent(delta string) {
226	found := false
227	for i, part := range m.Parts {
228		if c, ok := part.(TextContent); ok {
229			m.Parts[i] = TextContent{Text: c.Text + delta}
230			found = true
231		}
232	}
233	if !found {
234		m.Parts = append(m.Parts, TextContent{Text: delta})
235	}
236}
237
238func (m *Message) AppendReasoningContent(delta string) {
239	found := false
240	for i, part := range m.Parts {
241		if c, ok := part.(ReasoningContent); ok {
242			m.Parts[i] = ReasoningContent{
243				Thinking:   c.Thinking + delta,
244				Signature:  c.Signature,
245				StartedAt:  c.StartedAt,
246				FinishedAt: c.FinishedAt,
247			}
248			found = true
249		}
250	}
251	if !found {
252		m.Parts = append(m.Parts, ReasoningContent{
253			Thinking:  delta,
254			StartedAt: time.Now().Unix(),
255		})
256	}
257}
258
259func (m *Message) AppendReasoningSignature(signature string) {
260	for i, part := range m.Parts {
261		if c, ok := part.(ReasoningContent); ok {
262			m.Parts[i] = ReasoningContent{
263				Thinking:   c.Thinking,
264				Signature:  c.Signature + signature,
265				StartedAt:  c.StartedAt,
266				FinishedAt: c.FinishedAt,
267			}
268			return
269		}
270	}
271	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
272}
273
274func (m *Message) FinishThinking() {
275	for i, part := range m.Parts {
276		if c, ok := part.(ReasoningContent); ok {
277			if c.FinishedAt == 0 {
278				m.Parts[i] = ReasoningContent{
279					Thinking:   c.Thinking,
280					Signature:  c.Signature,
281					StartedAt:  c.StartedAt,
282					FinishedAt: time.Now().Unix(),
283				}
284			}
285			return
286		}
287	}
288}
289
290func (m *Message) ThinkingDuration() time.Duration {
291	reasoning := m.ReasoningContent()
292	if reasoning.StartedAt == 0 {
293		return 0
294	}
295
296	endTime := reasoning.FinishedAt
297	if endTime == 0 {
298		endTime = time.Now().Unix()
299	}
300
301	return time.Duration(endTime-reasoning.StartedAt) * time.Second
302}
303
304func (m *Message) FinishToolCall(toolCallID string) {
305	for i, part := range m.Parts {
306		if c, ok := part.(ToolCall); ok {
307			if c.ID == toolCallID {
308				m.Parts[i] = ToolCall{
309					ID:       c.ID,
310					Name:     c.Name,
311					Input:    c.Input,
312					Type:     c.Type,
313					Finished: true,
314				}
315				return
316			}
317		}
318	}
319}
320
321func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
322	for i, part := range m.Parts {
323		if c, ok := part.(ToolCall); ok {
324			if c.ID == toolCallID {
325				m.Parts[i] = ToolCall{
326					ID:       c.ID,
327					Name:     c.Name,
328					Input:    c.Input + inputDelta,
329					Type:     c.Type,
330					Finished: c.Finished,
331				}
332				return
333			}
334		}
335	}
336}
337
338func (m *Message) AddToolCall(tc ToolCall) {
339	for i, part := range m.Parts {
340		if c, ok := part.(ToolCall); ok {
341			if c.ID == tc.ID {
342				m.Parts[i] = tc
343				return
344			}
345		}
346	}
347	m.Parts = append(m.Parts, tc)
348}
349
350func (m *Message) SetToolCalls(tc []ToolCall) {
351	// remove any existing tool call part it could have multiple
352	parts := make([]ContentPart, 0)
353	for _, part := range m.Parts {
354		if _, ok := part.(ToolCall); ok {
355			continue
356		}
357		parts = append(parts, part)
358	}
359	m.Parts = parts
360	for _, toolCall := range tc {
361		m.Parts = append(m.Parts, toolCall)
362	}
363}
364
365func (m *Message) AddToolResult(tr ToolResult) {
366	m.Parts = append(m.Parts, tr)
367}
368
369func (m *Message) SetToolResults(tr []ToolResult) {
370	for _, toolResult := range tr {
371		m.Parts = append(m.Parts, toolResult)
372	}
373}
374
375func (m *Message) AddFinish(reason FinishReason, message, details string) {
376	// remove any existing finish part
377	for i, part := range m.Parts {
378		if _, ok := part.(Finish); ok {
379			m.Parts = slices.Delete(m.Parts, i, i+1)
380			break
381		}
382	}
383	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
384}
385
386func (m *Message) AddImageURL(url, detail string) {
387	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
388}
389
390func (m *Message) AddBinary(mimeType string, data []byte) {
391	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
392}
393
394func (m *Message) ToAIMessage() []ai.Message {
395	var messages []ai.Message
396	switch m.Role {
397	case User:
398		var parts []ai.MessagePart
399		if m.Content().Text != "" {
400			parts = append(parts, ai.TextPart{Text: m.Content().Text})
401		}
402		for _, content := range m.BinaryContent() {
403			parts = append(parts, ai.FilePart{
404				Filename:  content.Path,
405				Data:      content.Data,
406				MediaType: content.MIMEType,
407			})
408		}
409		messages = append(messages, ai.Message{
410			Role:    ai.MessageRoleUser,
411			Content: parts,
412		})
413	case Assistant:
414		var parts []ai.MessagePart
415		if m.Content().Text != "" {
416			parts = append(parts, ai.TextPart{Text: m.Content().Text})
417		}
418		reasoning := m.ReasoningContent()
419		if reasoning.Thinking != "" {
420			reasoningPart := ai.ReasoningPart{Text: reasoning.Thinking, ProviderOptions: ai.ProviderOptions{}}
421			if reasoning.Signature != "" {
422				reasoningPart.ProviderOptions["anthropic"] = &anthropic.ReasoningOptionMetadata{
423					Signature: reasoning.Signature,
424				}
425			}
426			parts = append(parts, reasoningPart)
427		}
428		for _, call := range m.ToolCalls() {
429			parts = append(parts, ai.ToolCallPart{
430				ToolCallID:       call.ID,
431				ToolName:         call.Name,
432				Input:            call.Input,
433				ProviderExecuted: call.ProviderExecuted,
434			})
435		}
436		messages = append(messages, ai.Message{
437			Role:    ai.MessageRoleAssistant,
438			Content: parts,
439		})
440	case Tool:
441		var parts []ai.MessagePart
442		for _, result := range m.ToolResults() {
443			var content ai.ToolResultOutputContent
444			if result.IsError {
445				content = ai.ToolResultOutputContentError{
446					Error: errors.New(result.Content),
447				}
448			} else if result.Data != "" {
449				content = ai.ToolResultOutputContentMedia{
450					Data:      result.Data,
451					MediaType: result.MIMEType,
452				}
453			} else {
454				content = ai.ToolResultOutputContentText{
455					Text: result.Content,
456				}
457			}
458			parts = append(parts, ai.ToolResultPart{
459				ToolCallID: result.ToolCallID,
460				Output:     content,
461			})
462		}
463		messages = append(messages, ai.Message{
464			Role:    ai.MessageRoleTool,
465			Content: parts,
466		})
467	}
468	return messages
469}