1package proto
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "slices"
8 "time"
9
10 "charm.land/catwalk/pkg/catwalk"
11)
12
13// CreateMessageParams represents parameters for creating a message.
14type CreateMessageParams struct {
15 Role MessageRole `json:"role"`
16 Parts []ContentPart `json:"parts"`
17 Model string `json:"model"`
18 Provider string `json:"provider,omitempty"`
19}
20
21// Message represents a message in the proto layer.
22type Message struct {
23 ID string `json:"id"`
24 Role MessageRole `json:"role"`
25 SessionID string `json:"session_id"`
26 Parts []ContentPart `json:"parts"`
27 Model string `json:"model"`
28 Provider string `json:"provider"`
29 CreatedAt int64 `json:"created_at"`
30 UpdatedAt int64 `json:"updated_at"`
31}
32
33// MessageRole represents the role of a message sender.
34type MessageRole string
35
36const (
37 Assistant MessageRole = "assistant"
38 User MessageRole = "user"
39 System MessageRole = "system"
40 Tool MessageRole = "tool"
41)
42
43// MarshalText implements the [encoding.TextMarshaler] interface.
44func (r MessageRole) MarshalText() ([]byte, error) {
45 return []byte(r), nil
46}
47
48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
49func (r *MessageRole) UnmarshalText(data []byte) error {
50 *r = MessageRole(data)
51 return nil
52}
53
54// FinishReason represents why a message generation finished.
55type FinishReason string
56
57const (
58 FinishReasonEndTurn FinishReason = "end_turn"
59 FinishReasonMaxTokens FinishReason = "max_tokens"
60 FinishReasonToolUse FinishReason = "tool_use"
61 FinishReasonCanceled FinishReason = "canceled"
62 FinishReasonError FinishReason = "error"
63 FinishReasonUnknown FinishReason = "unknown"
64)
65
66// MarshalText implements the [encoding.TextMarshaler] interface.
67func (fr FinishReason) MarshalText() ([]byte, error) {
68 return []byte(fr), nil
69}
70
71// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
72func (fr *FinishReason) UnmarshalText(data []byte) error {
73 *fr = FinishReason(data)
74 return nil
75}
76
77// ContentPart is a part of a message's content.
78type ContentPart interface {
79 isPart()
80}
81
82// ReasoningContent represents the reasoning/thinking part of a message.
83type ReasoningContent struct {
84 Thinking string `json:"thinking"`
85 Signature string `json:"signature"`
86 StartedAt int64 `json:"started_at,omitempty"`
87 FinishedAt int64 `json:"finished_at,omitempty"`
88}
89
90// String returns the thinking content as a string.
91func (tc ReasoningContent) String() string {
92 return tc.Thinking
93}
94
95func (ReasoningContent) isPart() {}
96
97// TextContent represents a text part of a message.
98type TextContent struct {
99 Text string `json:"text"`
100}
101
102// String returns the text content as a string.
103func (tc TextContent) String() string {
104 return tc.Text
105}
106
107func (TextContent) isPart() {}
108
109// ImageURLContent represents an image URL part of a message.
110type ImageURLContent struct {
111 URL string `json:"url"`
112 Detail string `json:"detail,omitempty"`
113}
114
115// String returns the image URL as a string.
116func (iuc ImageURLContent) String() string {
117 return iuc.URL
118}
119
120func (ImageURLContent) isPart() {}
121
122// BinaryContent represents binary data in a message.
123type BinaryContent struct {
124 Path string
125 MIMEType string
126 Data []byte
127}
128
129// String returns a base64-encoded string of the binary data.
130func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
131 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
132 if p == catwalk.InferenceProviderOpenAI {
133 return "data:" + bc.MIMEType + ";base64," + base64Encoded
134 }
135 return base64Encoded
136}
137
138func (BinaryContent) isPart() {}
139
140// ToolCall represents a tool call in a message.
141type ToolCall struct {
142 ID string `json:"id"`
143 Name string `json:"name"`
144 Input string `json:"input"`
145 Type string `json:"type,omitempty"`
146 Finished bool `json:"finished,omitempty"`
147}
148
149func (ToolCall) isPart() {}
150
151// ToolResult represents the result of a tool call.
152type ToolResult struct {
153 ToolCallID string `json:"tool_call_id"`
154 Name string `json:"name"`
155 Content string `json:"content"`
156 Metadata string `json:"metadata"`
157 IsError bool `json:"is_error"`
158}
159
160func (ToolResult) isPart() {}
161
162// Finish represents the end of a message generation.
163type Finish struct {
164 Reason FinishReason `json:"reason"`
165 Time int64 `json:"time"`
166 Message string `json:"message,omitempty"`
167 Details string `json:"details,omitempty"`
168}
169
170func (Finish) isPart() {}
171
172// MarshalJSON implements the [json.Marshaler] interface.
173func (m Message) MarshalJSON() ([]byte, error) {
174 parts, err := MarshalParts(m.Parts)
175 if err != nil {
176 return nil, err
177 }
178
179 type Alias Message
180 return json.Marshal(&struct {
181 Parts json.RawMessage `json:"parts"`
182 *Alias
183 }{
184 Parts: json.RawMessage(parts),
185 Alias: (*Alias)(&m),
186 })
187}
188
189// UnmarshalJSON implements the [json.Unmarshaler] interface.
190func (m *Message) UnmarshalJSON(data []byte) error {
191 type Alias Message
192 aux := &struct {
193 Parts json.RawMessage `json:"parts"`
194 *Alias
195 }{
196 Alias: (*Alias)(m),
197 }
198
199 if err := json.Unmarshal(data, &aux); err != nil {
200 return err
201 }
202
203 parts, err := UnmarshalParts([]byte(aux.Parts))
204 if err != nil {
205 return err
206 }
207
208 m.Parts = parts
209 return nil
210}
211
212// Content returns the first text content part.
213func (m *Message) Content() TextContent {
214 for _, part := range m.Parts {
215 if c, ok := part.(TextContent); ok {
216 return c
217 }
218 }
219 return TextContent{}
220}
221
222// ReasoningContent returns the first reasoning content part.
223func (m *Message) ReasoningContent() ReasoningContent {
224 for _, part := range m.Parts {
225 if c, ok := part.(ReasoningContent); ok {
226 return c
227 }
228 }
229 return ReasoningContent{}
230}
231
232// ImageURLContent returns all image URL content parts.
233func (m *Message) ImageURLContent() []ImageURLContent {
234 imageURLContents := make([]ImageURLContent, 0)
235 for _, part := range m.Parts {
236 if c, ok := part.(ImageURLContent); ok {
237 imageURLContents = append(imageURLContents, c)
238 }
239 }
240 return imageURLContents
241}
242
243// BinaryContent returns all binary content parts.
244func (m *Message) BinaryContent() []BinaryContent {
245 binaryContents := make([]BinaryContent, 0)
246 for _, part := range m.Parts {
247 if c, ok := part.(BinaryContent); ok {
248 binaryContents = append(binaryContents, c)
249 }
250 }
251 return binaryContents
252}
253
254// ToolCalls returns all tool call parts.
255func (m *Message) ToolCalls() []ToolCall {
256 toolCalls := make([]ToolCall, 0)
257 for _, part := range m.Parts {
258 if c, ok := part.(ToolCall); ok {
259 toolCalls = append(toolCalls, c)
260 }
261 }
262 return toolCalls
263}
264
265// ToolResults returns all tool result parts.
266func (m *Message) ToolResults() []ToolResult {
267 toolResults := make([]ToolResult, 0)
268 for _, part := range m.Parts {
269 if c, ok := part.(ToolResult); ok {
270 toolResults = append(toolResults, c)
271 }
272 }
273 return toolResults
274}
275
276// IsFinished returns true if the message has a finish part.
277func (m *Message) IsFinished() bool {
278 for _, part := range m.Parts {
279 if _, ok := part.(Finish); ok {
280 return true
281 }
282 }
283 return false
284}
285
286// FinishPart returns the finish part if present.
287func (m *Message) FinishPart() *Finish {
288 for _, part := range m.Parts {
289 if c, ok := part.(Finish); ok {
290 return &c
291 }
292 }
293 return nil
294}
295
296// FinishReason returns the finish reason if present.
297func (m *Message) FinishReason() FinishReason {
298 for _, part := range m.Parts {
299 if c, ok := part.(Finish); ok {
300 return c.Reason
301 }
302 }
303 return ""
304}
305
306// IsThinking returns true if the message is currently in a thinking state.
307func (m *Message) IsThinking() bool {
308 return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
309}
310
311// AppendContent appends text to the text content part.
312func (m *Message) AppendContent(delta string) {
313 found := false
314 for i, part := range m.Parts {
315 if c, ok := part.(TextContent); ok {
316 m.Parts[i] = TextContent{Text: c.Text + delta}
317 found = true
318 }
319 }
320 if !found {
321 m.Parts = append(m.Parts, TextContent{Text: delta})
322 }
323}
324
325// AppendReasoningContent appends text to the reasoning content part.
326func (m *Message) AppendReasoningContent(delta string) {
327 found := false
328 for i, part := range m.Parts {
329 if c, ok := part.(ReasoningContent); ok {
330 m.Parts[i] = ReasoningContent{
331 Thinking: c.Thinking + delta,
332 Signature: c.Signature,
333 StartedAt: c.StartedAt,
334 FinishedAt: c.FinishedAt,
335 }
336 found = true
337 }
338 }
339 if !found {
340 m.Parts = append(m.Parts, ReasoningContent{
341 Thinking: delta,
342 StartedAt: time.Now().Unix(),
343 })
344 }
345}
346
347// AppendReasoningSignature appends a signature to the reasoning content part.
348func (m *Message) AppendReasoningSignature(signature string) {
349 for i, part := range m.Parts {
350 if c, ok := part.(ReasoningContent); ok {
351 m.Parts[i] = ReasoningContent{
352 Thinking: c.Thinking,
353 Signature: c.Signature + signature,
354 StartedAt: c.StartedAt,
355 FinishedAt: c.FinishedAt,
356 }
357 return
358 }
359 }
360 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
361}
362
363// FinishThinking marks the reasoning content as finished.
364func (m *Message) FinishThinking() {
365 for i, part := range m.Parts {
366 if c, ok := part.(ReasoningContent); ok {
367 if c.FinishedAt == 0 {
368 m.Parts[i] = ReasoningContent{
369 Thinking: c.Thinking,
370 Signature: c.Signature,
371 StartedAt: c.StartedAt,
372 FinishedAt: time.Now().Unix(),
373 }
374 }
375 return
376 }
377 }
378}
379
380// ThinkingDuration returns the duration of the thinking phase.
381func (m *Message) ThinkingDuration() time.Duration {
382 reasoning := m.ReasoningContent()
383 if reasoning.StartedAt == 0 {
384 return 0
385 }
386
387 endTime := reasoning.FinishedAt
388 if endTime == 0 {
389 endTime = time.Now().Unix()
390 }
391
392 return time.Duration(endTime-reasoning.StartedAt) * time.Second
393}
394
395// FinishToolCall marks a tool call as finished.
396func (m *Message) FinishToolCall(toolCallID string) {
397 for i, part := range m.Parts {
398 if c, ok := part.(ToolCall); ok {
399 if c.ID == toolCallID {
400 m.Parts[i] = ToolCall{
401 ID: c.ID,
402 Name: c.Name,
403 Input: c.Input,
404 Type: c.Type,
405 Finished: true,
406 }
407 return
408 }
409 }
410 }
411}
412
413// AppendToolCallInput appends input to a tool call.
414func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
415 for i, part := range m.Parts {
416 if c, ok := part.(ToolCall); ok {
417 if c.ID == toolCallID {
418 m.Parts[i] = ToolCall{
419 ID: c.ID,
420 Name: c.Name,
421 Input: c.Input + inputDelta,
422 Type: c.Type,
423 Finished: c.Finished,
424 }
425 return
426 }
427 }
428 }
429}
430
431// AddToolCall adds or updates a tool call.
432func (m *Message) AddToolCall(tc ToolCall) {
433 for i, part := range m.Parts {
434 if c, ok := part.(ToolCall); ok {
435 if c.ID == tc.ID {
436 m.Parts[i] = tc
437 return
438 }
439 }
440 }
441 m.Parts = append(m.Parts, tc)
442}
443
444// SetToolCalls replaces all tool call parts.
445func (m *Message) SetToolCalls(tc []ToolCall) {
446 parts := make([]ContentPart, 0)
447 for _, part := range m.Parts {
448 if _, ok := part.(ToolCall); ok {
449 continue
450 }
451 parts = append(parts, part)
452 }
453 m.Parts = parts
454 for _, toolCall := range tc {
455 m.Parts = append(m.Parts, toolCall)
456 }
457}
458
459// AddToolResult adds a tool result.
460func (m *Message) AddToolResult(tr ToolResult) {
461 m.Parts = append(m.Parts, tr)
462}
463
464// SetToolResults adds multiple tool results.
465func (m *Message) SetToolResults(tr []ToolResult) {
466 for _, toolResult := range tr {
467 m.Parts = append(m.Parts, toolResult)
468 }
469}
470
471// AddFinish adds a finish part to the message.
472func (m *Message) AddFinish(reason FinishReason, message, details string) {
473 for i, part := range m.Parts {
474 if _, ok := part.(Finish); ok {
475 m.Parts = slices.Delete(m.Parts, i, i+1)
476 break
477 }
478 }
479 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
480}
481
482// AddImageURL adds an image URL part to the message.
483func (m *Message) AddImageURL(url, detail string) {
484 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
485}
486
487// AddBinary adds a binary content part to the message.
488func (m *Message) AddBinary(mimeType string, data []byte) {
489 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
490}
491
492type partType string
493
494const (
495 reasoningType partType = "reasoning"
496 textType partType = "text"
497 imageURLType partType = "image_url"
498 binaryType partType = "binary"
499 toolCallType partType = "tool_call"
500 toolResultType partType = "tool_result"
501 finishType partType = "finish"
502)
503
504type partWrapper struct {
505 Type partType `json:"type"`
506 Data ContentPart `json:"data"`
507}
508
509// MarshalParts marshals content parts to JSON.
510func MarshalParts(parts []ContentPart) ([]byte, error) {
511 wrappedParts := make([]partWrapper, len(parts))
512
513 for i, part := range parts {
514 var typ partType
515
516 switch part.(type) {
517 case ReasoningContent:
518 typ = reasoningType
519 case TextContent:
520 typ = textType
521 case ImageURLContent:
522 typ = imageURLType
523 case BinaryContent:
524 typ = binaryType
525 case ToolCall:
526 typ = toolCallType
527 case ToolResult:
528 typ = toolResultType
529 case Finish:
530 typ = finishType
531 default:
532 return nil, fmt.Errorf("unknown part type: %T", part)
533 }
534
535 wrappedParts[i] = partWrapper{
536 Type: typ,
537 Data: part,
538 }
539 }
540 return json.Marshal(wrappedParts)
541}
542
543// UnmarshalParts unmarshals content parts from JSON.
544func UnmarshalParts(data []byte) ([]ContentPart, error) {
545 temp := []json.RawMessage{}
546
547 if err := json.Unmarshal(data, &temp); err != nil {
548 return nil, err
549 }
550
551 parts := make([]ContentPart, 0)
552
553 for _, rawPart := range temp {
554 var wrapper struct {
555 Type partType `json:"type"`
556 Data json.RawMessage `json:"data"`
557 }
558
559 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
560 return nil, err
561 }
562
563 switch wrapper.Type {
564 case reasoningType:
565 part := ReasoningContent{}
566 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
567 return nil, err
568 }
569 parts = append(parts, part)
570 case textType:
571 part := TextContent{}
572 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
573 return nil, err
574 }
575 parts = append(parts, part)
576 case imageURLType:
577 part := ImageURLContent{}
578 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
579 return nil, err
580 }
581 parts = append(parts, part)
582 case binaryType:
583 part := BinaryContent{}
584 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
585 return nil, err
586 }
587 parts = append(parts, part)
588 case toolCallType:
589 part := ToolCall{}
590 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
591 return nil, err
592 }
593 parts = append(parts, part)
594 case toolResultType:
595 part := ToolResult{}
596 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
597 return nil, err
598 }
599 parts = append(parts, part)
600 case finishType:
601 part := Finish{}
602 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
603 return nil, err
604 }
605 parts = append(parts, part)
606 default:
607 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
608 }
609 }
610
611 return parts, nil
612}
613
614// Attachment represents a file attachment.
615type Attachment struct {
616 FilePath string `json:"file_path"`
617 FileName string `json:"file_name"`
618 MimeType string `json:"mime_type"`
619 Content []byte `json:"content"`
620}
621
622// MarshalJSON implements the [json.Marshaler] interface.
623func (a Attachment) MarshalJSON() ([]byte, error) {
624 type Alias Attachment
625 return json.Marshal(&struct {
626 Content string `json:"content"`
627 *Alias
628 }{
629 Content: base64.StdEncoding.EncodeToString(a.Content),
630 Alias: (*Alias)(&a),
631 })
632}
633
634// UnmarshalJSON implements the [json.Unmarshaler] interface.
635func (a *Attachment) UnmarshalJSON(data []byte) error {
636 type Alias Attachment
637 aux := &struct {
638 Content string `json:"content"`
639 *Alias
640 }{
641 Alias: (*Alias)(a),
642 }
643 if err := json.Unmarshal(data, &aux); err != nil {
644 return err
645 }
646 content, err := base64.StdEncoding.DecodeString(aux.Content)
647 if err != nil {
648 return err
649 }
650 a.Content = content
651 return nil
652}