message.go

  1package proto
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"slices"
  8	"time"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14// CreateMessageParams represents parameters for creating a message.
 15type CreateMessageParams struct {
 16	Role     MessageRole   `json:"role"`
 17	Parts    []ContentPart `json:"parts"`
 18	Model    string        `json:"model"`
 19	Provider string        `json:"provider,omitempty"`
 20}
 21
 22// Message represents a message in the proto layer.
 23type Message struct {
 24	ID        string        `json:"id"`
 25	Role      MessageRole   `json:"role"`
 26	SessionID string        `json:"session_id"`
 27	Parts     []ContentPart `json:"parts"`
 28	Model     string        `json:"model"`
 29	Provider  string        `json:"provider"`
 30	CreatedAt int64         `json:"created_at"`
 31	UpdatedAt int64         `json:"updated_at"`
 32}
 33
 34// MessageRole represents the role of a message sender.
 35type MessageRole string
 36
 37const (
 38	Assistant MessageRole = "assistant"
 39	User      MessageRole = "user"
 40	System    MessageRole = "system"
 41	Tool      MessageRole = "tool"
 42)
 43
 44// MarshalText implements the [encoding.TextMarshaler] interface.
 45func (r MessageRole) MarshalText() ([]byte, error) {
 46	return []byte(r), nil
 47}
 48
 49// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 50func (r *MessageRole) UnmarshalText(data []byte) error {
 51	*r = MessageRole(data)
 52	return nil
 53}
 54
 55// FinishReason represents why a message generation finished.
 56type FinishReason string
 57
 58const (
 59	FinishReasonEndTurn   FinishReason = "end_turn"
 60	FinishReasonMaxTokens FinishReason = "max_tokens"
 61	FinishReasonToolUse   FinishReason = "tool_use"
 62	FinishReasonCanceled  FinishReason = "canceled"
 63	FinishReasonError     FinishReason = "error"
 64	FinishReasonUnknown   FinishReason = "unknown"
 65)
 66
 67// MarshalText implements the [encoding.TextMarshaler] interface.
 68func (fr FinishReason) MarshalText() ([]byte, error) {
 69	return []byte(fr), nil
 70}
 71
 72// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 73func (fr *FinishReason) UnmarshalText(data []byte) error {
 74	*fr = FinishReason(data)
 75	return nil
 76}
 77
 78// ContentPart is a part of a message's content.
 79type ContentPart interface {
 80	isPart()
 81}
 82
 83// ReasoningContent represents the reasoning/thinking part of a message.
 84type ReasoningContent struct {
 85	Thinking   string `json:"thinking"`
 86	Signature  string `json:"signature"`
 87	StartedAt  int64  `json:"started_at,omitempty"`
 88	FinishedAt int64  `json:"finished_at,omitempty"`
 89}
 90
 91// String returns the thinking content as a string.
 92func (tc ReasoningContent) String() string {
 93	return tc.Thinking
 94}
 95
 96func (ReasoningContent) isPart() {}
 97
 98// TextContent represents a text part of a message.
 99type TextContent struct {
100	Text string `json:"text"`
101}
102
103// String returns the text content as a string.
104func (tc TextContent) String() string {
105	return tc.Text
106}
107
108func (TextContent) isPart() {}
109
110// ImageURLContent represents an image URL part of a message.
111type ImageURLContent struct {
112	URL    string `json:"url"`
113	Detail string `json:"detail,omitempty"`
114}
115
116// String returns the image URL as a string.
117func (iuc ImageURLContent) String() string {
118	return iuc.URL
119}
120
121func (ImageURLContent) isPart() {}
122
123// BinaryContent represents binary data in a message.
124type BinaryContent struct {
125	Path     string
126	MIMEType string
127	Data     []byte
128}
129
130// String returns a base64-encoded string of the binary data.
131func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
132	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
133	if p == catwalk.InferenceProviderOpenAI {
134		return "data:" + bc.MIMEType + ";base64," + base64Encoded
135	}
136	return base64Encoded
137}
138
139func (BinaryContent) isPart() {}
140
141// ToolCall represents a tool call in a message.
142type ToolCall struct {
143	ID       string `json:"id"`
144	Name     string `json:"name"`
145	Input    string `json:"input"`
146	Type     string `json:"type,omitempty"`
147	Finished bool   `json:"finished,omitempty"`
148}
149
150func (ToolCall) isPart() {}
151
152// ToolResult represents the result of a tool call.
153type ToolResult struct {
154	ToolCallID string `json:"tool_call_id"`
155	Name       string `json:"name"`
156	Content    string `json:"content"`
157	Data       string `json:"data,omitempty"`
158	MIMEType   string `json:"mime_type,omitempty"`
159	Metadata   string `json:"metadata"`
160	IsError    bool   `json:"is_error"`
161}
162
163func (ToolResult) isPart() {}
164
165// Finish represents the end of a message generation.
166type Finish struct {
167	Reason  FinishReason `json:"reason"`
168	Time    int64        `json:"time"`
169	Message string       `json:"message,omitempty"`
170	Details string       `json:"details,omitempty"`
171}
172
173func (Finish) isPart() {}
174
175// MarshalJSON implements the [json.Marshaler] interface.
176func (m Message) MarshalJSON() ([]byte, error) {
177	parts, err := MarshalParts(m.Parts)
178	if err != nil {
179		return nil, err
180	}
181
182	type Alias Message
183	return json.Marshal(&struct {
184		Parts json.RawMessage `json:"parts"`
185		*Alias
186	}{
187		Parts: json.RawMessage(parts),
188		Alias: (*Alias)(&m),
189	})
190}
191
192// UnmarshalJSON implements the [json.Unmarshaler] interface.
193func (m *Message) UnmarshalJSON(data []byte) error {
194	type Alias Message
195	aux := &struct {
196		Parts json.RawMessage `json:"parts"`
197		*Alias
198	}{
199		Alias: (*Alias)(m),
200	}
201
202	if err := json.Unmarshal(data, &aux); err != nil {
203		return err
204	}
205
206	parts, err := UnmarshalParts([]byte(aux.Parts))
207	if err != nil {
208		return err
209	}
210
211	m.Parts = parts
212	return nil
213}
214
215// Content returns the first text content part.
216func (m *Message) Content() TextContent {
217	for _, part := range m.Parts {
218		if c, ok := part.(TextContent); ok {
219			return c
220		}
221	}
222	return TextContent{}
223}
224
225// ReasoningContent returns the first reasoning content part.
226func (m *Message) ReasoningContent() ReasoningContent {
227	for _, part := range m.Parts {
228		if c, ok := part.(ReasoningContent); ok {
229			return c
230		}
231	}
232	return ReasoningContent{}
233}
234
235// ImageURLContent returns all image URL content parts.
236func (m *Message) ImageURLContent() []ImageURLContent {
237	imageURLContents := make([]ImageURLContent, 0)
238	for _, part := range m.Parts {
239		if c, ok := part.(ImageURLContent); ok {
240			imageURLContents = append(imageURLContents, c)
241		}
242	}
243	return imageURLContents
244}
245
246// BinaryContent returns all binary content parts.
247func (m *Message) BinaryContent() []BinaryContent {
248	binaryContents := make([]BinaryContent, 0)
249	for _, part := range m.Parts {
250		if c, ok := part.(BinaryContent); ok {
251			binaryContents = append(binaryContents, c)
252		}
253	}
254	return binaryContents
255}
256
257// ToolCalls returns all tool call parts.
258func (m *Message) ToolCalls() []ToolCall {
259	toolCalls := make([]ToolCall, 0)
260	for _, part := range m.Parts {
261		if c, ok := part.(ToolCall); ok {
262			toolCalls = append(toolCalls, c)
263		}
264	}
265	return toolCalls
266}
267
268// ToolResults returns all tool result parts.
269func (m *Message) ToolResults() []ToolResult {
270	toolResults := make([]ToolResult, 0)
271	for _, part := range m.Parts {
272		if c, ok := part.(ToolResult); ok {
273			toolResults = append(toolResults, c)
274		}
275	}
276	return toolResults
277}
278
279// IsFinished returns true if the message has a finish part.
280func (m *Message) IsFinished() bool {
281	for _, part := range m.Parts {
282		if _, ok := part.(Finish); ok {
283			return true
284		}
285	}
286	return false
287}
288
289// FinishPart returns the finish part if present.
290func (m *Message) FinishPart() *Finish {
291	for _, part := range m.Parts {
292		if c, ok := part.(Finish); ok {
293			return &c
294		}
295	}
296	return nil
297}
298
299// FinishReason returns the finish reason if present.
300func (m *Message) FinishReason() FinishReason {
301	for _, part := range m.Parts {
302		if c, ok := part.(Finish); ok {
303			return c.Reason
304		}
305	}
306	return ""
307}
308
309// IsThinking returns true if the message is currently in a thinking state.
310func (m *Message) IsThinking() bool {
311	return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
312}
313
314// AppendContent appends text to the text content part.
315func (m *Message) AppendContent(delta string) {
316	found := false
317	for i, part := range m.Parts {
318		if c, ok := part.(TextContent); ok {
319			m.Parts[i] = TextContent{Text: c.Text + delta}
320			found = true
321		}
322	}
323	if !found {
324		m.Parts = append(m.Parts, TextContent{Text: delta})
325	}
326}
327
328// AppendReasoningContent appends text to the reasoning content part.
329func (m *Message) AppendReasoningContent(delta string) {
330	found := false
331	for i, part := range m.Parts {
332		if c, ok := part.(ReasoningContent); ok {
333			m.Parts[i] = ReasoningContent{
334				Thinking:   c.Thinking + delta,
335				Signature:  c.Signature,
336				StartedAt:  c.StartedAt,
337				FinishedAt: c.FinishedAt,
338			}
339			found = true
340		}
341	}
342	if !found {
343		m.Parts = append(m.Parts, ReasoningContent{
344			Thinking:  delta,
345			StartedAt: time.Now().Unix(),
346		})
347	}
348}
349
350// AppendReasoningSignature appends a signature to the reasoning content part.
351func (m *Message) AppendReasoningSignature(signature string) {
352	for i, part := range m.Parts {
353		if c, ok := part.(ReasoningContent); ok {
354			m.Parts[i] = ReasoningContent{
355				Thinking:   c.Thinking,
356				Signature:  c.Signature + signature,
357				StartedAt:  c.StartedAt,
358				FinishedAt: c.FinishedAt,
359			}
360			return
361		}
362	}
363	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
364}
365
366// FinishThinking marks the reasoning content as finished.
367func (m *Message) FinishThinking() {
368	for i, part := range m.Parts {
369		if c, ok := part.(ReasoningContent); ok {
370			if c.FinishedAt == 0 {
371				m.Parts[i] = ReasoningContent{
372					Thinking:   c.Thinking,
373					Signature:  c.Signature,
374					StartedAt:  c.StartedAt,
375					FinishedAt: time.Now().Unix(),
376				}
377			}
378			return
379		}
380	}
381}
382
383// ThinkingDuration returns the duration of the thinking phase.
384func (m *Message) ThinkingDuration() time.Duration {
385	reasoning := m.ReasoningContent()
386	if reasoning.StartedAt == 0 {
387		return 0
388	}
389
390	endTime := reasoning.FinishedAt
391	if endTime == 0 {
392		endTime = time.Now().Unix()
393	}
394
395	return time.Duration(endTime-reasoning.StartedAt) * time.Second
396}
397
398// FinishToolCall marks a tool call as finished.
399func (m *Message) FinishToolCall(toolCallID string) {
400	for i, part := range m.Parts {
401		if c, ok := part.(ToolCall); ok {
402			if c.ID == toolCallID {
403				m.Parts[i] = ToolCall{
404					ID:       c.ID,
405					Name:     c.Name,
406					Input:    c.Input,
407					Type:     c.Type,
408					Finished: true,
409				}
410				return
411			}
412		}
413	}
414}
415
416// AppendToolCallInput appends input to a tool call.
417func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
418	for i, part := range m.Parts {
419		if c, ok := part.(ToolCall); ok {
420			if c.ID == toolCallID {
421				m.Parts[i] = ToolCall{
422					ID:       c.ID,
423					Name:     c.Name,
424					Input:    c.Input + inputDelta,
425					Type:     c.Type,
426					Finished: c.Finished,
427				}
428				return
429			}
430		}
431	}
432}
433
434// AddToolCall adds or updates a tool call.
435func (m *Message) AddToolCall(tc ToolCall) {
436	for i, part := range m.Parts {
437		if c, ok := part.(ToolCall); ok {
438			if c.ID == tc.ID {
439				m.Parts[i] = tc
440				return
441			}
442		}
443	}
444	m.Parts = append(m.Parts, tc)
445}
446
447// SetToolCalls replaces all tool call parts.
448func (m *Message) SetToolCalls(tc []ToolCall) {
449	parts := make([]ContentPart, 0)
450	for _, part := range m.Parts {
451		if _, ok := part.(ToolCall); ok {
452			continue
453		}
454		parts = append(parts, part)
455	}
456	m.Parts = parts
457	for _, toolCall := range tc {
458		m.Parts = append(m.Parts, toolCall)
459	}
460}
461
462// AddToolResult adds a tool result.
463func (m *Message) AddToolResult(tr ToolResult) {
464	m.Parts = append(m.Parts, tr)
465}
466
467// SetToolResults adds multiple tool results.
468func (m *Message) SetToolResults(tr []ToolResult) {
469	for _, toolResult := range tr {
470		m.Parts = append(m.Parts, toolResult)
471	}
472}
473
474// AddFinish adds a finish part to the message.
475func (m *Message) AddFinish(reason FinishReason, message, details string) {
476	for i, part := range m.Parts {
477		if _, ok := part.(Finish); ok {
478			m.Parts = slices.Delete(m.Parts, i, i+1)
479			break
480		}
481	}
482	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
483}
484
485// AddImageURL adds an image URL part to the message.
486func (m *Message) AddImageURL(url, detail string) {
487	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
488}
489
490// AddBinary adds a binary content part to the message.
491func (m *Message) AddBinary(mimeType string, data []byte) {
492	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
493}
494
495type partType string
496
497const (
498	reasoningType  partType = "reasoning"
499	textType       partType = "text"
500	imageURLType   partType = "image_url"
501	binaryType     partType = "binary"
502	toolCallType   partType = "tool_call"
503	toolResultType partType = "tool_result"
504	finishType     partType = "finish"
505)
506
507type partWrapper struct {
508	Type partType    `json:"type"`
509	Data ContentPart `json:"data"`
510}
511
512// MarshalParts marshals content parts to JSON.
513func MarshalParts(parts []ContentPart) ([]byte, error) {
514	wrappedParts := make([]partWrapper, len(parts))
515
516	for i, part := range parts {
517		var typ partType
518
519		switch part.(type) {
520		case ReasoningContent:
521			typ = reasoningType
522		case TextContent:
523			typ = textType
524		case ImageURLContent:
525			typ = imageURLType
526		case BinaryContent:
527			typ = binaryType
528		case ToolCall:
529			typ = toolCallType
530		case ToolResult:
531			typ = toolResultType
532		case Finish:
533			typ = finishType
534		default:
535			return nil, fmt.Errorf("unknown part type: %T", part)
536		}
537
538		wrappedParts[i] = partWrapper{
539			Type: typ,
540			Data: part,
541		}
542	}
543	return json.Marshal(wrappedParts)
544}
545
546// UnmarshalParts unmarshals content parts from JSON.
547func UnmarshalParts(data []byte) ([]ContentPart, error) {
548	temp := []json.RawMessage{}
549
550	if err := json.Unmarshal(data, &temp); err != nil {
551		return nil, err
552	}
553
554	parts := make([]ContentPart, 0)
555
556	for _, rawPart := range temp {
557		var wrapper struct {
558			Type partType        `json:"type"`
559			Data json.RawMessage `json:"data"`
560		}
561
562		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
563			return nil, err
564		}
565
566		switch wrapper.Type {
567		case reasoningType:
568			part := ReasoningContent{}
569			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
570				return nil, err
571			}
572			parts = append(parts, part)
573		case textType:
574			part := TextContent{}
575			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
576				return nil, err
577			}
578			parts = append(parts, part)
579		case imageURLType:
580			part := ImageURLContent{}
581			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
582				return nil, err
583			}
584			parts = append(parts, part)
585		case binaryType:
586			part := BinaryContent{}
587			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
588				return nil, err
589			}
590			parts = append(parts, part)
591		case toolCallType:
592			part := ToolCall{}
593			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
594				return nil, err
595			}
596			parts = append(parts, part)
597		case toolResultType:
598			part := ToolResult{}
599			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
600				return nil, err
601			}
602			parts = append(parts, part)
603		case finishType:
604			part := Finish{}
605			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
606				return nil, err
607			}
608			parts = append(parts, part)
609		default:
610			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
611		}
612	}
613
614	return parts, nil
615}
616
617// Attachment represents a file attachment.
618type Attachment struct {
619	FilePath string `json:"file_path"`
620	FileName string `json:"file_name"`
621	MimeType string `json:"mime_type"`
622	Content  []byte `json:"content"`
623}
624
625// ToMessage converts a proto Attachment to a [message.Attachment].
626func (a Attachment) ToMessage() message.Attachment {
627	return message.Attachment{
628		FilePath: a.FilePath,
629		FileName: a.FileName,
630		MimeType: a.MimeType,
631		Content:  a.Content,
632	}
633}
634
635// AttachmentFromMessage converts a [message.Attachment] to a proto
636// Attachment.
637func AttachmentFromMessage(a message.Attachment) Attachment {
638	return Attachment{
639		FilePath: a.FilePath,
640		FileName: a.FileName,
641		MimeType: a.MimeType,
642		Content:  a.Content,
643	}
644}
645
646// AttachmentsToMessage converts a slice of proto Attachments to a slice
647// of [message.Attachment].
648func AttachmentsToMessage(as []Attachment) []message.Attachment {
649	out := make([]message.Attachment, len(as))
650	for i, a := range as {
651		out[i] = a.ToMessage()
652	}
653	return out
654}
655
656// AttachmentsFromMessage converts a slice of [message.Attachment] to a
657// slice of proto Attachments.
658func AttachmentsFromMessage(as []message.Attachment) []Attachment {
659	out := make([]Attachment, len(as))
660	for i, a := range as {
661		out[i] = AttachmentFromMessage(a)
662	}
663	return out
664}
665
666// MarshalJSON implements the [json.Marshaler] interface.
667func (a Attachment) MarshalJSON() ([]byte, error) {
668	type Alias Attachment
669	return json.Marshal(&struct {
670		Content string `json:"content"`
671		*Alias
672	}{
673		Content: base64.StdEncoding.EncodeToString(a.Content),
674		Alias:   (*Alias)(&a),
675	})
676}
677
678// UnmarshalJSON implements the [json.Unmarshaler] interface.
679func (a *Attachment) UnmarshalJSON(data []byte) error {
680	type Alias Attachment
681	aux := &struct {
682		Content string `json:"content"`
683		*Alias
684	}{
685		Alias: (*Alias)(a),
686	}
687	if err := json.Unmarshal(data, &aux); err != nil {
688		return err
689	}
690	content, err := base64.StdEncoding.DecodeString(aux.Content)
691	if err != nil {
692		return err
693	}
694	a.Content = content
695	return nil
696}