1package proto
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "slices"
8 "time"
9
10 "charm.land/catwalk/pkg/catwalk"
11)
12
13// CreateMessageParams represents parameters for creating a message.
14type CreateMessageParams struct {
15 Role MessageRole `json:"role"`
16 Parts []ContentPart `json:"parts"`
17 Model string `json:"model"`
18 Provider string `json:"provider,omitempty"`
19}
20
21// Message represents a message in the proto layer.
22type Message struct {
23 ID string `json:"id"`
24 Role MessageRole `json:"role"`
25 SessionID string `json:"session_id"`
26 Parts []ContentPart `json:"parts"`
27 Model string `json:"model"`
28 Provider string `json:"provider"`
29 CreatedAt int64 `json:"created_at"`
30 UpdatedAt int64 `json:"updated_at"`
31}
32
33// MessageRole represents the role of a message sender.
34type MessageRole string
35
36const (
37 Assistant MessageRole = "assistant"
38 User MessageRole = "user"
39 System MessageRole = "system"
40 Tool MessageRole = "tool"
41)
42
43// MarshalText implements the [encoding.TextMarshaler] interface.
44func (r MessageRole) MarshalText() ([]byte, error) {
45 return []byte(r), nil
46}
47
48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
49func (r *MessageRole) UnmarshalText(data []byte) error {
50 *r = MessageRole(data)
51 return nil
52}
53
54// FinishReason represents why a message generation finished.
55type FinishReason string
56
57const (
58 FinishReasonEndTurn FinishReason = "end_turn"
59 FinishReasonMaxTokens FinishReason = "max_tokens"
60 FinishReasonToolUse FinishReason = "tool_use"
61 FinishReasonCanceled FinishReason = "canceled"
62 FinishReasonError FinishReason = "error"
63 FinishReasonPermissionDenied FinishReason = "permission_denied"
64 FinishReasonUnknown FinishReason = "unknown"
65)
66
67// MarshalText implements the [encoding.TextMarshaler] interface.
68func (fr FinishReason) MarshalText() ([]byte, error) {
69 return []byte(fr), nil
70}
71
72// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
73func (fr *FinishReason) UnmarshalText(data []byte) error {
74 *fr = FinishReason(data)
75 return nil
76}
77
78// ContentPart is a part of a message's content.
79type ContentPart interface {
80 isPart()
81}
82
83// ReasoningContent represents the reasoning/thinking part of a message.
84type ReasoningContent struct {
85 Thinking string `json:"thinking"`
86 Signature string `json:"signature"`
87 StartedAt int64 `json:"started_at,omitempty"`
88 FinishedAt int64 `json:"finished_at,omitempty"`
89}
90
91// String returns the thinking content as a string.
92func (tc ReasoningContent) String() string {
93 return tc.Thinking
94}
95
96func (ReasoningContent) isPart() {}
97
98// TextContent represents a text part of a message.
99type TextContent struct {
100 Text string `json:"text"`
101}
102
103// String returns the text content as a string.
104func (tc TextContent) String() string {
105 return tc.Text
106}
107
108func (TextContent) isPart() {}
109
110// ImageURLContent represents an image URL part of a message.
111type ImageURLContent struct {
112 URL string `json:"url"`
113 Detail string `json:"detail,omitempty"`
114}
115
116// String returns the image URL as a string.
117func (iuc ImageURLContent) String() string {
118 return iuc.URL
119}
120
121func (ImageURLContent) isPart() {}
122
123// BinaryContent represents binary data in a message.
124type BinaryContent struct {
125 Path string
126 MIMEType string
127 Data []byte
128}
129
130// String returns a base64-encoded string of the binary data.
131func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
132 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
133 if p == catwalk.InferenceProviderOpenAI {
134 return "data:" + bc.MIMEType + ";base64," + base64Encoded
135 }
136 return base64Encoded
137}
138
139func (BinaryContent) isPart() {}
140
141// ToolCall represents a tool call in a message.
142type ToolCall struct {
143 ID string `json:"id"`
144 Name string `json:"name"`
145 Input string `json:"input"`
146 Type string `json:"type,omitempty"`
147 Finished bool `json:"finished,omitempty"`
148}
149
150func (ToolCall) isPart() {}
151
152// ToolResult represents the result of a tool call.
153type ToolResult struct {
154 ToolCallID string `json:"tool_call_id"`
155 Name string `json:"name"`
156 Content string `json:"content"`
157 Metadata string `json:"metadata"`
158 IsError bool `json:"is_error"`
159}
160
161func (ToolResult) isPart() {}
162
163// Finish represents the end of a message generation.
164type Finish struct {
165 Reason FinishReason `json:"reason"`
166 Time int64 `json:"time"`
167 Message string `json:"message,omitempty"`
168 Details string `json:"details,omitempty"`
169}
170
171func (Finish) isPart() {}
172
173// MarshalJSON implements the [json.Marshaler] interface.
174func (m Message) MarshalJSON() ([]byte, error) {
175 parts, err := MarshalParts(m.Parts)
176 if err != nil {
177 return nil, err
178 }
179
180 type Alias Message
181 return json.Marshal(&struct {
182 Parts json.RawMessage `json:"parts"`
183 *Alias
184 }{
185 Parts: json.RawMessage(parts),
186 Alias: (*Alias)(&m),
187 })
188}
189
190// UnmarshalJSON implements the [json.Unmarshaler] interface.
191func (m *Message) UnmarshalJSON(data []byte) error {
192 type Alias Message
193 aux := &struct {
194 Parts json.RawMessage `json:"parts"`
195 *Alias
196 }{
197 Alias: (*Alias)(m),
198 }
199
200 if err := json.Unmarshal(data, &aux); err != nil {
201 return err
202 }
203
204 parts, err := UnmarshalParts([]byte(aux.Parts))
205 if err != nil {
206 return err
207 }
208
209 m.Parts = parts
210 return nil
211}
212
213// Content returns the first text content part.
214func (m *Message) Content() TextContent {
215 for _, part := range m.Parts {
216 if c, ok := part.(TextContent); ok {
217 return c
218 }
219 }
220 return TextContent{}
221}
222
223// ReasoningContent returns the first reasoning content part.
224func (m *Message) ReasoningContent() ReasoningContent {
225 for _, part := range m.Parts {
226 if c, ok := part.(ReasoningContent); ok {
227 return c
228 }
229 }
230 return ReasoningContent{}
231}
232
233// ImageURLContent returns all image URL content parts.
234func (m *Message) ImageURLContent() []ImageURLContent {
235 imageURLContents := make([]ImageURLContent, 0)
236 for _, part := range m.Parts {
237 if c, ok := part.(ImageURLContent); ok {
238 imageURLContents = append(imageURLContents, c)
239 }
240 }
241 return imageURLContents
242}
243
244// BinaryContent returns all binary content parts.
245func (m *Message) BinaryContent() []BinaryContent {
246 binaryContents := make([]BinaryContent, 0)
247 for _, part := range m.Parts {
248 if c, ok := part.(BinaryContent); ok {
249 binaryContents = append(binaryContents, c)
250 }
251 }
252 return binaryContents
253}
254
255// ToolCalls returns all tool call parts.
256func (m *Message) ToolCalls() []ToolCall {
257 toolCalls := make([]ToolCall, 0)
258 for _, part := range m.Parts {
259 if c, ok := part.(ToolCall); ok {
260 toolCalls = append(toolCalls, c)
261 }
262 }
263 return toolCalls
264}
265
266// ToolResults returns all tool result parts.
267func (m *Message) ToolResults() []ToolResult {
268 toolResults := make([]ToolResult, 0)
269 for _, part := range m.Parts {
270 if c, ok := part.(ToolResult); ok {
271 toolResults = append(toolResults, c)
272 }
273 }
274 return toolResults
275}
276
277// IsFinished returns true if the message has a finish part.
278func (m *Message) IsFinished() bool {
279 for _, part := range m.Parts {
280 if _, ok := part.(Finish); ok {
281 return true
282 }
283 }
284 return false
285}
286
287// FinishPart returns the finish part if present.
288func (m *Message) FinishPart() *Finish {
289 for _, part := range m.Parts {
290 if c, ok := part.(Finish); ok {
291 return &c
292 }
293 }
294 return nil
295}
296
297// FinishReason returns the finish reason if present.
298func (m *Message) FinishReason() FinishReason {
299 for _, part := range m.Parts {
300 if c, ok := part.(Finish); ok {
301 return c.Reason
302 }
303 }
304 return ""
305}
306
307// IsThinking returns true if the message is currently in a thinking state.
308func (m *Message) IsThinking() bool {
309 return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
310}
311
312// AppendContent appends text to the text content part.
313func (m *Message) AppendContent(delta string) {
314 found := false
315 for i, part := range m.Parts {
316 if c, ok := part.(TextContent); ok {
317 m.Parts[i] = TextContent{Text: c.Text + delta}
318 found = true
319 }
320 }
321 if !found {
322 m.Parts = append(m.Parts, TextContent{Text: delta})
323 }
324}
325
326// AppendReasoningContent appends text to the reasoning content part.
327func (m *Message) AppendReasoningContent(delta string) {
328 found := false
329 for i, part := range m.Parts {
330 if c, ok := part.(ReasoningContent); ok {
331 m.Parts[i] = ReasoningContent{
332 Thinking: c.Thinking + delta,
333 Signature: c.Signature,
334 StartedAt: c.StartedAt,
335 FinishedAt: c.FinishedAt,
336 }
337 found = true
338 }
339 }
340 if !found {
341 m.Parts = append(m.Parts, ReasoningContent{
342 Thinking: delta,
343 StartedAt: time.Now().Unix(),
344 })
345 }
346}
347
348// AppendReasoningSignature appends a signature to the reasoning content part.
349func (m *Message) AppendReasoningSignature(signature string) {
350 for i, part := range m.Parts {
351 if c, ok := part.(ReasoningContent); ok {
352 m.Parts[i] = ReasoningContent{
353 Thinking: c.Thinking,
354 Signature: c.Signature + signature,
355 StartedAt: c.StartedAt,
356 FinishedAt: c.FinishedAt,
357 }
358 return
359 }
360 }
361 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
362}
363
364// FinishThinking marks the reasoning content as finished.
365func (m *Message) FinishThinking() {
366 for i, part := range m.Parts {
367 if c, ok := part.(ReasoningContent); ok {
368 if c.FinishedAt == 0 {
369 m.Parts[i] = ReasoningContent{
370 Thinking: c.Thinking,
371 Signature: c.Signature,
372 StartedAt: c.StartedAt,
373 FinishedAt: time.Now().Unix(),
374 }
375 }
376 return
377 }
378 }
379}
380
381// ThinkingDuration returns the duration of the thinking phase.
382func (m *Message) ThinkingDuration() time.Duration {
383 reasoning := m.ReasoningContent()
384 if reasoning.StartedAt == 0 {
385 return 0
386 }
387
388 endTime := reasoning.FinishedAt
389 if endTime == 0 {
390 endTime = time.Now().Unix()
391 }
392
393 return time.Duration(endTime-reasoning.StartedAt) * time.Second
394}
395
396// FinishToolCall marks a tool call as finished.
397func (m *Message) FinishToolCall(toolCallID string) {
398 for i, part := range m.Parts {
399 if c, ok := part.(ToolCall); ok {
400 if c.ID == toolCallID {
401 m.Parts[i] = ToolCall{
402 ID: c.ID,
403 Name: c.Name,
404 Input: c.Input,
405 Type: c.Type,
406 Finished: true,
407 }
408 return
409 }
410 }
411 }
412}
413
414// AppendToolCallInput appends input to a tool call.
415func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
416 for i, part := range m.Parts {
417 if c, ok := part.(ToolCall); ok {
418 if c.ID == toolCallID {
419 m.Parts[i] = ToolCall{
420 ID: c.ID,
421 Name: c.Name,
422 Input: c.Input + inputDelta,
423 Type: c.Type,
424 Finished: c.Finished,
425 }
426 return
427 }
428 }
429 }
430}
431
432// AddToolCall adds or updates a tool call.
433func (m *Message) AddToolCall(tc ToolCall) {
434 for i, part := range m.Parts {
435 if c, ok := part.(ToolCall); ok {
436 if c.ID == tc.ID {
437 m.Parts[i] = tc
438 return
439 }
440 }
441 }
442 m.Parts = append(m.Parts, tc)
443}
444
445// SetToolCalls replaces all tool call parts.
446func (m *Message) SetToolCalls(tc []ToolCall) {
447 parts := make([]ContentPart, 0)
448 for _, part := range m.Parts {
449 if _, ok := part.(ToolCall); ok {
450 continue
451 }
452 parts = append(parts, part)
453 }
454 m.Parts = parts
455 for _, toolCall := range tc {
456 m.Parts = append(m.Parts, toolCall)
457 }
458}
459
460// AddToolResult adds a tool result.
461func (m *Message) AddToolResult(tr ToolResult) {
462 m.Parts = append(m.Parts, tr)
463}
464
465// SetToolResults adds multiple tool results.
466func (m *Message) SetToolResults(tr []ToolResult) {
467 for _, toolResult := range tr {
468 m.Parts = append(m.Parts, toolResult)
469 }
470}
471
472// AddFinish adds a finish part to the message.
473func (m *Message) AddFinish(reason FinishReason, message, details string) {
474 for i, part := range m.Parts {
475 if _, ok := part.(Finish); ok {
476 m.Parts = slices.Delete(m.Parts, i, i+1)
477 break
478 }
479 }
480 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
481}
482
483// AddImageURL adds an image URL part to the message.
484func (m *Message) AddImageURL(url, detail string) {
485 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
486}
487
488// AddBinary adds a binary content part to the message.
489func (m *Message) AddBinary(mimeType string, data []byte) {
490 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
491}
492
493type partType string
494
495const (
496 reasoningType partType = "reasoning"
497 textType partType = "text"
498 imageURLType partType = "image_url"
499 binaryType partType = "binary"
500 toolCallType partType = "tool_call"
501 toolResultType partType = "tool_result"
502 finishType partType = "finish"
503)
504
505type partWrapper struct {
506 Type partType `json:"type"`
507 Data ContentPart `json:"data"`
508}
509
510// MarshalParts marshals content parts to JSON.
511func MarshalParts(parts []ContentPart) ([]byte, error) {
512 wrappedParts := make([]partWrapper, len(parts))
513
514 for i, part := range parts {
515 var typ partType
516
517 switch part.(type) {
518 case ReasoningContent:
519 typ = reasoningType
520 case TextContent:
521 typ = textType
522 case ImageURLContent:
523 typ = imageURLType
524 case BinaryContent:
525 typ = binaryType
526 case ToolCall:
527 typ = toolCallType
528 case ToolResult:
529 typ = toolResultType
530 case Finish:
531 typ = finishType
532 default:
533 return nil, fmt.Errorf("unknown part type: %T", part)
534 }
535
536 wrappedParts[i] = partWrapper{
537 Type: typ,
538 Data: part,
539 }
540 }
541 return json.Marshal(wrappedParts)
542}
543
544// UnmarshalParts unmarshals content parts from JSON.
545func UnmarshalParts(data []byte) ([]ContentPart, error) {
546 temp := []json.RawMessage{}
547
548 if err := json.Unmarshal(data, &temp); err != nil {
549 return nil, err
550 }
551
552 parts := make([]ContentPart, 0)
553
554 for _, rawPart := range temp {
555 var wrapper struct {
556 Type partType `json:"type"`
557 Data json.RawMessage `json:"data"`
558 }
559
560 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
561 return nil, err
562 }
563
564 switch wrapper.Type {
565 case reasoningType:
566 part := ReasoningContent{}
567 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
568 return nil, err
569 }
570 parts = append(parts, part)
571 case textType:
572 part := TextContent{}
573 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
574 return nil, err
575 }
576 parts = append(parts, part)
577 case imageURLType:
578 part := ImageURLContent{}
579 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
580 return nil, err
581 }
582 parts = append(parts, part)
583 case binaryType:
584 part := BinaryContent{}
585 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
586 return nil, err
587 }
588 parts = append(parts, part)
589 case toolCallType:
590 part := ToolCall{}
591 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
592 return nil, err
593 }
594 parts = append(parts, part)
595 case toolResultType:
596 part := ToolResult{}
597 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
598 return nil, err
599 }
600 parts = append(parts, part)
601 case finishType:
602 part := Finish{}
603 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
604 return nil, err
605 }
606 parts = append(parts, part)
607 default:
608 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
609 }
610 }
611
612 return parts, nil
613}
614
615// Attachment represents a file attachment.
616type Attachment struct {
617 FilePath string `json:"file_path"`
618 FileName string `json:"file_name"`
619 MimeType string `json:"mime_type"`
620 Content []byte `json:"content"`
621}
622
623// MarshalJSON implements the [json.Marshaler] interface.
624func (a Attachment) MarshalJSON() ([]byte, error) {
625 type Alias Attachment
626 return json.Marshal(&struct {
627 Content string `json:"content"`
628 *Alias
629 }{
630 Content: base64.StdEncoding.EncodeToString(a.Content),
631 Alias: (*Alias)(&a),
632 })
633}
634
635// UnmarshalJSON implements the [json.Unmarshaler] interface.
636func (a *Attachment) UnmarshalJSON(data []byte) error {
637 type Alias Attachment
638 aux := &struct {
639 Content string `json:"content"`
640 *Alias
641 }{
642 Alias: (*Alias)(a),
643 }
644 if err := json.Unmarshal(data, &aux); err != nil {
645 return err
646 }
647 content, err := base64.StdEncoding.DecodeString(aux.Content)
648 if err != nil {
649 return err
650 }
651 a.Content = content
652 return nil
653}