message.go

  1package proto
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"slices"
  8	"time"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11)
 12
 13// CreateMessageParams represents parameters for creating a message.
 14type CreateMessageParams struct {
 15	Role     MessageRole   `json:"role"`
 16	Parts    []ContentPart `json:"parts"`
 17	Model    string        `json:"model"`
 18	Provider string        `json:"provider,omitempty"`
 19}
 20
 21// Message represents a message in the proto layer.
 22type Message struct {
 23	ID        string        `json:"id"`
 24	Role      MessageRole   `json:"role"`
 25	SessionID string        `json:"session_id"`
 26	Parts     []ContentPart `json:"parts"`
 27	Model     string        `json:"model"`
 28	Provider  string        `json:"provider"`
 29	CreatedAt int64         `json:"created_at"`
 30	UpdatedAt int64         `json:"updated_at"`
 31}
 32
 33// MessageRole represents the role of a message sender.
 34type MessageRole string
 35
 36const (
 37	Assistant MessageRole = "assistant"
 38	User      MessageRole = "user"
 39	System    MessageRole = "system"
 40	Tool      MessageRole = "tool"
 41)
 42
 43// MarshalText implements the [encoding.TextMarshaler] interface.
 44func (r MessageRole) MarshalText() ([]byte, error) {
 45	return []byte(r), nil
 46}
 47
 48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 49func (r *MessageRole) UnmarshalText(data []byte) error {
 50	*r = MessageRole(data)
 51	return nil
 52}
 53
 54// FinishReason represents why a message generation finished.
 55type FinishReason string
 56
 57const (
 58	FinishReasonEndTurn   FinishReason = "end_turn"
 59	FinishReasonMaxTokens FinishReason = "max_tokens"
 60	FinishReasonToolUse   FinishReason = "tool_use"
 61	FinishReasonCanceled  FinishReason = "canceled"
 62	FinishReasonError     FinishReason = "error"
 63	FinishReasonUnknown   FinishReason = "unknown"
 64)
 65
 66// MarshalText implements the [encoding.TextMarshaler] interface.
 67func (fr FinishReason) MarshalText() ([]byte, error) {
 68	return []byte(fr), nil
 69}
 70
 71// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 72func (fr *FinishReason) UnmarshalText(data []byte) error {
 73	*fr = FinishReason(data)
 74	return nil
 75}
 76
 77// ContentPart is a part of a message's content.
 78type ContentPart interface {
 79	isPart()
 80}
 81
 82// ReasoningContent represents the reasoning/thinking part of a message.
 83type ReasoningContent struct {
 84	Thinking   string `json:"thinking"`
 85	Signature  string `json:"signature"`
 86	StartedAt  int64  `json:"started_at,omitempty"`
 87	FinishedAt int64  `json:"finished_at,omitempty"`
 88}
 89
 90// String returns the thinking content as a string.
 91func (tc ReasoningContent) String() string {
 92	return tc.Thinking
 93}
 94
 95func (ReasoningContent) isPart() {}
 96
 97// TextContent represents a text part of a message.
 98type TextContent struct {
 99	Text string `json:"text"`
100}
101
102// String returns the text content as a string.
103func (tc TextContent) String() string {
104	return tc.Text
105}
106
107func (TextContent) isPart() {}
108
109// ImageURLContent represents an image URL part of a message.
110type ImageURLContent struct {
111	URL    string `json:"url"`
112	Detail string `json:"detail,omitempty"`
113}
114
115// String returns the image URL as a string.
116func (iuc ImageURLContent) String() string {
117	return iuc.URL
118}
119
120func (ImageURLContent) isPart() {}
121
122// BinaryContent represents binary data in a message.
123type BinaryContent struct {
124	Path     string
125	MIMEType string
126	Data     []byte
127}
128
129// String returns a base64-encoded string of the binary data.
130func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
131	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
132	if p == catwalk.InferenceProviderOpenAI {
133		return "data:" + bc.MIMEType + ";base64," + base64Encoded
134	}
135	return base64Encoded
136}
137
138func (BinaryContent) isPart() {}
139
140// ToolCall represents a tool call in a message.
141type ToolCall struct {
142	ID       string `json:"id"`
143	Name     string `json:"name"`
144	Input    string `json:"input"`
145	Type     string `json:"type,omitempty"`
146	Finished bool   `json:"finished,omitempty"`
147}
148
149func (ToolCall) isPart() {}
150
151// ToolResult represents the result of a tool call.
152type ToolResult struct {
153	ToolCallID string `json:"tool_call_id"`
154	Name       string `json:"name"`
155	Content    string `json:"content"`
156	Data       string `json:"data,omitempty"`
157	MIMEType   string `json:"mime_type,omitempty"`
158	Metadata   string `json:"metadata"`
159	IsError    bool   `json:"is_error"`
160}
161
162func (ToolResult) isPart() {}
163
164// Finish represents the end of a message generation.
165type Finish struct {
166	Reason  FinishReason `json:"reason"`
167	Time    int64        `json:"time"`
168	Message string       `json:"message,omitempty"`
169	Details string       `json:"details,omitempty"`
170}
171
172func (Finish) isPart() {}
173
174// MarshalJSON implements the [json.Marshaler] interface.
175func (m Message) MarshalJSON() ([]byte, error) {
176	parts, err := MarshalParts(m.Parts)
177	if err != nil {
178		return nil, err
179	}
180
181	type Alias Message
182	return json.Marshal(&struct {
183		Parts json.RawMessage `json:"parts"`
184		*Alias
185	}{
186		Parts: json.RawMessage(parts),
187		Alias: (*Alias)(&m),
188	})
189}
190
191// UnmarshalJSON implements the [json.Unmarshaler] interface.
192func (m *Message) UnmarshalJSON(data []byte) error {
193	type Alias Message
194	aux := &struct {
195		Parts json.RawMessage `json:"parts"`
196		*Alias
197	}{
198		Alias: (*Alias)(m),
199	}
200
201	if err := json.Unmarshal(data, &aux); err != nil {
202		return err
203	}
204
205	parts, err := UnmarshalParts([]byte(aux.Parts))
206	if err != nil {
207		return err
208	}
209
210	m.Parts = parts
211	return nil
212}
213
214// Content returns the first text content part.
215func (m *Message) Content() TextContent {
216	for _, part := range m.Parts {
217		if c, ok := part.(TextContent); ok {
218			return c
219		}
220	}
221	return TextContent{}
222}
223
224// ReasoningContent returns the first reasoning content part.
225func (m *Message) ReasoningContent() ReasoningContent {
226	for _, part := range m.Parts {
227		if c, ok := part.(ReasoningContent); ok {
228			return c
229		}
230	}
231	return ReasoningContent{}
232}
233
234// ImageURLContent returns all image URL content parts.
235func (m *Message) ImageURLContent() []ImageURLContent {
236	imageURLContents := make([]ImageURLContent, 0)
237	for _, part := range m.Parts {
238		if c, ok := part.(ImageURLContent); ok {
239			imageURLContents = append(imageURLContents, c)
240		}
241	}
242	return imageURLContents
243}
244
245// BinaryContent returns all binary content parts.
246func (m *Message) BinaryContent() []BinaryContent {
247	binaryContents := make([]BinaryContent, 0)
248	for _, part := range m.Parts {
249		if c, ok := part.(BinaryContent); ok {
250			binaryContents = append(binaryContents, c)
251		}
252	}
253	return binaryContents
254}
255
256// ToolCalls returns all tool call parts.
257func (m *Message) ToolCalls() []ToolCall {
258	toolCalls := make([]ToolCall, 0)
259	for _, part := range m.Parts {
260		if c, ok := part.(ToolCall); ok {
261			toolCalls = append(toolCalls, c)
262		}
263	}
264	return toolCalls
265}
266
267// ToolResults returns all tool result parts.
268func (m *Message) ToolResults() []ToolResult {
269	toolResults := make([]ToolResult, 0)
270	for _, part := range m.Parts {
271		if c, ok := part.(ToolResult); ok {
272			toolResults = append(toolResults, c)
273		}
274	}
275	return toolResults
276}
277
278// IsFinished returns true if the message has a finish part.
279func (m *Message) IsFinished() bool {
280	for _, part := range m.Parts {
281		if _, ok := part.(Finish); ok {
282			return true
283		}
284	}
285	return false
286}
287
288// FinishPart returns the finish part if present.
289func (m *Message) FinishPart() *Finish {
290	for _, part := range m.Parts {
291		if c, ok := part.(Finish); ok {
292			return &c
293		}
294	}
295	return nil
296}
297
298// FinishReason returns the finish reason if present.
299func (m *Message) FinishReason() FinishReason {
300	for _, part := range m.Parts {
301		if c, ok := part.(Finish); ok {
302			return c.Reason
303		}
304	}
305	return ""
306}
307
308// IsThinking returns true if the message is currently in a thinking state.
309func (m *Message) IsThinking() bool {
310	return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
311}
312
313// AppendContent appends text to the text content part.
314func (m *Message) AppendContent(delta string) {
315	found := false
316	for i, part := range m.Parts {
317		if c, ok := part.(TextContent); ok {
318			m.Parts[i] = TextContent{Text: c.Text + delta}
319			found = true
320		}
321	}
322	if !found {
323		m.Parts = append(m.Parts, TextContent{Text: delta})
324	}
325}
326
327// AppendReasoningContent appends text to the reasoning content part.
328func (m *Message) AppendReasoningContent(delta string) {
329	found := false
330	for i, part := range m.Parts {
331		if c, ok := part.(ReasoningContent); ok {
332			m.Parts[i] = ReasoningContent{
333				Thinking:   c.Thinking + delta,
334				Signature:  c.Signature,
335				StartedAt:  c.StartedAt,
336				FinishedAt: c.FinishedAt,
337			}
338			found = true
339		}
340	}
341	if !found {
342		m.Parts = append(m.Parts, ReasoningContent{
343			Thinking:  delta,
344			StartedAt: time.Now().Unix(),
345		})
346	}
347}
348
349// AppendReasoningSignature appends a signature to the reasoning content part.
350func (m *Message) AppendReasoningSignature(signature string) {
351	for i, part := range m.Parts {
352		if c, ok := part.(ReasoningContent); ok {
353			m.Parts[i] = ReasoningContent{
354				Thinking:   c.Thinking,
355				Signature:  c.Signature + signature,
356				StartedAt:  c.StartedAt,
357				FinishedAt: c.FinishedAt,
358			}
359			return
360		}
361	}
362	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
363}
364
365// FinishThinking marks the reasoning content as finished.
366func (m *Message) FinishThinking() {
367	for i, part := range m.Parts {
368		if c, ok := part.(ReasoningContent); ok {
369			if c.FinishedAt == 0 {
370				m.Parts[i] = ReasoningContent{
371					Thinking:   c.Thinking,
372					Signature:  c.Signature,
373					StartedAt:  c.StartedAt,
374					FinishedAt: time.Now().Unix(),
375				}
376			}
377			return
378		}
379	}
380}
381
382// ThinkingDuration returns the duration of the thinking phase.
383func (m *Message) ThinkingDuration() time.Duration {
384	reasoning := m.ReasoningContent()
385	if reasoning.StartedAt == 0 {
386		return 0
387	}
388
389	endTime := reasoning.FinishedAt
390	if endTime == 0 {
391		endTime = time.Now().Unix()
392	}
393
394	return time.Duration(endTime-reasoning.StartedAt) * time.Second
395}
396
397// FinishToolCall marks a tool call as finished.
398func (m *Message) FinishToolCall(toolCallID string) {
399	for i, part := range m.Parts {
400		if c, ok := part.(ToolCall); ok {
401			if c.ID == toolCallID {
402				m.Parts[i] = ToolCall{
403					ID:       c.ID,
404					Name:     c.Name,
405					Input:    c.Input,
406					Type:     c.Type,
407					Finished: true,
408				}
409				return
410			}
411		}
412	}
413}
414
415// AppendToolCallInput appends input to a tool call.
416func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
417	for i, part := range m.Parts {
418		if c, ok := part.(ToolCall); ok {
419			if c.ID == toolCallID {
420				m.Parts[i] = ToolCall{
421					ID:       c.ID,
422					Name:     c.Name,
423					Input:    c.Input + inputDelta,
424					Type:     c.Type,
425					Finished: c.Finished,
426				}
427				return
428			}
429		}
430	}
431}
432
433// AddToolCall adds or updates a tool call.
434func (m *Message) AddToolCall(tc ToolCall) {
435	for i, part := range m.Parts {
436		if c, ok := part.(ToolCall); ok {
437			if c.ID == tc.ID {
438				m.Parts[i] = tc
439				return
440			}
441		}
442	}
443	m.Parts = append(m.Parts, tc)
444}
445
446// SetToolCalls replaces all tool call parts.
447func (m *Message) SetToolCalls(tc []ToolCall) {
448	parts := make([]ContentPart, 0)
449	for _, part := range m.Parts {
450		if _, ok := part.(ToolCall); ok {
451			continue
452		}
453		parts = append(parts, part)
454	}
455	m.Parts = parts
456	for _, toolCall := range tc {
457		m.Parts = append(m.Parts, toolCall)
458	}
459}
460
461// AddToolResult adds a tool result.
462func (m *Message) AddToolResult(tr ToolResult) {
463	m.Parts = append(m.Parts, tr)
464}
465
466// SetToolResults adds multiple tool results.
467func (m *Message) SetToolResults(tr []ToolResult) {
468	for _, toolResult := range tr {
469		m.Parts = append(m.Parts, toolResult)
470	}
471}
472
473// AddFinish adds a finish part to the message.
474func (m *Message) AddFinish(reason FinishReason, message, details string) {
475	for i, part := range m.Parts {
476		if _, ok := part.(Finish); ok {
477			m.Parts = slices.Delete(m.Parts, i, i+1)
478			break
479		}
480	}
481	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
482}
483
484// AddImageURL adds an image URL part to the message.
485func (m *Message) AddImageURL(url, detail string) {
486	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
487}
488
489// AddBinary adds a binary content part to the message.
490func (m *Message) AddBinary(mimeType string, data []byte) {
491	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
492}
493
494type partType string
495
496const (
497	reasoningType  partType = "reasoning"
498	textType       partType = "text"
499	imageURLType   partType = "image_url"
500	binaryType     partType = "binary"
501	toolCallType   partType = "tool_call"
502	toolResultType partType = "tool_result"
503	finishType     partType = "finish"
504)
505
506type partWrapper struct {
507	Type partType    `json:"type"`
508	Data ContentPart `json:"data"`
509}
510
511// MarshalParts marshals content parts to JSON.
512func MarshalParts(parts []ContentPart) ([]byte, error) {
513	wrappedParts := make([]partWrapper, len(parts))
514
515	for i, part := range parts {
516		var typ partType
517
518		switch part.(type) {
519		case ReasoningContent:
520			typ = reasoningType
521		case TextContent:
522			typ = textType
523		case ImageURLContent:
524			typ = imageURLType
525		case BinaryContent:
526			typ = binaryType
527		case ToolCall:
528			typ = toolCallType
529		case ToolResult:
530			typ = toolResultType
531		case Finish:
532			typ = finishType
533		default:
534			return nil, fmt.Errorf("unknown part type: %T", part)
535		}
536
537		wrappedParts[i] = partWrapper{
538			Type: typ,
539			Data: part,
540		}
541	}
542	return json.Marshal(wrappedParts)
543}
544
545// UnmarshalParts unmarshals content parts from JSON.
546func UnmarshalParts(data []byte) ([]ContentPart, error) {
547	temp := []json.RawMessage{}
548
549	if err := json.Unmarshal(data, &temp); err != nil {
550		return nil, err
551	}
552
553	parts := make([]ContentPart, 0)
554
555	for _, rawPart := range temp {
556		var wrapper struct {
557			Type partType        `json:"type"`
558			Data json.RawMessage `json:"data"`
559		}
560
561		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
562			return nil, err
563		}
564
565		switch wrapper.Type {
566		case reasoningType:
567			part := ReasoningContent{}
568			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
569				return nil, err
570			}
571			parts = append(parts, part)
572		case textType:
573			part := TextContent{}
574			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
575				return nil, err
576			}
577			parts = append(parts, part)
578		case imageURLType:
579			part := ImageURLContent{}
580			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
581				return nil, err
582			}
583			parts = append(parts, part)
584		case binaryType:
585			part := BinaryContent{}
586			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
587				return nil, err
588			}
589			parts = append(parts, part)
590		case toolCallType:
591			part := ToolCall{}
592			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
593				return nil, err
594			}
595			parts = append(parts, part)
596		case toolResultType:
597			part := ToolResult{}
598			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
599				return nil, err
600			}
601			parts = append(parts, part)
602		case finishType:
603			part := Finish{}
604			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
605				return nil, err
606			}
607			parts = append(parts, part)
608		default:
609			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
610		}
611	}
612
613	return parts, nil
614}
615
616// Attachment represents a file attachment.
617type Attachment struct {
618	FilePath string `json:"file_path"`
619	FileName string `json:"file_name"`
620	MimeType string `json:"mime_type"`
621	Content  []byte `json:"content"`
622}
623
624// MarshalJSON implements the [json.Marshaler] interface.
625func (a Attachment) MarshalJSON() ([]byte, error) {
626	type Alias Attachment
627	return json.Marshal(&struct {
628		Content string `json:"content"`
629		*Alias
630	}{
631		Content: base64.StdEncoding.EncodeToString(a.Content),
632		Alias:   (*Alias)(&a),
633	})
634}
635
636// UnmarshalJSON implements the [json.Unmarshaler] interface.
637func (a *Attachment) UnmarshalJSON(data []byte) error {
638	type Alias Attachment
639	aux := &struct {
640		Content string `json:"content"`
641		*Alias
642	}{
643		Alias: (*Alias)(a),
644	}
645	if err := json.Unmarshal(data, &aux); err != nil {
646		return err
647	}
648	content, err := base64.StdEncoding.DecodeString(aux.Content)
649	if err != nil {
650		return err
651	}
652	a.Content = content
653	return nil
654}