1package proto
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "slices"
8 "time"
9
10 "charm.land/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14// CreateMessageParams represents parameters for creating a message.
15type CreateMessageParams struct {
16 Role MessageRole `json:"role"`
17 Parts []ContentPart `json:"parts"`
18 Model string `json:"model"`
19 Provider string `json:"provider,omitempty"`
20}
21
22// Message represents a message in the proto layer.
23type Message struct {
24 ID string `json:"id"`
25 Role MessageRole `json:"role"`
26 SessionID string `json:"session_id"`
27 Parts []ContentPart `json:"parts"`
28 Model string `json:"model"`
29 Provider string `json:"provider"`
30 CreatedAt int64 `json:"created_at"`
31 UpdatedAt int64 `json:"updated_at"`
32}
33
34// MessageRole represents the role of a message sender.
35type MessageRole string
36
37const (
38 Assistant MessageRole = "assistant"
39 User MessageRole = "user"
40 System MessageRole = "system"
41 Tool MessageRole = "tool"
42)
43
44// MarshalText implements the [encoding.TextMarshaler] interface.
45func (r MessageRole) MarshalText() ([]byte, error) {
46 return []byte(r), nil
47}
48
49// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
50func (r *MessageRole) UnmarshalText(data []byte) error {
51 *r = MessageRole(data)
52 return nil
53}
54
55// FinishReason represents why a message generation finished.
56type FinishReason string
57
58const (
59 FinishReasonEndTurn FinishReason = "end_turn"
60 FinishReasonMaxTokens FinishReason = "max_tokens"
61 FinishReasonToolUse FinishReason = "tool_use"
62 FinishReasonCanceled FinishReason = "canceled"
63 FinishReasonError FinishReason = "error"
64 FinishReasonUnknown FinishReason = "unknown"
65)
66
67// MarshalText implements the [encoding.TextMarshaler] interface.
68func (fr FinishReason) MarshalText() ([]byte, error) {
69 return []byte(fr), nil
70}
71
72// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
73func (fr *FinishReason) UnmarshalText(data []byte) error {
74 *fr = FinishReason(data)
75 return nil
76}
77
78// ContentPart is a part of a message's content.
79type ContentPart interface {
80 isPart()
81}
82
83// ReasoningContent represents the reasoning/thinking part of a message.
84type ReasoningContent struct {
85 Thinking string `json:"thinking"`
86 Signature string `json:"signature"`
87 StartedAt int64 `json:"started_at,omitempty"`
88 FinishedAt int64 `json:"finished_at,omitempty"`
89}
90
91// String returns the thinking content as a string.
92func (tc ReasoningContent) String() string {
93 return tc.Thinking
94}
95
96func (ReasoningContent) isPart() {}
97
98// TextContent represents a text part of a message.
99type TextContent struct {
100 Text string `json:"text"`
101}
102
103// String returns the text content as a string.
104func (tc TextContent) String() string {
105 return tc.Text
106}
107
108func (TextContent) isPart() {}
109
110// ImageURLContent represents an image URL part of a message.
111type ImageURLContent struct {
112 URL string `json:"url"`
113 Detail string `json:"detail,omitempty"`
114}
115
116// String returns the image URL as a string.
117func (iuc ImageURLContent) String() string {
118 return iuc.URL
119}
120
121func (ImageURLContent) isPart() {}
122
123// BinaryContent represents binary data in a message.
124type BinaryContent struct {
125 Path string
126 MIMEType string
127 Data []byte
128}
129
130// String returns a base64-encoded string of the binary data.
131func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
132 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
133 if p == catwalk.InferenceProviderOpenAI {
134 return "data:" + bc.MIMEType + ";base64," + base64Encoded
135 }
136 return base64Encoded
137}
138
139func (BinaryContent) isPart() {}
140
141// ToolCall represents a tool call in a message.
142type ToolCall struct {
143 ID string `json:"id"`
144 Name string `json:"name"`
145 Input string `json:"input"`
146 Type string `json:"type,omitempty"`
147 Finished bool `json:"finished,omitempty"`
148}
149
150func (ToolCall) isPart() {}
151
152// ToolResult represents the result of a tool call.
153type ToolResult struct {
154 ToolCallID string `json:"tool_call_id"`
155 Name string `json:"name"`
156 Content string `json:"content"`
157 Data string `json:"data,omitempty"`
158 MIMEType string `json:"mime_type,omitempty"`
159 Metadata string `json:"metadata"`
160 IsError bool `json:"is_error"`
161}
162
163func (ToolResult) isPart() {}
164
165// Finish represents the end of a message generation.
166type Finish struct {
167 Reason FinishReason `json:"reason"`
168 Time int64 `json:"time"`
169 Message string `json:"message,omitempty"`
170 Details string `json:"details,omitempty"`
171}
172
173func (Finish) isPart() {}
174
175// MarshalJSON implements the [json.Marshaler] interface.
176func (m Message) MarshalJSON() ([]byte, error) {
177 parts, err := MarshalParts(m.Parts)
178 if err != nil {
179 return nil, err
180 }
181
182 type Alias Message
183 return json.Marshal(&struct {
184 Parts json.RawMessage `json:"parts"`
185 *Alias
186 }{
187 Parts: json.RawMessage(parts),
188 Alias: (*Alias)(&m),
189 })
190}
191
192// UnmarshalJSON implements the [json.Unmarshaler] interface.
193func (m *Message) UnmarshalJSON(data []byte) error {
194 type Alias Message
195 aux := &struct {
196 Parts json.RawMessage `json:"parts"`
197 *Alias
198 }{
199 Alias: (*Alias)(m),
200 }
201
202 if err := json.Unmarshal(data, &aux); err != nil {
203 return err
204 }
205
206 parts, err := UnmarshalParts([]byte(aux.Parts))
207 if err != nil {
208 return err
209 }
210
211 m.Parts = parts
212 return nil
213}
214
215// Content returns the first text content part.
216func (m *Message) Content() TextContent {
217 for _, part := range m.Parts {
218 if c, ok := part.(TextContent); ok {
219 return c
220 }
221 }
222 return TextContent{}
223}
224
225// ReasoningContent returns the first reasoning content part.
226func (m *Message) ReasoningContent() ReasoningContent {
227 for _, part := range m.Parts {
228 if c, ok := part.(ReasoningContent); ok {
229 return c
230 }
231 }
232 return ReasoningContent{}
233}
234
235// ImageURLContent returns all image URL content parts.
236func (m *Message) ImageURLContent() []ImageURLContent {
237 imageURLContents := make([]ImageURLContent, 0)
238 for _, part := range m.Parts {
239 if c, ok := part.(ImageURLContent); ok {
240 imageURLContents = append(imageURLContents, c)
241 }
242 }
243 return imageURLContents
244}
245
246// BinaryContent returns all binary content parts.
247func (m *Message) BinaryContent() []BinaryContent {
248 binaryContents := make([]BinaryContent, 0)
249 for _, part := range m.Parts {
250 if c, ok := part.(BinaryContent); ok {
251 binaryContents = append(binaryContents, c)
252 }
253 }
254 return binaryContents
255}
256
257// ToolCalls returns all tool call parts.
258func (m *Message) ToolCalls() []ToolCall {
259 toolCalls := make([]ToolCall, 0)
260 for _, part := range m.Parts {
261 if c, ok := part.(ToolCall); ok {
262 toolCalls = append(toolCalls, c)
263 }
264 }
265 return toolCalls
266}
267
268// ToolResults returns all tool result parts.
269func (m *Message) ToolResults() []ToolResult {
270 toolResults := make([]ToolResult, 0)
271 for _, part := range m.Parts {
272 if c, ok := part.(ToolResult); ok {
273 toolResults = append(toolResults, c)
274 }
275 }
276 return toolResults
277}
278
279// IsFinished returns true if the message has a finish part.
280func (m *Message) IsFinished() bool {
281 for _, part := range m.Parts {
282 if _, ok := part.(Finish); ok {
283 return true
284 }
285 }
286 return false
287}
288
289// FinishPart returns the finish part if present.
290func (m *Message) FinishPart() *Finish {
291 for _, part := range m.Parts {
292 if c, ok := part.(Finish); ok {
293 return &c
294 }
295 }
296 return nil
297}
298
299// FinishReason returns the finish reason if present.
300func (m *Message) FinishReason() FinishReason {
301 for _, part := range m.Parts {
302 if c, ok := part.(Finish); ok {
303 return c.Reason
304 }
305 }
306 return ""
307}
308
309// IsThinking returns true if the message is currently in a thinking state.
310func (m *Message) IsThinking() bool {
311 return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
312}
313
314// AppendContent appends text to the text content part.
315func (m *Message) AppendContent(delta string) {
316 found := false
317 for i, part := range m.Parts {
318 if c, ok := part.(TextContent); ok {
319 m.Parts[i] = TextContent{Text: c.Text + delta}
320 found = true
321 }
322 }
323 if !found {
324 m.Parts = append(m.Parts, TextContent{Text: delta})
325 }
326}
327
328// AppendReasoningContent appends text to the reasoning content part.
329func (m *Message) AppendReasoningContent(delta string) {
330 found := false
331 for i, part := range m.Parts {
332 if c, ok := part.(ReasoningContent); ok {
333 m.Parts[i] = ReasoningContent{
334 Thinking: c.Thinking + delta,
335 Signature: c.Signature,
336 StartedAt: c.StartedAt,
337 FinishedAt: c.FinishedAt,
338 }
339 found = true
340 }
341 }
342 if !found {
343 m.Parts = append(m.Parts, ReasoningContent{
344 Thinking: delta,
345 StartedAt: time.Now().Unix(),
346 })
347 }
348}
349
350// AppendReasoningSignature appends a signature to the reasoning content part.
351func (m *Message) AppendReasoningSignature(signature string) {
352 for i, part := range m.Parts {
353 if c, ok := part.(ReasoningContent); ok {
354 m.Parts[i] = ReasoningContent{
355 Thinking: c.Thinking,
356 Signature: c.Signature + signature,
357 StartedAt: c.StartedAt,
358 FinishedAt: c.FinishedAt,
359 }
360 return
361 }
362 }
363 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
364}
365
366// FinishThinking marks the reasoning content as finished.
367func (m *Message) FinishThinking() {
368 for i, part := range m.Parts {
369 if c, ok := part.(ReasoningContent); ok {
370 if c.FinishedAt == 0 {
371 m.Parts[i] = ReasoningContent{
372 Thinking: c.Thinking,
373 Signature: c.Signature,
374 StartedAt: c.StartedAt,
375 FinishedAt: time.Now().Unix(),
376 }
377 }
378 return
379 }
380 }
381}
382
383// ThinkingDuration returns the duration of the thinking phase.
384func (m *Message) ThinkingDuration() time.Duration {
385 reasoning := m.ReasoningContent()
386 if reasoning.StartedAt == 0 {
387 return 0
388 }
389
390 endTime := reasoning.FinishedAt
391 if endTime == 0 {
392 endTime = time.Now().Unix()
393 }
394
395 return time.Duration(endTime-reasoning.StartedAt) * time.Second
396}
397
398// FinishToolCall marks a tool call as finished.
399func (m *Message) FinishToolCall(toolCallID string) {
400 for i, part := range m.Parts {
401 if c, ok := part.(ToolCall); ok {
402 if c.ID == toolCallID {
403 m.Parts[i] = ToolCall{
404 ID: c.ID,
405 Name: c.Name,
406 Input: c.Input,
407 Type: c.Type,
408 Finished: true,
409 }
410 return
411 }
412 }
413 }
414}
415
416// AppendToolCallInput appends input to a tool call.
417func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
418 for i, part := range m.Parts {
419 if c, ok := part.(ToolCall); ok {
420 if c.ID == toolCallID {
421 m.Parts[i] = ToolCall{
422 ID: c.ID,
423 Name: c.Name,
424 Input: c.Input + inputDelta,
425 Type: c.Type,
426 Finished: c.Finished,
427 }
428 return
429 }
430 }
431 }
432}
433
434// AddToolCall adds or updates a tool call.
435func (m *Message) AddToolCall(tc ToolCall) {
436 for i, part := range m.Parts {
437 if c, ok := part.(ToolCall); ok {
438 if c.ID == tc.ID {
439 m.Parts[i] = tc
440 return
441 }
442 }
443 }
444 m.Parts = append(m.Parts, tc)
445}
446
447// SetToolCalls replaces all tool call parts.
448func (m *Message) SetToolCalls(tc []ToolCall) {
449 parts := make([]ContentPart, 0)
450 for _, part := range m.Parts {
451 if _, ok := part.(ToolCall); ok {
452 continue
453 }
454 parts = append(parts, part)
455 }
456 m.Parts = parts
457 for _, toolCall := range tc {
458 m.Parts = append(m.Parts, toolCall)
459 }
460}
461
462// AddToolResult adds a tool result.
463func (m *Message) AddToolResult(tr ToolResult) {
464 m.Parts = append(m.Parts, tr)
465}
466
467// SetToolResults adds multiple tool results.
468func (m *Message) SetToolResults(tr []ToolResult) {
469 for _, toolResult := range tr {
470 m.Parts = append(m.Parts, toolResult)
471 }
472}
473
474// AddFinish adds a finish part to the message.
475func (m *Message) AddFinish(reason FinishReason, message, details string) {
476 for i, part := range m.Parts {
477 if _, ok := part.(Finish); ok {
478 m.Parts = slices.Delete(m.Parts, i, i+1)
479 break
480 }
481 }
482 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
483}
484
485// AddImageURL adds an image URL part to the message.
486func (m *Message) AddImageURL(url, detail string) {
487 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
488}
489
490// AddBinary adds a binary content part to the message.
491func (m *Message) AddBinary(mimeType string, data []byte) {
492 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
493}
494
495type partType string
496
497const (
498 reasoningType partType = "reasoning"
499 textType partType = "text"
500 imageURLType partType = "image_url"
501 binaryType partType = "binary"
502 toolCallType partType = "tool_call"
503 toolResultType partType = "tool_result"
504 finishType partType = "finish"
505)
506
507type partWrapper struct {
508 Type partType `json:"type"`
509 Data ContentPart `json:"data"`
510}
511
512// MarshalParts marshals content parts to JSON.
513func MarshalParts(parts []ContentPart) ([]byte, error) {
514 wrappedParts := make([]partWrapper, len(parts))
515
516 for i, part := range parts {
517 var typ partType
518
519 switch part.(type) {
520 case ReasoningContent:
521 typ = reasoningType
522 case TextContent:
523 typ = textType
524 case ImageURLContent:
525 typ = imageURLType
526 case BinaryContent:
527 typ = binaryType
528 case ToolCall:
529 typ = toolCallType
530 case ToolResult:
531 typ = toolResultType
532 case Finish:
533 typ = finishType
534 default:
535 return nil, fmt.Errorf("unknown part type: %T", part)
536 }
537
538 wrappedParts[i] = partWrapper{
539 Type: typ,
540 Data: part,
541 }
542 }
543 return json.Marshal(wrappedParts)
544}
545
546// UnmarshalParts unmarshals content parts from JSON.
547func UnmarshalParts(data []byte) ([]ContentPart, error) {
548 temp := []json.RawMessage{}
549
550 if err := json.Unmarshal(data, &temp); err != nil {
551 return nil, err
552 }
553
554 parts := make([]ContentPart, 0)
555
556 for _, rawPart := range temp {
557 var wrapper struct {
558 Type partType `json:"type"`
559 Data json.RawMessage `json:"data"`
560 }
561
562 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
563 return nil, err
564 }
565
566 switch wrapper.Type {
567 case reasoningType:
568 part := ReasoningContent{}
569 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
570 return nil, err
571 }
572 parts = append(parts, part)
573 case textType:
574 part := TextContent{}
575 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
576 return nil, err
577 }
578 parts = append(parts, part)
579 case imageURLType:
580 part := ImageURLContent{}
581 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
582 return nil, err
583 }
584 parts = append(parts, part)
585 case binaryType:
586 part := BinaryContent{}
587 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
588 return nil, err
589 }
590 parts = append(parts, part)
591 case toolCallType:
592 part := ToolCall{}
593 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
594 return nil, err
595 }
596 parts = append(parts, part)
597 case toolResultType:
598 part := ToolResult{}
599 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
600 return nil, err
601 }
602 parts = append(parts, part)
603 case finishType:
604 part := Finish{}
605 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
606 return nil, err
607 }
608 parts = append(parts, part)
609 default:
610 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
611 }
612 }
613
614 return parts, nil
615}
616
617// Attachment represents a file attachment.
618type Attachment struct {
619 FilePath string `json:"file_path"`
620 FileName string `json:"file_name"`
621 MimeType string `json:"mime_type"`
622 Content []byte `json:"content"`
623}
624
625// ToMessage converts a proto Attachment to a [message.Attachment].
626func (a Attachment) ToMessage() message.Attachment {
627 return message.Attachment{
628 FilePath: a.FilePath,
629 FileName: a.FileName,
630 MimeType: a.MimeType,
631 Content: a.Content,
632 }
633}
634
635// AttachmentFromMessage converts a [message.Attachment] to a proto
636// Attachment.
637func AttachmentFromMessage(a message.Attachment) Attachment {
638 return Attachment{
639 FilePath: a.FilePath,
640 FileName: a.FileName,
641 MimeType: a.MimeType,
642 Content: a.Content,
643 }
644}
645
646// AttachmentsToMessage converts a slice of proto Attachments to a slice
647// of [message.Attachment].
648func AttachmentsToMessage(as []Attachment) []message.Attachment {
649 out := make([]message.Attachment, len(as))
650 for i, a := range as {
651 out[i] = a.ToMessage()
652 }
653 return out
654}
655
656// AttachmentsFromMessage converts a slice of [message.Attachment] to a
657// slice of proto Attachments.
658func AttachmentsFromMessage(as []message.Attachment) []Attachment {
659 out := make([]Attachment, len(as))
660 for i, a := range as {
661 out[i] = AttachmentFromMessage(a)
662 }
663 return out
664}
665
666// MarshalJSON implements the [json.Marshaler] interface.
667func (a Attachment) MarshalJSON() ([]byte, error) {
668 type Alias Attachment
669 return json.Marshal(&struct {
670 Content string `json:"content"`
671 *Alias
672 }{
673 Content: base64.StdEncoding.EncodeToString(a.Content),
674 Alias: (*Alias)(&a),
675 })
676}
677
678// UnmarshalJSON implements the [json.Unmarshaler] interface.
679func (a *Attachment) UnmarshalJSON(data []byte) error {
680 type Alias Attachment
681 aux := &struct {
682 Content string `json:"content"`
683 *Alias
684 }{
685 Alias: (*Alias)(a),
686 }
687 if err := json.Unmarshal(data, &aux); err != nil {
688 return err
689 }
690 content, err := base64.StdEncoding.DecodeString(aux.Content)
691 if err != nil {
692 return err
693 }
694 a.Content = content
695 return nil
696}