message.go

  1package proto
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"slices"
  8	"time"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11)
 12
 13// CreateMessageParams represents parameters for creating a message.
 14type CreateMessageParams struct {
 15	Role     MessageRole   `json:"role"`
 16	Parts    []ContentPart `json:"parts"`
 17	Model    string        `json:"model"`
 18	Provider string        `json:"provider,omitempty"`
 19}
 20
 21// Message represents a message in the proto layer.
 22type Message struct {
 23	ID        string        `json:"id"`
 24	Role      MessageRole   `json:"role"`
 25	SessionID string        `json:"session_id"`
 26	Parts     []ContentPart `json:"parts"`
 27	Model     string        `json:"model"`
 28	Provider  string        `json:"provider"`
 29	CreatedAt int64         `json:"created_at"`
 30	UpdatedAt int64         `json:"updated_at"`
 31}
 32
 33// MessageRole represents the role of a message sender.
 34type MessageRole string
 35
 36const (
 37	Assistant MessageRole = "assistant"
 38	User      MessageRole = "user"
 39	System    MessageRole = "system"
 40	Tool      MessageRole = "tool"
 41)
 42
 43// MarshalText implements the [encoding.TextMarshaler] interface.
 44func (r MessageRole) MarshalText() ([]byte, error) {
 45	return []byte(r), nil
 46}
 47
 48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 49func (r *MessageRole) UnmarshalText(data []byte) error {
 50	*r = MessageRole(data)
 51	return nil
 52}
 53
 54// FinishReason represents why a message generation finished.
 55type FinishReason string
 56
 57const (
 58	FinishReasonEndTurn          FinishReason = "end_turn"
 59	FinishReasonMaxTokens        FinishReason = "max_tokens"
 60	FinishReasonToolUse          FinishReason = "tool_use"
 61	FinishReasonCanceled         FinishReason = "canceled"
 62	FinishReasonError            FinishReason = "error"
 63	FinishReasonPermissionDenied FinishReason = "permission_denied"
 64	FinishReasonUnknown          FinishReason = "unknown"
 65)
 66
 67// MarshalText implements the [encoding.TextMarshaler] interface.
 68func (fr FinishReason) MarshalText() ([]byte, error) {
 69	return []byte(fr), nil
 70}
 71
 72// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 73func (fr *FinishReason) UnmarshalText(data []byte) error {
 74	*fr = FinishReason(data)
 75	return nil
 76}
 77
 78// ContentPart is a part of a message's content.
 79type ContentPart interface {
 80	isPart()
 81}
 82
 83// ReasoningContent represents the reasoning/thinking part of a message.
 84type ReasoningContent struct {
 85	Thinking   string `json:"thinking"`
 86	Signature  string `json:"signature"`
 87	StartedAt  int64  `json:"started_at,omitempty"`
 88	FinishedAt int64  `json:"finished_at,omitempty"`
 89}
 90
 91// String returns the thinking content as a string.
 92func (tc ReasoningContent) String() string {
 93	return tc.Thinking
 94}
 95
 96func (ReasoningContent) isPart() {}
 97
 98// TextContent represents a text part of a message.
 99type TextContent struct {
100	Text string `json:"text"`
101}
102
103// String returns the text content as a string.
104func (tc TextContent) String() string {
105	return tc.Text
106}
107
108func (TextContent) isPart() {}
109
110// ImageURLContent represents an image URL part of a message.
111type ImageURLContent struct {
112	URL    string `json:"url"`
113	Detail string `json:"detail,omitempty"`
114}
115
116// String returns the image URL as a string.
117func (iuc ImageURLContent) String() string {
118	return iuc.URL
119}
120
121func (ImageURLContent) isPart() {}
122
123// BinaryContent represents binary data in a message.
124type BinaryContent struct {
125	Path     string
126	MIMEType string
127	Data     []byte
128}
129
130// String returns a base64-encoded string of the binary data.
131func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
132	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
133	if p == catwalk.InferenceProviderOpenAI {
134		return "data:" + bc.MIMEType + ";base64," + base64Encoded
135	}
136	return base64Encoded
137}
138
139func (BinaryContent) isPart() {}
140
141// ToolCall represents a tool call in a message.
142type ToolCall struct {
143	ID       string `json:"id"`
144	Name     string `json:"name"`
145	Input    string `json:"input"`
146	Type     string `json:"type,omitempty"`
147	Finished bool   `json:"finished,omitempty"`
148}
149
150func (ToolCall) isPart() {}
151
152// ToolResult represents the result of a tool call.
153type ToolResult struct {
154	ToolCallID string `json:"tool_call_id"`
155	Name       string `json:"name"`
156	Content    string `json:"content"`
157	Metadata   string `json:"metadata"`
158	IsError    bool   `json:"is_error"`
159}
160
161func (ToolResult) isPart() {}
162
163// Finish represents the end of a message generation.
164type Finish struct {
165	Reason  FinishReason `json:"reason"`
166	Time    int64        `json:"time"`
167	Message string       `json:"message,omitempty"`
168	Details string       `json:"details,omitempty"`
169}
170
171func (Finish) isPart() {}
172
173// MarshalJSON implements the [json.Marshaler] interface.
174func (m Message) MarshalJSON() ([]byte, error) {
175	parts, err := MarshalParts(m.Parts)
176	if err != nil {
177		return nil, err
178	}
179
180	type Alias Message
181	return json.Marshal(&struct {
182		Parts json.RawMessage `json:"parts"`
183		*Alias
184	}{
185		Parts: json.RawMessage(parts),
186		Alias: (*Alias)(&m),
187	})
188}
189
190// UnmarshalJSON implements the [json.Unmarshaler] interface.
191func (m *Message) UnmarshalJSON(data []byte) error {
192	type Alias Message
193	aux := &struct {
194		Parts json.RawMessage `json:"parts"`
195		*Alias
196	}{
197		Alias: (*Alias)(m),
198	}
199
200	if err := json.Unmarshal(data, &aux); err != nil {
201		return err
202	}
203
204	parts, err := UnmarshalParts([]byte(aux.Parts))
205	if err != nil {
206		return err
207	}
208
209	m.Parts = parts
210	return nil
211}
212
213// Content returns the first text content part.
214func (m *Message) Content() TextContent {
215	for _, part := range m.Parts {
216		if c, ok := part.(TextContent); ok {
217			return c
218		}
219	}
220	return TextContent{}
221}
222
223// ReasoningContent returns the first reasoning content part.
224func (m *Message) ReasoningContent() ReasoningContent {
225	for _, part := range m.Parts {
226		if c, ok := part.(ReasoningContent); ok {
227			return c
228		}
229	}
230	return ReasoningContent{}
231}
232
233// ImageURLContent returns all image URL content parts.
234func (m *Message) ImageURLContent() []ImageURLContent {
235	imageURLContents := make([]ImageURLContent, 0)
236	for _, part := range m.Parts {
237		if c, ok := part.(ImageURLContent); ok {
238			imageURLContents = append(imageURLContents, c)
239		}
240	}
241	return imageURLContents
242}
243
244// BinaryContent returns all binary content parts.
245func (m *Message) BinaryContent() []BinaryContent {
246	binaryContents := make([]BinaryContent, 0)
247	for _, part := range m.Parts {
248		if c, ok := part.(BinaryContent); ok {
249			binaryContents = append(binaryContents, c)
250		}
251	}
252	return binaryContents
253}
254
255// ToolCalls returns all tool call parts.
256func (m *Message) ToolCalls() []ToolCall {
257	toolCalls := make([]ToolCall, 0)
258	for _, part := range m.Parts {
259		if c, ok := part.(ToolCall); ok {
260			toolCalls = append(toolCalls, c)
261		}
262	}
263	return toolCalls
264}
265
266// ToolResults returns all tool result parts.
267func (m *Message) ToolResults() []ToolResult {
268	toolResults := make([]ToolResult, 0)
269	for _, part := range m.Parts {
270		if c, ok := part.(ToolResult); ok {
271			toolResults = append(toolResults, c)
272		}
273	}
274	return toolResults
275}
276
277// IsFinished returns true if the message has a finish part.
278func (m *Message) IsFinished() bool {
279	for _, part := range m.Parts {
280		if _, ok := part.(Finish); ok {
281			return true
282		}
283	}
284	return false
285}
286
287// FinishPart returns the finish part if present.
288func (m *Message) FinishPart() *Finish {
289	for _, part := range m.Parts {
290		if c, ok := part.(Finish); ok {
291			return &c
292		}
293	}
294	return nil
295}
296
297// FinishReason returns the finish reason if present.
298func (m *Message) FinishReason() FinishReason {
299	for _, part := range m.Parts {
300		if c, ok := part.(Finish); ok {
301			return c.Reason
302		}
303	}
304	return ""
305}
306
307// IsThinking returns true if the message is currently in a thinking state.
308func (m *Message) IsThinking() bool {
309	return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
310}
311
312// AppendContent appends text to the text content part.
313func (m *Message) AppendContent(delta string) {
314	found := false
315	for i, part := range m.Parts {
316		if c, ok := part.(TextContent); ok {
317			m.Parts[i] = TextContent{Text: c.Text + delta}
318			found = true
319		}
320	}
321	if !found {
322		m.Parts = append(m.Parts, TextContent{Text: delta})
323	}
324}
325
326// AppendReasoningContent appends text to the reasoning content part.
327func (m *Message) AppendReasoningContent(delta string) {
328	found := false
329	for i, part := range m.Parts {
330		if c, ok := part.(ReasoningContent); ok {
331			m.Parts[i] = ReasoningContent{
332				Thinking:   c.Thinking + delta,
333				Signature:  c.Signature,
334				StartedAt:  c.StartedAt,
335				FinishedAt: c.FinishedAt,
336			}
337			found = true
338		}
339	}
340	if !found {
341		m.Parts = append(m.Parts, ReasoningContent{
342			Thinking:  delta,
343			StartedAt: time.Now().Unix(),
344		})
345	}
346}
347
348// AppendReasoningSignature appends a signature to the reasoning content part.
349func (m *Message) AppendReasoningSignature(signature string) {
350	for i, part := range m.Parts {
351		if c, ok := part.(ReasoningContent); ok {
352			m.Parts[i] = ReasoningContent{
353				Thinking:   c.Thinking,
354				Signature:  c.Signature + signature,
355				StartedAt:  c.StartedAt,
356				FinishedAt: c.FinishedAt,
357			}
358			return
359		}
360	}
361	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
362}
363
364// FinishThinking marks the reasoning content as finished.
365func (m *Message) FinishThinking() {
366	for i, part := range m.Parts {
367		if c, ok := part.(ReasoningContent); ok {
368			if c.FinishedAt == 0 {
369				m.Parts[i] = ReasoningContent{
370					Thinking:   c.Thinking,
371					Signature:  c.Signature,
372					StartedAt:  c.StartedAt,
373					FinishedAt: time.Now().Unix(),
374				}
375			}
376			return
377		}
378	}
379}
380
381// ThinkingDuration returns the duration of the thinking phase.
382func (m *Message) ThinkingDuration() time.Duration {
383	reasoning := m.ReasoningContent()
384	if reasoning.StartedAt == 0 {
385		return 0
386	}
387
388	endTime := reasoning.FinishedAt
389	if endTime == 0 {
390		endTime = time.Now().Unix()
391	}
392
393	return time.Duration(endTime-reasoning.StartedAt) * time.Second
394}
395
396// FinishToolCall marks a tool call as finished.
397func (m *Message) FinishToolCall(toolCallID string) {
398	for i, part := range m.Parts {
399		if c, ok := part.(ToolCall); ok {
400			if c.ID == toolCallID {
401				m.Parts[i] = ToolCall{
402					ID:       c.ID,
403					Name:     c.Name,
404					Input:    c.Input,
405					Type:     c.Type,
406					Finished: true,
407				}
408				return
409			}
410		}
411	}
412}
413
414// AppendToolCallInput appends input to a tool call.
415func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
416	for i, part := range m.Parts {
417		if c, ok := part.(ToolCall); ok {
418			if c.ID == toolCallID {
419				m.Parts[i] = ToolCall{
420					ID:       c.ID,
421					Name:     c.Name,
422					Input:    c.Input + inputDelta,
423					Type:     c.Type,
424					Finished: c.Finished,
425				}
426				return
427			}
428		}
429	}
430}
431
432// AddToolCall adds or updates a tool call.
433func (m *Message) AddToolCall(tc ToolCall) {
434	for i, part := range m.Parts {
435		if c, ok := part.(ToolCall); ok {
436			if c.ID == tc.ID {
437				m.Parts[i] = tc
438				return
439			}
440		}
441	}
442	m.Parts = append(m.Parts, tc)
443}
444
445// SetToolCalls replaces all tool call parts.
446func (m *Message) SetToolCalls(tc []ToolCall) {
447	parts := make([]ContentPart, 0)
448	for _, part := range m.Parts {
449		if _, ok := part.(ToolCall); ok {
450			continue
451		}
452		parts = append(parts, part)
453	}
454	m.Parts = parts
455	for _, toolCall := range tc {
456		m.Parts = append(m.Parts, toolCall)
457	}
458}
459
460// AddToolResult adds a tool result.
461func (m *Message) AddToolResult(tr ToolResult) {
462	m.Parts = append(m.Parts, tr)
463}
464
465// SetToolResults adds multiple tool results.
466func (m *Message) SetToolResults(tr []ToolResult) {
467	for _, toolResult := range tr {
468		m.Parts = append(m.Parts, toolResult)
469	}
470}
471
472// AddFinish adds a finish part to the message.
473func (m *Message) AddFinish(reason FinishReason, message, details string) {
474	for i, part := range m.Parts {
475		if _, ok := part.(Finish); ok {
476			m.Parts = slices.Delete(m.Parts, i, i+1)
477			break
478		}
479	}
480	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
481}
482
483// AddImageURL adds an image URL part to the message.
484func (m *Message) AddImageURL(url, detail string) {
485	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
486}
487
488// AddBinary adds a binary content part to the message.
489func (m *Message) AddBinary(mimeType string, data []byte) {
490	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
491}
492
493type partType string
494
495const (
496	reasoningType  partType = "reasoning"
497	textType       partType = "text"
498	imageURLType   partType = "image_url"
499	binaryType     partType = "binary"
500	toolCallType   partType = "tool_call"
501	toolResultType partType = "tool_result"
502	finishType     partType = "finish"
503)
504
505type partWrapper struct {
506	Type partType    `json:"type"`
507	Data ContentPart `json:"data"`
508}
509
510// MarshalParts marshals content parts to JSON.
511func MarshalParts(parts []ContentPart) ([]byte, error) {
512	wrappedParts := make([]partWrapper, len(parts))
513
514	for i, part := range parts {
515		var typ partType
516
517		switch part.(type) {
518		case ReasoningContent:
519			typ = reasoningType
520		case TextContent:
521			typ = textType
522		case ImageURLContent:
523			typ = imageURLType
524		case BinaryContent:
525			typ = binaryType
526		case ToolCall:
527			typ = toolCallType
528		case ToolResult:
529			typ = toolResultType
530		case Finish:
531			typ = finishType
532		default:
533			return nil, fmt.Errorf("unknown part type: %T", part)
534		}
535
536		wrappedParts[i] = partWrapper{
537			Type: typ,
538			Data: part,
539		}
540	}
541	return json.Marshal(wrappedParts)
542}
543
544// UnmarshalParts unmarshals content parts from JSON.
545func UnmarshalParts(data []byte) ([]ContentPart, error) {
546	temp := []json.RawMessage{}
547
548	if err := json.Unmarshal(data, &temp); err != nil {
549		return nil, err
550	}
551
552	parts := make([]ContentPart, 0)
553
554	for _, rawPart := range temp {
555		var wrapper struct {
556			Type partType        `json:"type"`
557			Data json.RawMessage `json:"data"`
558		}
559
560		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
561			return nil, err
562		}
563
564		switch wrapper.Type {
565		case reasoningType:
566			part := ReasoningContent{}
567			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
568				return nil, err
569			}
570			parts = append(parts, part)
571		case textType:
572			part := TextContent{}
573			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
574				return nil, err
575			}
576			parts = append(parts, part)
577		case imageURLType:
578			part := ImageURLContent{}
579			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
580				return nil, err
581			}
582			parts = append(parts, part)
583		case binaryType:
584			part := BinaryContent{}
585			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
586				return nil, err
587			}
588			parts = append(parts, part)
589		case toolCallType:
590			part := ToolCall{}
591			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
592				return nil, err
593			}
594			parts = append(parts, part)
595		case toolResultType:
596			part := ToolResult{}
597			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
598				return nil, err
599			}
600			parts = append(parts, part)
601		case finishType:
602			part := Finish{}
603			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
604				return nil, err
605			}
606			parts = append(parts, part)
607		default:
608			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
609		}
610	}
611
612	return parts, nil
613}
614
615// Attachment represents a file attachment.
616type Attachment struct {
617	FilePath string `json:"file_path"`
618	FileName string `json:"file_name"`
619	MimeType string `json:"mime_type"`
620	Content  []byte `json:"content"`
621}
622
623// MarshalJSON implements the [json.Marshaler] interface.
624func (a Attachment) MarshalJSON() ([]byte, error) {
625	type Alias Attachment
626	return json.Marshal(&struct {
627		Content string `json:"content"`
628		*Alias
629	}{
630		Content: base64.StdEncoding.EncodeToString(a.Content),
631		Alias:   (*Alias)(&a),
632	})
633}
634
635// UnmarshalJSON implements the [json.Unmarshaler] interface.
636func (a *Attachment) UnmarshalJSON(data []byte) error {
637	type Alias Attachment
638	aux := &struct {
639		Content string `json:"content"`
640		*Alias
641	}{
642		Alias: (*Alias)(a),
643	}
644	if err := json.Unmarshal(data, &aux); err != nil {
645		return err
646	}
647	content, err := base64.StdEncoding.DecodeString(aux.Content)
648	if err != nil {
649		return err
650	}
651	a.Content = content
652	return nil
653}