message.go

  1package proto
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"slices"
  8	"time"
  9
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11)
 12
 13type CreateMessageParams struct {
 14	Role     MessageRole   `json:"role"`
 15	Parts    []ContentPart `json:"parts"`
 16	Model    string        `json:"model"`
 17	Provider string        `json:"provider,omitempty"`
 18}
 19
 20type Message struct {
 21	ID        string        `json:"id"`
 22	Role      MessageRole   `json:"role"`
 23	SessionID string        `json:"session_id"`
 24	Parts     []ContentPart `json:"parts"`
 25	Model     string        `json:"model"`
 26	Provider  string        `json:"provider"`
 27	CreatedAt int64         `json:"created_at"`
 28	UpdatedAt int64         `json:"updated_at"`
 29}
 30type MessageRole string
 31
 32const (
 33	Assistant MessageRole = "assistant"
 34	User      MessageRole = "user"
 35	System    MessageRole = "system"
 36	Tool      MessageRole = "tool"
 37)
 38
 39func (r MessageRole) MarshalText() ([]byte, error) {
 40	return []byte(r), nil
 41}
 42
 43func (r *MessageRole) UnmarshalText(data []byte) error {
 44	*r = MessageRole(data)
 45	return nil
 46}
 47
 48type FinishReason string
 49
 50const (
 51	FinishReasonEndTurn          FinishReason = "end_turn"
 52	FinishReasonMaxTokens        FinishReason = "max_tokens"
 53	FinishReasonToolUse          FinishReason = "tool_use"
 54	FinishReasonCanceled         FinishReason = "canceled"
 55	FinishReasonError            FinishReason = "error"
 56	FinishReasonPermissionDenied FinishReason = "permission_denied"
 57
 58	// Should never happen
 59	FinishReasonUnknown FinishReason = "unknown"
 60)
 61
 62func (fr FinishReason) MarshalText() ([]byte, error) {
 63	return []byte(fr), nil
 64}
 65
 66func (fr *FinishReason) UnmarshalText(data []byte) error {
 67	*fr = FinishReason(data)
 68	return nil
 69}
 70
 71type ContentPart interface {
 72	isPart()
 73}
 74
 75type ReasoningContent struct {
 76	Thinking   string `json:"thinking"`
 77	Signature  string `json:"signature"`
 78	StartedAt  int64  `json:"started_at,omitempty"`
 79	FinishedAt int64  `json:"finished_at,omitempty"`
 80}
 81
 82func (tc ReasoningContent) String() string {
 83	return tc.Thinking
 84}
 85func (ReasoningContent) isPart() {}
 86
 87type TextContent struct {
 88	Text string `json:"text"`
 89}
 90
 91func (tc TextContent) String() string {
 92	return tc.Text
 93}
 94
 95func (TextContent) isPart() {}
 96
 97type ImageURLContent struct {
 98	URL    string `json:"url"`
 99	Detail string `json:"detail,omitempty"`
100}
101
102func (iuc ImageURLContent) String() string {
103	return iuc.URL
104}
105
106func (ImageURLContent) isPart() {}
107
108type BinaryContent struct {
109	Path     string
110	MIMEType string
111	Data     []byte
112}
113
114func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
115	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
116	if p == catwalk.InferenceProviderOpenAI {
117		return "data:" + bc.MIMEType + ";base64," + base64Encoded
118	}
119	return base64Encoded
120}
121
122func (BinaryContent) isPart() {}
123
124type ToolCall struct {
125	ID       string `json:"id"`
126	Name     string `json:"name"`
127	Input    string `json:"input"`
128	Type     string `json:"type"`
129	Finished bool   `json:"finished"`
130}
131
132func (ToolCall) isPart() {}
133
134type ToolResult struct {
135	ToolCallID string `json:"tool_call_id"`
136	Name       string `json:"name"`
137	Content    string `json:"content"`
138	Metadata   string `json:"metadata"`
139	IsError    bool   `json:"is_error"`
140}
141
142func (ToolResult) isPart() {}
143
144type Finish struct {
145	Reason  FinishReason `json:"reason"`
146	Time    int64        `json:"time"`
147	Message string       `json:"message,omitempty"`
148	Details string       `json:"details,omitempty"`
149}
150
151func (Finish) isPart() {}
152
153// MarshalJSON implements the [json.Marshaler] interface.
154func (m Message) MarshalJSON() ([]byte, error) {
155	// We need to handle the Parts specially since they're ContentPart interfaces
156	// which can't be directly marshaled by the standard JSON package.
157	parts, err := MarshallParts(m.Parts)
158	if err != nil {
159		return nil, err
160	}
161
162	// Create an alias to avoid infinite recursion
163	type Alias Message
164	return json.Marshal(&struct {
165		Parts json.RawMessage `json:"parts"`
166		*Alias
167	}{
168		Parts: json.RawMessage(parts),
169		Alias: (*Alias)(&m),
170	})
171}
172
173// UnmarshalJSON implements the [json.Unmarshaler] interface.
174func (m *Message) UnmarshalJSON(data []byte) error {
175	// Create an alias to avoid infinite recursion
176	type Alias Message
177	aux := &struct {
178		Parts json.RawMessage `json:"parts"`
179		*Alias
180	}{
181		Alias: (*Alias)(m),
182	}
183
184	if err := json.Unmarshal(data, &aux); err != nil {
185		return err
186	}
187
188	// Unmarshal the parts using our custom function
189	parts, err := UnmarshallParts([]byte(aux.Parts))
190	if err != nil {
191		return err
192	}
193
194	m.Parts = parts
195	return nil
196}
197
198func (m *Message) Content() TextContent {
199	for _, part := range m.Parts {
200		if c, ok := part.(TextContent); ok {
201			return c
202		}
203	}
204	return TextContent{}
205}
206
207func (m *Message) ReasoningContent() ReasoningContent {
208	for _, part := range m.Parts {
209		if c, ok := part.(ReasoningContent); ok {
210			return c
211		}
212	}
213	return ReasoningContent{}
214}
215
216func (m *Message) ImageURLContent() []ImageURLContent {
217	imageURLContents := make([]ImageURLContent, 0)
218	for _, part := range m.Parts {
219		if c, ok := part.(ImageURLContent); ok {
220			imageURLContents = append(imageURLContents, c)
221		}
222	}
223	return imageURLContents
224}
225
226func (m *Message) BinaryContent() []BinaryContent {
227	binaryContents := make([]BinaryContent, 0)
228	for _, part := range m.Parts {
229		if c, ok := part.(BinaryContent); ok {
230			binaryContents = append(binaryContents, c)
231		}
232	}
233	return binaryContents
234}
235
236func (m *Message) ToolCalls() []ToolCall {
237	toolCalls := make([]ToolCall, 0)
238	for _, part := range m.Parts {
239		if c, ok := part.(ToolCall); ok {
240			toolCalls = append(toolCalls, c)
241		}
242	}
243	return toolCalls
244}
245
246func (m *Message) ToolResults() []ToolResult {
247	toolResults := make([]ToolResult, 0)
248	for _, part := range m.Parts {
249		if c, ok := part.(ToolResult); ok {
250			toolResults = append(toolResults, c)
251		}
252	}
253	return toolResults
254}
255
256func (m *Message) IsFinished() bool {
257	for _, part := range m.Parts {
258		if _, ok := part.(Finish); ok {
259			return true
260		}
261	}
262	return false
263}
264
265func (m *Message) FinishPart() *Finish {
266	for _, part := range m.Parts {
267		if c, ok := part.(Finish); ok {
268			return &c
269		}
270	}
271	return nil
272}
273
274func (m *Message) FinishReason() FinishReason {
275	for _, part := range m.Parts {
276		if c, ok := part.(Finish); ok {
277			return c.Reason
278		}
279	}
280	return ""
281}
282
283func (m *Message) IsThinking() bool {
284	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
285		return true
286	}
287	return false
288}
289
290func (m *Message) AppendContent(delta string) {
291	found := false
292	for i, part := range m.Parts {
293		if c, ok := part.(TextContent); ok {
294			m.Parts[i] = TextContent{Text: c.Text + delta}
295			found = true
296		}
297	}
298	if !found {
299		m.Parts = append(m.Parts, TextContent{Text: delta})
300	}
301}
302
303func (m *Message) AppendReasoningContent(delta string) {
304	found := false
305	for i, part := range m.Parts {
306		if c, ok := part.(ReasoningContent); ok {
307			m.Parts[i] = ReasoningContent{
308				Thinking:   c.Thinking + delta,
309				Signature:  c.Signature,
310				StartedAt:  c.StartedAt,
311				FinishedAt: c.FinishedAt,
312			}
313			found = true
314		}
315	}
316	if !found {
317		m.Parts = append(m.Parts, ReasoningContent{
318			Thinking:  delta,
319			StartedAt: time.Now().Unix(),
320		})
321	}
322}
323
324func (m *Message) AppendReasoningSignature(signature string) {
325	for i, part := range m.Parts {
326		if c, ok := part.(ReasoningContent); ok {
327			m.Parts[i] = ReasoningContent{
328				Thinking:   c.Thinking,
329				Signature:  c.Signature + signature,
330				StartedAt:  c.StartedAt,
331				FinishedAt: c.FinishedAt,
332			}
333			return
334		}
335	}
336	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
337}
338
339func (m *Message) FinishThinking() {
340	for i, part := range m.Parts {
341		if c, ok := part.(ReasoningContent); ok {
342			if c.FinishedAt == 0 {
343				m.Parts[i] = ReasoningContent{
344					Thinking:   c.Thinking,
345					Signature:  c.Signature,
346					StartedAt:  c.StartedAt,
347					FinishedAt: time.Now().Unix(),
348				}
349			}
350			return
351		}
352	}
353}
354
355func (m *Message) ThinkingDuration() time.Duration {
356	reasoning := m.ReasoningContent()
357	if reasoning.StartedAt == 0 {
358		return 0
359	}
360
361	endTime := reasoning.FinishedAt
362	if endTime == 0 {
363		endTime = time.Now().Unix()
364	}
365
366	return time.Duration(endTime-reasoning.StartedAt) * time.Second
367}
368
369func (m *Message) FinishToolCall(toolCallID string) {
370	for i, part := range m.Parts {
371		if c, ok := part.(ToolCall); ok {
372			if c.ID == toolCallID {
373				m.Parts[i] = ToolCall{
374					ID:       c.ID,
375					Name:     c.Name,
376					Input:    c.Input,
377					Type:     c.Type,
378					Finished: true,
379				}
380				return
381			}
382		}
383	}
384}
385
386func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
387	for i, part := range m.Parts {
388		if c, ok := part.(ToolCall); ok {
389			if c.ID == toolCallID {
390				m.Parts[i] = ToolCall{
391					ID:       c.ID,
392					Name:     c.Name,
393					Input:    c.Input + inputDelta,
394					Type:     c.Type,
395					Finished: c.Finished,
396				}
397				return
398			}
399		}
400	}
401}
402
403func (m *Message) AddToolCall(tc ToolCall) {
404	for i, part := range m.Parts {
405		if c, ok := part.(ToolCall); ok {
406			if c.ID == tc.ID {
407				m.Parts[i] = tc
408				return
409			}
410		}
411	}
412	m.Parts = append(m.Parts, tc)
413}
414
415func (m *Message) SetToolCalls(tc []ToolCall) {
416	// remove any existing tool call part it could have multiple
417	parts := make([]ContentPart, 0)
418	for _, part := range m.Parts {
419		if _, ok := part.(ToolCall); ok {
420			continue
421		}
422		parts = append(parts, part)
423	}
424	m.Parts = parts
425	for _, toolCall := range tc {
426		m.Parts = append(m.Parts, toolCall)
427	}
428}
429
430func (m *Message) AddToolResult(tr ToolResult) {
431	m.Parts = append(m.Parts, tr)
432}
433
434func (m *Message) SetToolResults(tr []ToolResult) {
435	for _, toolResult := range tr {
436		m.Parts = append(m.Parts, toolResult)
437	}
438}
439
440func (m *Message) AddFinish(reason FinishReason, message, details string) {
441	// remove any existing finish part
442	for i, part := range m.Parts {
443		if _, ok := part.(Finish); ok {
444			m.Parts = slices.Delete(m.Parts, i, i+1)
445			break
446		}
447	}
448	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
449}
450
451func (m *Message) AddImageURL(url, detail string) {
452	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
453}
454
455func (m *Message) AddBinary(mimeType string, data []byte) {
456	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
457}
458
459type partType string
460
461const (
462	reasoningType  partType = "reasoning"
463	textType       partType = "text"
464	imageURLType   partType = "image_url"
465	binaryType     partType = "binary"
466	toolCallType   partType = "tool_call"
467	toolResultType partType = "tool_result"
468	finishType     partType = "finish"
469)
470
471type partWrapper struct {
472	Type partType    `json:"type"`
473	Data ContentPart `json:"data"`
474}
475
476func MarshallParts(parts []ContentPart) ([]byte, error) {
477	wrappedParts := make([]partWrapper, len(parts))
478
479	for i, part := range parts {
480		var typ partType
481
482		switch part.(type) {
483		case ReasoningContent:
484			typ = reasoningType
485		case TextContent:
486			typ = textType
487		case ImageURLContent:
488			typ = imageURLType
489		case BinaryContent:
490			typ = binaryType
491		case ToolCall:
492			typ = toolCallType
493		case ToolResult:
494			typ = toolResultType
495		case Finish:
496			typ = finishType
497		default:
498			return nil, fmt.Errorf("unknown part type: %T", part)
499		}
500
501		wrappedParts[i] = partWrapper{
502			Type: typ,
503			Data: part,
504		}
505	}
506	return json.Marshal(wrappedParts)
507}
508
509func UnmarshallParts(data []byte) ([]ContentPart, error) {
510	temp := []json.RawMessage{}
511
512	if err := json.Unmarshal(data, &temp); err != nil {
513		return nil, err
514	}
515
516	parts := make([]ContentPart, 0)
517
518	for _, rawPart := range temp {
519		var wrapper struct {
520			Type partType        `json:"type"`
521			Data json.RawMessage `json:"data"`
522		}
523
524		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
525			return nil, err
526		}
527
528		switch wrapper.Type {
529		case reasoningType:
530			part := ReasoningContent{}
531			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
532				return nil, err
533			}
534			parts = append(parts, part)
535		case textType:
536			part := TextContent{}
537			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
538				return nil, err
539			}
540			parts = append(parts, part)
541		case imageURLType:
542			part := ImageURLContent{}
543			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
544				return nil, err
545			}
546		case binaryType:
547			part := BinaryContent{}
548			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
549				return nil, err
550			}
551			parts = append(parts, part)
552		case toolCallType:
553			part := ToolCall{}
554			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
555				return nil, err
556			}
557			parts = append(parts, part)
558		case toolResultType:
559			part := ToolResult{}
560			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
561				return nil, err
562			}
563			parts = append(parts, part)
564		case finishType:
565			part := Finish{}
566			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
567				return nil, err
568			}
569			parts = append(parts, part)
570		default:
571			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
572		}
573	}
574
575	return parts, nil
576}
577
578type Attachment struct {
579	FilePath string `json:"file_path"`
580	FileName string `json:"file_name"`
581	MimeType string `json:"mime_type"`
582	Content  []byte `json:"content"`
583}
584
585// MarshalJSON implements the [json.Marshaler] interface.
586func (a Attachment) MarshalJSON() ([]byte, error) {
587	// Encode the content as a base64 string
588	type Alias Attachment
589	return json.Marshal(&struct {
590		Content string `json:"content"`
591		*Alias
592	}{
593		Content: base64.StdEncoding.EncodeToString(a.Content),
594		Alias:   (*Alias)(&a),
595	})
596}
597
598// UnmarshalJSON implements the [json.Unmarshaler] interface.
599func (a *Attachment) UnmarshalJSON(data []byte) error {
600	// Decode the content from a base64 string
601	type Alias Attachment
602	aux := &struct {
603		Content string `json:"content"`
604		*Alias
605	}{
606		Alias: (*Alias)(a),
607	}
608	if err := json.Unmarshal(data, &aux); err != nil {
609		return err
610	}
611	content, err := base64.StdEncoding.DecodeString(aux.Content)
612	if err != nil {
613		return err
614	}
615	a.Content = content
616	return nil
617}