message.go

  1package proto
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"slices"
  8	"time"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11)
 12
 13// CreateMessageParams represents parameters for creating a message.
 14type CreateMessageParams struct {
 15	Role     MessageRole   `json:"role"`
 16	Parts    []ContentPart `json:"parts"`
 17	Model    string        `json:"model"`
 18	Provider string        `json:"provider,omitempty"`
 19}
 20
 21// Message represents a message in the proto layer.
 22type Message struct {
 23	ID        string        `json:"id"`
 24	Role      MessageRole   `json:"role"`
 25	SessionID string        `json:"session_id"`
 26	Parts     []ContentPart `json:"parts"`
 27	Model     string        `json:"model"`
 28	Provider  string        `json:"provider"`
 29	CreatedAt int64         `json:"created_at"`
 30	UpdatedAt int64         `json:"updated_at"`
 31}
 32
 33// MessageRole represents the role of a message sender.
 34type MessageRole string
 35
 36const (
 37	Assistant MessageRole = "assistant"
 38	User      MessageRole = "user"
 39	System    MessageRole = "system"
 40	Tool      MessageRole = "tool"
 41)
 42
 43// MarshalText implements the [encoding.TextMarshaler] interface.
 44func (r MessageRole) MarshalText() ([]byte, error) {
 45	return []byte(r), nil
 46}
 47
 48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 49func (r *MessageRole) UnmarshalText(data []byte) error {
 50	*r = MessageRole(data)
 51	return nil
 52}
 53
 54// FinishReason represents why a message generation finished.
 55type FinishReason string
 56
 57const (
 58	FinishReasonEndTurn   FinishReason = "end_turn"
 59	FinishReasonMaxTokens FinishReason = "max_tokens"
 60	FinishReasonToolUse   FinishReason = "tool_use"
 61	FinishReasonCanceled  FinishReason = "canceled"
 62	FinishReasonError     FinishReason = "error"
 63	FinishReasonUnknown   FinishReason = "unknown"
 64)
 65
 66// MarshalText implements the [encoding.TextMarshaler] interface.
 67func (fr FinishReason) MarshalText() ([]byte, error) {
 68	return []byte(fr), nil
 69}
 70
 71// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
 72func (fr *FinishReason) UnmarshalText(data []byte) error {
 73	*fr = FinishReason(data)
 74	return nil
 75}
 76
 77// ContentPart is a part of a message's content.
 78type ContentPart interface {
 79	isPart()
 80}
 81
 82// ReasoningContent represents the reasoning/thinking part of a message.
 83type ReasoningContent struct {
 84	Thinking   string `json:"thinking"`
 85	Signature  string `json:"signature"`
 86	StartedAt  int64  `json:"started_at,omitempty"`
 87	FinishedAt int64  `json:"finished_at,omitempty"`
 88}
 89
 90// String returns the thinking content as a string.
 91func (tc ReasoningContent) String() string {
 92	return tc.Thinking
 93}
 94
 95func (ReasoningContent) isPart() {}
 96
 97// TextContent represents a text part of a message.
 98type TextContent struct {
 99	Text string `json:"text"`
100}
101
102// String returns the text content as a string.
103func (tc TextContent) String() string {
104	return tc.Text
105}
106
107func (TextContent) isPart() {}
108
109// ImageURLContent represents an image URL part of a message.
110type ImageURLContent struct {
111	URL    string `json:"url"`
112	Detail string `json:"detail,omitempty"`
113}
114
115// String returns the image URL as a string.
116func (iuc ImageURLContent) String() string {
117	return iuc.URL
118}
119
120func (ImageURLContent) isPart() {}
121
122// BinaryContent represents binary data in a message.
123type BinaryContent struct {
124	Path     string
125	MIMEType string
126	Data     []byte
127}
128
129// String returns a base64-encoded string of the binary data.
130func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
131	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
132	if p == catwalk.InferenceProviderOpenAI {
133		return "data:" + bc.MIMEType + ";base64," + base64Encoded
134	}
135	return base64Encoded
136}
137
138func (BinaryContent) isPart() {}
139
140// ToolCall represents a tool call in a message.
141type ToolCall struct {
142	ID       string `json:"id"`
143	Name     string `json:"name"`
144	Input    string `json:"input"`
145	Type     string `json:"type,omitempty"`
146	Finished bool   `json:"finished,omitempty"`
147}
148
149func (ToolCall) isPart() {}
150
151// ToolResult represents the result of a tool call.
152type ToolResult struct {
153	ToolCallID string `json:"tool_call_id"`
154	Name       string `json:"name"`
155	Content    string `json:"content"`
156	Metadata   string `json:"metadata"`
157	IsError    bool   `json:"is_error"`
158}
159
160func (ToolResult) isPart() {}
161
162// Finish represents the end of a message generation.
163type Finish struct {
164	Reason  FinishReason `json:"reason"`
165	Time    int64        `json:"time"`
166	Message string       `json:"message,omitempty"`
167	Details string       `json:"details,omitempty"`
168}
169
170func (Finish) isPart() {}
171
172// MarshalJSON implements the [json.Marshaler] interface.
173func (m Message) MarshalJSON() ([]byte, error) {
174	parts, err := MarshalParts(m.Parts)
175	if err != nil {
176		return nil, err
177	}
178
179	type Alias Message
180	return json.Marshal(&struct {
181		Parts json.RawMessage `json:"parts"`
182		*Alias
183	}{
184		Parts: json.RawMessage(parts),
185		Alias: (*Alias)(&m),
186	})
187}
188
189// UnmarshalJSON implements the [json.Unmarshaler] interface.
190func (m *Message) UnmarshalJSON(data []byte) error {
191	type Alias Message
192	aux := &struct {
193		Parts json.RawMessage `json:"parts"`
194		*Alias
195	}{
196		Alias: (*Alias)(m),
197	}
198
199	if err := json.Unmarshal(data, &aux); err != nil {
200		return err
201	}
202
203	parts, err := UnmarshalParts([]byte(aux.Parts))
204	if err != nil {
205		return err
206	}
207
208	m.Parts = parts
209	return nil
210}
211
212// Content returns the first text content part.
213func (m *Message) Content() TextContent {
214	for _, part := range m.Parts {
215		if c, ok := part.(TextContent); ok {
216			return c
217		}
218	}
219	return TextContent{}
220}
221
222// ReasoningContent returns the first reasoning content part.
223func (m *Message) ReasoningContent() ReasoningContent {
224	for _, part := range m.Parts {
225		if c, ok := part.(ReasoningContent); ok {
226			return c
227		}
228	}
229	return ReasoningContent{}
230}
231
232// ImageURLContent returns all image URL content parts.
233func (m *Message) ImageURLContent() []ImageURLContent {
234	imageURLContents := make([]ImageURLContent, 0)
235	for _, part := range m.Parts {
236		if c, ok := part.(ImageURLContent); ok {
237			imageURLContents = append(imageURLContents, c)
238		}
239	}
240	return imageURLContents
241}
242
243// BinaryContent returns all binary content parts.
244func (m *Message) BinaryContent() []BinaryContent {
245	binaryContents := make([]BinaryContent, 0)
246	for _, part := range m.Parts {
247		if c, ok := part.(BinaryContent); ok {
248			binaryContents = append(binaryContents, c)
249		}
250	}
251	return binaryContents
252}
253
254// ToolCalls returns all tool call parts.
255func (m *Message) ToolCalls() []ToolCall {
256	toolCalls := make([]ToolCall, 0)
257	for _, part := range m.Parts {
258		if c, ok := part.(ToolCall); ok {
259			toolCalls = append(toolCalls, c)
260		}
261	}
262	return toolCalls
263}
264
265// ToolResults returns all tool result parts.
266func (m *Message) ToolResults() []ToolResult {
267	toolResults := make([]ToolResult, 0)
268	for _, part := range m.Parts {
269		if c, ok := part.(ToolResult); ok {
270			toolResults = append(toolResults, c)
271		}
272	}
273	return toolResults
274}
275
276// IsFinished returns true if the message has a finish part.
277func (m *Message) IsFinished() bool {
278	for _, part := range m.Parts {
279		if _, ok := part.(Finish); ok {
280			return true
281		}
282	}
283	return false
284}
285
286// FinishPart returns the finish part if present.
287func (m *Message) FinishPart() *Finish {
288	for _, part := range m.Parts {
289		if c, ok := part.(Finish); ok {
290			return &c
291		}
292	}
293	return nil
294}
295
296// FinishReason returns the finish reason if present.
297func (m *Message) FinishReason() FinishReason {
298	for _, part := range m.Parts {
299		if c, ok := part.(Finish); ok {
300			return c.Reason
301		}
302	}
303	return ""
304}
305
306// IsThinking returns true if the message is currently in a thinking state.
307func (m *Message) IsThinking() bool {
308	return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
309}
310
311// AppendContent appends text to the text content part.
312func (m *Message) AppendContent(delta string) {
313	found := false
314	for i, part := range m.Parts {
315		if c, ok := part.(TextContent); ok {
316			m.Parts[i] = TextContent{Text: c.Text + delta}
317			found = true
318		}
319	}
320	if !found {
321		m.Parts = append(m.Parts, TextContent{Text: delta})
322	}
323}
324
325// AppendReasoningContent appends text to the reasoning content part.
326func (m *Message) AppendReasoningContent(delta string) {
327	found := false
328	for i, part := range m.Parts {
329		if c, ok := part.(ReasoningContent); ok {
330			m.Parts[i] = ReasoningContent{
331				Thinking:   c.Thinking + delta,
332				Signature:  c.Signature,
333				StartedAt:  c.StartedAt,
334				FinishedAt: c.FinishedAt,
335			}
336			found = true
337		}
338	}
339	if !found {
340		m.Parts = append(m.Parts, ReasoningContent{
341			Thinking:  delta,
342			StartedAt: time.Now().Unix(),
343		})
344	}
345}
346
347// AppendReasoningSignature appends a signature to the reasoning content part.
348func (m *Message) AppendReasoningSignature(signature string) {
349	for i, part := range m.Parts {
350		if c, ok := part.(ReasoningContent); ok {
351			m.Parts[i] = ReasoningContent{
352				Thinking:   c.Thinking,
353				Signature:  c.Signature + signature,
354				StartedAt:  c.StartedAt,
355				FinishedAt: c.FinishedAt,
356			}
357			return
358		}
359	}
360	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
361}
362
363// FinishThinking marks the reasoning content as finished.
364func (m *Message) FinishThinking() {
365	for i, part := range m.Parts {
366		if c, ok := part.(ReasoningContent); ok {
367			if c.FinishedAt == 0 {
368				m.Parts[i] = ReasoningContent{
369					Thinking:   c.Thinking,
370					Signature:  c.Signature,
371					StartedAt:  c.StartedAt,
372					FinishedAt: time.Now().Unix(),
373				}
374			}
375			return
376		}
377	}
378}
379
380// ThinkingDuration returns the duration of the thinking phase.
381func (m *Message) ThinkingDuration() time.Duration {
382	reasoning := m.ReasoningContent()
383	if reasoning.StartedAt == 0 {
384		return 0
385	}
386
387	endTime := reasoning.FinishedAt
388	if endTime == 0 {
389		endTime = time.Now().Unix()
390	}
391
392	return time.Duration(endTime-reasoning.StartedAt) * time.Second
393}
394
395// FinishToolCall marks a tool call as finished.
396func (m *Message) FinishToolCall(toolCallID string) {
397	for i, part := range m.Parts {
398		if c, ok := part.(ToolCall); ok {
399			if c.ID == toolCallID {
400				m.Parts[i] = ToolCall{
401					ID:       c.ID,
402					Name:     c.Name,
403					Input:    c.Input,
404					Type:     c.Type,
405					Finished: true,
406				}
407				return
408			}
409		}
410	}
411}
412
413// AppendToolCallInput appends input to a tool call.
414func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
415	for i, part := range m.Parts {
416		if c, ok := part.(ToolCall); ok {
417			if c.ID == toolCallID {
418				m.Parts[i] = ToolCall{
419					ID:       c.ID,
420					Name:     c.Name,
421					Input:    c.Input + inputDelta,
422					Type:     c.Type,
423					Finished: c.Finished,
424				}
425				return
426			}
427		}
428	}
429}
430
431// AddToolCall adds or updates a tool call.
432func (m *Message) AddToolCall(tc ToolCall) {
433	for i, part := range m.Parts {
434		if c, ok := part.(ToolCall); ok {
435			if c.ID == tc.ID {
436				m.Parts[i] = tc
437				return
438			}
439		}
440	}
441	m.Parts = append(m.Parts, tc)
442}
443
444// SetToolCalls replaces all tool call parts.
445func (m *Message) SetToolCalls(tc []ToolCall) {
446	parts := make([]ContentPart, 0)
447	for _, part := range m.Parts {
448		if _, ok := part.(ToolCall); ok {
449			continue
450		}
451		parts = append(parts, part)
452	}
453	m.Parts = parts
454	for _, toolCall := range tc {
455		m.Parts = append(m.Parts, toolCall)
456	}
457}
458
459// AddToolResult adds a tool result.
460func (m *Message) AddToolResult(tr ToolResult) {
461	m.Parts = append(m.Parts, tr)
462}
463
464// SetToolResults adds multiple tool results.
465func (m *Message) SetToolResults(tr []ToolResult) {
466	for _, toolResult := range tr {
467		m.Parts = append(m.Parts, toolResult)
468	}
469}
470
471// AddFinish adds a finish part to the message.
472func (m *Message) AddFinish(reason FinishReason, message, details string) {
473	for i, part := range m.Parts {
474		if _, ok := part.(Finish); ok {
475			m.Parts = slices.Delete(m.Parts, i, i+1)
476			break
477		}
478	}
479	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
480}
481
482// AddImageURL adds an image URL part to the message.
483func (m *Message) AddImageURL(url, detail string) {
484	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
485}
486
487// AddBinary adds a binary content part to the message.
488func (m *Message) AddBinary(mimeType string, data []byte) {
489	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
490}
491
492type partType string
493
494const (
495	reasoningType  partType = "reasoning"
496	textType       partType = "text"
497	imageURLType   partType = "image_url"
498	binaryType     partType = "binary"
499	toolCallType   partType = "tool_call"
500	toolResultType partType = "tool_result"
501	finishType     partType = "finish"
502)
503
504type partWrapper struct {
505	Type partType    `json:"type"`
506	Data ContentPart `json:"data"`
507}
508
509// MarshalParts marshals content parts to JSON.
510func MarshalParts(parts []ContentPart) ([]byte, error) {
511	wrappedParts := make([]partWrapper, len(parts))
512
513	for i, part := range parts {
514		var typ partType
515
516		switch part.(type) {
517		case ReasoningContent:
518			typ = reasoningType
519		case TextContent:
520			typ = textType
521		case ImageURLContent:
522			typ = imageURLType
523		case BinaryContent:
524			typ = binaryType
525		case ToolCall:
526			typ = toolCallType
527		case ToolResult:
528			typ = toolResultType
529		case Finish:
530			typ = finishType
531		default:
532			return nil, fmt.Errorf("unknown part type: %T", part)
533		}
534
535		wrappedParts[i] = partWrapper{
536			Type: typ,
537			Data: part,
538		}
539	}
540	return json.Marshal(wrappedParts)
541}
542
543// UnmarshalParts unmarshals content parts from JSON.
544func UnmarshalParts(data []byte) ([]ContentPart, error) {
545	temp := []json.RawMessage{}
546
547	if err := json.Unmarshal(data, &temp); err != nil {
548		return nil, err
549	}
550
551	parts := make([]ContentPart, 0)
552
553	for _, rawPart := range temp {
554		var wrapper struct {
555			Type partType        `json:"type"`
556			Data json.RawMessage `json:"data"`
557		}
558
559		if err := json.Unmarshal(rawPart, &wrapper); err != nil {
560			return nil, err
561		}
562
563		switch wrapper.Type {
564		case reasoningType:
565			part := ReasoningContent{}
566			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
567				return nil, err
568			}
569			parts = append(parts, part)
570		case textType:
571			part := TextContent{}
572			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
573				return nil, err
574			}
575			parts = append(parts, part)
576		case imageURLType:
577			part := ImageURLContent{}
578			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
579				return nil, err
580			}
581			parts = append(parts, part)
582		case binaryType:
583			part := BinaryContent{}
584			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
585				return nil, err
586			}
587			parts = append(parts, part)
588		case toolCallType:
589			part := ToolCall{}
590			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
591				return nil, err
592			}
593			parts = append(parts, part)
594		case toolResultType:
595			part := ToolResult{}
596			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
597				return nil, err
598			}
599			parts = append(parts, part)
600		case finishType:
601			part := Finish{}
602			if err := json.Unmarshal(wrapper.Data, &part); err != nil {
603				return nil, err
604			}
605			parts = append(parts, part)
606		default:
607			return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
608		}
609	}
610
611	return parts, nil
612}
613
614// Attachment represents a file attachment.
615type Attachment struct {
616	FilePath string `json:"file_path"`
617	FileName string `json:"file_name"`
618	MimeType string `json:"mime_type"`
619	Content  []byte `json:"content"`
620}
621
622// MarshalJSON implements the [json.Marshaler] interface.
623func (a Attachment) MarshalJSON() ([]byte, error) {
624	type Alias Attachment
625	return json.Marshal(&struct {
626		Content string `json:"content"`
627		*Alias
628	}{
629		Content: base64.StdEncoding.EncodeToString(a.Content),
630		Alias:   (*Alias)(&a),
631	})
632}
633
634// UnmarshalJSON implements the [json.Unmarshaler] interface.
635func (a *Attachment) UnmarshalJSON(data []byte) error {
636	type Alias Attachment
637	aux := &struct {
638		Content string `json:"content"`
639		*Alias
640	}{
641		Alias: (*Alias)(a),
642	}
643	if err := json.Unmarshal(data, &aux); err != nil {
644		return err
645	}
646	content, err := base64.StdEncoding.DecodeString(aux.Content)
647	if err != nil {
648		return err
649	}
650	a.Content = content
651	return nil
652}