1package proto
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "slices"
8 "time"
9
10 "charm.land/catwalk/pkg/catwalk"
11)
12
13// CreateMessageParams represents parameters for creating a message.
14type CreateMessageParams struct {
15 Role MessageRole `json:"role"`
16 Parts []ContentPart `json:"parts"`
17 Model string `json:"model"`
18 Provider string `json:"provider,omitempty"`
19}
20
21// Message represents a message in the proto layer.
22type Message struct {
23 ID string `json:"id"`
24 Role MessageRole `json:"role"`
25 SessionID string `json:"session_id"`
26 Parts []ContentPart `json:"parts"`
27 Model string `json:"model"`
28 Provider string `json:"provider"`
29 CreatedAt int64 `json:"created_at"`
30 UpdatedAt int64 `json:"updated_at"`
31}
32
33// MessageRole represents the role of a message sender.
34type MessageRole string
35
36const (
37 Assistant MessageRole = "assistant"
38 User MessageRole = "user"
39 System MessageRole = "system"
40 Tool MessageRole = "tool"
41)
42
43// MarshalText implements the [encoding.TextMarshaler] interface.
44func (r MessageRole) MarshalText() ([]byte, error) {
45 return []byte(r), nil
46}
47
48// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
49func (r *MessageRole) UnmarshalText(data []byte) error {
50 *r = MessageRole(data)
51 return nil
52}
53
54// FinishReason represents why a message generation finished.
55type FinishReason string
56
57const (
58 FinishReasonEndTurn FinishReason = "end_turn"
59 FinishReasonMaxTokens FinishReason = "max_tokens"
60 FinishReasonToolUse FinishReason = "tool_use"
61 FinishReasonCanceled FinishReason = "canceled"
62 FinishReasonError FinishReason = "error"
63 FinishReasonUnknown FinishReason = "unknown"
64)
65
66// MarshalText implements the [encoding.TextMarshaler] interface.
67func (fr FinishReason) MarshalText() ([]byte, error) {
68 return []byte(fr), nil
69}
70
71// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
72func (fr *FinishReason) UnmarshalText(data []byte) error {
73 *fr = FinishReason(data)
74 return nil
75}
76
77// ContentPart is a part of a message's content.
78type ContentPart interface {
79 isPart()
80}
81
82// ReasoningContent represents the reasoning/thinking part of a message.
83type ReasoningContent struct {
84 Thinking string `json:"thinking"`
85 Signature string `json:"signature"`
86 StartedAt int64 `json:"started_at,omitempty"`
87 FinishedAt int64 `json:"finished_at,omitempty"`
88}
89
90// String returns the thinking content as a string.
91func (tc ReasoningContent) String() string {
92 return tc.Thinking
93}
94
95func (ReasoningContent) isPart() {}
96
97// TextContent represents a text part of a message.
98type TextContent struct {
99 Text string `json:"text"`
100}
101
102// String returns the text content as a string.
103func (tc TextContent) String() string {
104 return tc.Text
105}
106
107func (TextContent) isPart() {}
108
109// ImageURLContent represents an image URL part of a message.
110type ImageURLContent struct {
111 URL string `json:"url"`
112 Detail string `json:"detail,omitempty"`
113}
114
115// String returns the image URL as a string.
116func (iuc ImageURLContent) String() string {
117 return iuc.URL
118}
119
120func (ImageURLContent) isPart() {}
121
122// BinaryContent represents binary data in a message.
123type BinaryContent struct {
124 Path string
125 MIMEType string
126 Data []byte
127}
128
129// String returns a base64-encoded string of the binary data.
130func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
131 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
132 if p == catwalk.InferenceProviderOpenAI {
133 return "data:" + bc.MIMEType + ";base64," + base64Encoded
134 }
135 return base64Encoded
136}
137
138func (BinaryContent) isPart() {}
139
140// ToolCall represents a tool call in a message.
141type ToolCall struct {
142 ID string `json:"id"`
143 Name string `json:"name"`
144 Input string `json:"input"`
145 Type string `json:"type,omitempty"`
146 Finished bool `json:"finished,omitempty"`
147}
148
149func (ToolCall) isPart() {}
150
151// ToolResult represents the result of a tool call.
152type ToolResult struct {
153 ToolCallID string `json:"tool_call_id"`
154 Name string `json:"name"`
155 Content string `json:"content"`
156 Data string `json:"data,omitempty"`
157 MIMEType string `json:"mime_type,omitempty"`
158 Metadata string `json:"metadata"`
159 IsError bool `json:"is_error"`
160}
161
162func (ToolResult) isPart() {}
163
164// Finish represents the end of a message generation.
165type Finish struct {
166 Reason FinishReason `json:"reason"`
167 Time int64 `json:"time"`
168 Message string `json:"message,omitempty"`
169 Details string `json:"details,omitempty"`
170}
171
172func (Finish) isPart() {}
173
174// MarshalJSON implements the [json.Marshaler] interface.
175func (m Message) MarshalJSON() ([]byte, error) {
176 parts, err := MarshalParts(m.Parts)
177 if err != nil {
178 return nil, err
179 }
180
181 type Alias Message
182 return json.Marshal(&struct {
183 Parts json.RawMessage `json:"parts"`
184 *Alias
185 }{
186 Parts: json.RawMessage(parts),
187 Alias: (*Alias)(&m),
188 })
189}
190
191// UnmarshalJSON implements the [json.Unmarshaler] interface.
192func (m *Message) UnmarshalJSON(data []byte) error {
193 type Alias Message
194 aux := &struct {
195 Parts json.RawMessage `json:"parts"`
196 *Alias
197 }{
198 Alias: (*Alias)(m),
199 }
200
201 if err := json.Unmarshal(data, &aux); err != nil {
202 return err
203 }
204
205 parts, err := UnmarshalParts([]byte(aux.Parts))
206 if err != nil {
207 return err
208 }
209
210 m.Parts = parts
211 return nil
212}
213
214// Content returns the first text content part.
215func (m *Message) Content() TextContent {
216 for _, part := range m.Parts {
217 if c, ok := part.(TextContent); ok {
218 return c
219 }
220 }
221 return TextContent{}
222}
223
224// ReasoningContent returns the first reasoning content part.
225func (m *Message) ReasoningContent() ReasoningContent {
226 for _, part := range m.Parts {
227 if c, ok := part.(ReasoningContent); ok {
228 return c
229 }
230 }
231 return ReasoningContent{}
232}
233
234// ImageURLContent returns all image URL content parts.
235func (m *Message) ImageURLContent() []ImageURLContent {
236 imageURLContents := make([]ImageURLContent, 0)
237 for _, part := range m.Parts {
238 if c, ok := part.(ImageURLContent); ok {
239 imageURLContents = append(imageURLContents, c)
240 }
241 }
242 return imageURLContents
243}
244
245// BinaryContent returns all binary content parts.
246func (m *Message) BinaryContent() []BinaryContent {
247 binaryContents := make([]BinaryContent, 0)
248 for _, part := range m.Parts {
249 if c, ok := part.(BinaryContent); ok {
250 binaryContents = append(binaryContents, c)
251 }
252 }
253 return binaryContents
254}
255
256// ToolCalls returns all tool call parts.
257func (m *Message) ToolCalls() []ToolCall {
258 toolCalls := make([]ToolCall, 0)
259 for _, part := range m.Parts {
260 if c, ok := part.(ToolCall); ok {
261 toolCalls = append(toolCalls, c)
262 }
263 }
264 return toolCalls
265}
266
267// ToolResults returns all tool result parts.
268func (m *Message) ToolResults() []ToolResult {
269 toolResults := make([]ToolResult, 0)
270 for _, part := range m.Parts {
271 if c, ok := part.(ToolResult); ok {
272 toolResults = append(toolResults, c)
273 }
274 }
275 return toolResults
276}
277
278// IsFinished returns true if the message has a finish part.
279func (m *Message) IsFinished() bool {
280 for _, part := range m.Parts {
281 if _, ok := part.(Finish); ok {
282 return true
283 }
284 }
285 return false
286}
287
288// FinishPart returns the finish part if present.
289func (m *Message) FinishPart() *Finish {
290 for _, part := range m.Parts {
291 if c, ok := part.(Finish); ok {
292 return &c
293 }
294 }
295 return nil
296}
297
298// FinishReason returns the finish reason if present.
299func (m *Message) FinishReason() FinishReason {
300 for _, part := range m.Parts {
301 if c, ok := part.(Finish); ok {
302 return c.Reason
303 }
304 }
305 return ""
306}
307
308// IsThinking returns true if the message is currently in a thinking state.
309func (m *Message) IsThinking() bool {
310 return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
311}
312
313// AppendContent appends text to the text content part.
314func (m *Message) AppendContent(delta string) {
315 found := false
316 for i, part := range m.Parts {
317 if c, ok := part.(TextContent); ok {
318 m.Parts[i] = TextContent{Text: c.Text + delta}
319 found = true
320 }
321 }
322 if !found {
323 m.Parts = append(m.Parts, TextContent{Text: delta})
324 }
325}
326
327// AppendReasoningContent appends text to the reasoning content part.
328func (m *Message) AppendReasoningContent(delta string) {
329 found := false
330 for i, part := range m.Parts {
331 if c, ok := part.(ReasoningContent); ok {
332 m.Parts[i] = ReasoningContent{
333 Thinking: c.Thinking + delta,
334 Signature: c.Signature,
335 StartedAt: c.StartedAt,
336 FinishedAt: c.FinishedAt,
337 }
338 found = true
339 }
340 }
341 if !found {
342 m.Parts = append(m.Parts, ReasoningContent{
343 Thinking: delta,
344 StartedAt: time.Now().Unix(),
345 })
346 }
347}
348
349// AppendReasoningSignature appends a signature to the reasoning content part.
350func (m *Message) AppendReasoningSignature(signature string) {
351 for i, part := range m.Parts {
352 if c, ok := part.(ReasoningContent); ok {
353 m.Parts[i] = ReasoningContent{
354 Thinking: c.Thinking,
355 Signature: c.Signature + signature,
356 StartedAt: c.StartedAt,
357 FinishedAt: c.FinishedAt,
358 }
359 return
360 }
361 }
362 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
363}
364
365// FinishThinking marks the reasoning content as finished.
366func (m *Message) FinishThinking() {
367 for i, part := range m.Parts {
368 if c, ok := part.(ReasoningContent); ok {
369 if c.FinishedAt == 0 {
370 m.Parts[i] = ReasoningContent{
371 Thinking: c.Thinking,
372 Signature: c.Signature,
373 StartedAt: c.StartedAt,
374 FinishedAt: time.Now().Unix(),
375 }
376 }
377 return
378 }
379 }
380}
381
382// ThinkingDuration returns the duration of the thinking phase.
383func (m *Message) ThinkingDuration() time.Duration {
384 reasoning := m.ReasoningContent()
385 if reasoning.StartedAt == 0 {
386 return 0
387 }
388
389 endTime := reasoning.FinishedAt
390 if endTime == 0 {
391 endTime = time.Now().Unix()
392 }
393
394 return time.Duration(endTime-reasoning.StartedAt) * time.Second
395}
396
397// FinishToolCall marks a tool call as finished.
398func (m *Message) FinishToolCall(toolCallID string) {
399 for i, part := range m.Parts {
400 if c, ok := part.(ToolCall); ok {
401 if c.ID == toolCallID {
402 m.Parts[i] = ToolCall{
403 ID: c.ID,
404 Name: c.Name,
405 Input: c.Input,
406 Type: c.Type,
407 Finished: true,
408 }
409 return
410 }
411 }
412 }
413}
414
415// AppendToolCallInput appends input to a tool call.
416func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
417 for i, part := range m.Parts {
418 if c, ok := part.(ToolCall); ok {
419 if c.ID == toolCallID {
420 m.Parts[i] = ToolCall{
421 ID: c.ID,
422 Name: c.Name,
423 Input: c.Input + inputDelta,
424 Type: c.Type,
425 Finished: c.Finished,
426 }
427 return
428 }
429 }
430 }
431}
432
433// AddToolCall adds or updates a tool call.
434func (m *Message) AddToolCall(tc ToolCall) {
435 for i, part := range m.Parts {
436 if c, ok := part.(ToolCall); ok {
437 if c.ID == tc.ID {
438 m.Parts[i] = tc
439 return
440 }
441 }
442 }
443 m.Parts = append(m.Parts, tc)
444}
445
446// SetToolCalls replaces all tool call parts.
447func (m *Message) SetToolCalls(tc []ToolCall) {
448 parts := make([]ContentPart, 0)
449 for _, part := range m.Parts {
450 if _, ok := part.(ToolCall); ok {
451 continue
452 }
453 parts = append(parts, part)
454 }
455 m.Parts = parts
456 for _, toolCall := range tc {
457 m.Parts = append(m.Parts, toolCall)
458 }
459}
460
461// AddToolResult adds a tool result.
462func (m *Message) AddToolResult(tr ToolResult) {
463 m.Parts = append(m.Parts, tr)
464}
465
466// SetToolResults adds multiple tool results.
467func (m *Message) SetToolResults(tr []ToolResult) {
468 for _, toolResult := range tr {
469 m.Parts = append(m.Parts, toolResult)
470 }
471}
472
473// AddFinish adds a finish part to the message.
474func (m *Message) AddFinish(reason FinishReason, message, details string) {
475 for i, part := range m.Parts {
476 if _, ok := part.(Finish); ok {
477 m.Parts = slices.Delete(m.Parts, i, i+1)
478 break
479 }
480 }
481 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
482}
483
484// AddImageURL adds an image URL part to the message.
485func (m *Message) AddImageURL(url, detail string) {
486 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
487}
488
489// AddBinary adds a binary content part to the message.
490func (m *Message) AddBinary(mimeType string, data []byte) {
491 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
492}
493
494type partType string
495
496const (
497 reasoningType partType = "reasoning"
498 textType partType = "text"
499 imageURLType partType = "image_url"
500 binaryType partType = "binary"
501 toolCallType partType = "tool_call"
502 toolResultType partType = "tool_result"
503 finishType partType = "finish"
504)
505
506type partWrapper struct {
507 Type partType `json:"type"`
508 Data ContentPart `json:"data"`
509}
510
511// MarshalParts marshals content parts to JSON.
512func MarshalParts(parts []ContentPart) ([]byte, error) {
513 wrappedParts := make([]partWrapper, len(parts))
514
515 for i, part := range parts {
516 var typ partType
517
518 switch part.(type) {
519 case ReasoningContent:
520 typ = reasoningType
521 case TextContent:
522 typ = textType
523 case ImageURLContent:
524 typ = imageURLType
525 case BinaryContent:
526 typ = binaryType
527 case ToolCall:
528 typ = toolCallType
529 case ToolResult:
530 typ = toolResultType
531 case Finish:
532 typ = finishType
533 default:
534 return nil, fmt.Errorf("unknown part type: %T", part)
535 }
536
537 wrappedParts[i] = partWrapper{
538 Type: typ,
539 Data: part,
540 }
541 }
542 return json.Marshal(wrappedParts)
543}
544
545// UnmarshalParts unmarshals content parts from JSON.
546func UnmarshalParts(data []byte) ([]ContentPart, error) {
547 temp := []json.RawMessage{}
548
549 if err := json.Unmarshal(data, &temp); err != nil {
550 return nil, err
551 }
552
553 parts := make([]ContentPart, 0)
554
555 for _, rawPart := range temp {
556 var wrapper struct {
557 Type partType `json:"type"`
558 Data json.RawMessage `json:"data"`
559 }
560
561 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
562 return nil, err
563 }
564
565 switch wrapper.Type {
566 case reasoningType:
567 part := ReasoningContent{}
568 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
569 return nil, err
570 }
571 parts = append(parts, part)
572 case textType:
573 part := TextContent{}
574 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
575 return nil, err
576 }
577 parts = append(parts, part)
578 case imageURLType:
579 part := ImageURLContent{}
580 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
581 return nil, err
582 }
583 parts = append(parts, part)
584 case binaryType:
585 part := BinaryContent{}
586 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
587 return nil, err
588 }
589 parts = append(parts, part)
590 case toolCallType:
591 part := ToolCall{}
592 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
593 return nil, err
594 }
595 parts = append(parts, part)
596 case toolResultType:
597 part := ToolResult{}
598 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
599 return nil, err
600 }
601 parts = append(parts, part)
602 case finishType:
603 part := Finish{}
604 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
605 return nil, err
606 }
607 parts = append(parts, part)
608 default:
609 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
610 }
611 }
612
613 return parts, nil
614}
615
616// Attachment represents a file attachment.
617type Attachment struct {
618 FilePath string `json:"file_path"`
619 FileName string `json:"file_name"`
620 MimeType string `json:"mime_type"`
621 Content []byte `json:"content"`
622}
623
624// MarshalJSON implements the [json.Marshaler] interface.
625func (a Attachment) MarshalJSON() ([]byte, error) {
626 type Alias Attachment
627 return json.Marshal(&struct {
628 Content string `json:"content"`
629 *Alias
630 }{
631 Content: base64.StdEncoding.EncodeToString(a.Content),
632 Alias: (*Alias)(&a),
633 })
634}
635
636// UnmarshalJSON implements the [json.Unmarshaler] interface.
637func (a *Attachment) UnmarshalJSON(data []byte) error {
638 type Alias Attachment
639 aux := &struct {
640 Content string `json:"content"`
641 *Alias
642 }{
643 Alias: (*Alias)(a),
644 }
645 if err := json.Unmarshal(data, &aux); err != nil {
646 return err
647 }
648 content, err := base64.StdEncoding.DecodeString(aux.Content)
649 if err != nil {
650 return err
651 }
652 a.Content = content
653 return nil
654}