1package proto
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "slices"
8 "time"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11)
12
13type CreateMessageParams struct {
14 Role MessageRole `json:"role"`
15 Parts []ContentPart `json:"parts"`
16 Model string `json:"model"`
17 Provider string `json:"provider,omitempty"`
18}
19
20type Message struct {
21 ID string `json:"id"`
22 Role MessageRole `json:"role"`
23 SessionID string `json:"session_id"`
24 Parts []ContentPart `json:"parts"`
25 Model string `json:"model"`
26 Provider string `json:"provider"`
27 CreatedAt int64 `json:"created_at"`
28 UpdatedAt int64 `json:"updated_at"`
29}
30type MessageRole string
31
32const (
33 Assistant MessageRole = "assistant"
34 User MessageRole = "user"
35 System MessageRole = "system"
36 Tool MessageRole = "tool"
37)
38
39func (r MessageRole) MarshalText() ([]byte, error) {
40 return []byte(r), nil
41}
42
43func (r *MessageRole) UnmarshalText(data []byte) error {
44 *r = MessageRole(data)
45 return nil
46}
47
48type FinishReason string
49
50const (
51 FinishReasonEndTurn FinishReason = "end_turn"
52 FinishReasonMaxTokens FinishReason = "max_tokens"
53 FinishReasonToolUse FinishReason = "tool_use"
54 FinishReasonCanceled FinishReason = "canceled"
55 FinishReasonError FinishReason = "error"
56 FinishReasonPermissionDenied FinishReason = "permission_denied"
57
58 // Should never happen
59 FinishReasonUnknown FinishReason = "unknown"
60)
61
62func (fr FinishReason) MarshalText() ([]byte, error) {
63 return []byte(fr), nil
64}
65
66func (fr *FinishReason) UnmarshalText(data []byte) error {
67 *fr = FinishReason(data)
68 return nil
69}
70
71type ContentPart interface {
72 isPart()
73}
74
75type ReasoningContent struct {
76 Thinking string `json:"thinking"`
77 Signature string `json:"signature"`
78 StartedAt int64 `json:"started_at,omitempty"`
79 FinishedAt int64 `json:"finished_at,omitempty"`
80}
81
82func (tc ReasoningContent) String() string {
83 return tc.Thinking
84}
85func (ReasoningContent) isPart() {}
86
87type TextContent struct {
88 Text string `json:"text"`
89}
90
91func (tc TextContent) String() string {
92 return tc.Text
93}
94
95func (TextContent) isPart() {}
96
97type ImageURLContent struct {
98 URL string `json:"url"`
99 Detail string `json:"detail,omitempty"`
100}
101
102func (iuc ImageURLContent) String() string {
103 return iuc.URL
104}
105
106func (ImageURLContent) isPart() {}
107
108type BinaryContent struct {
109 Path string
110 MIMEType string
111 Data []byte
112}
113
114func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
115 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
116 if p == catwalk.InferenceProviderOpenAI {
117 return "data:" + bc.MIMEType + ";base64," + base64Encoded
118 }
119 return base64Encoded
120}
121
122func (BinaryContent) isPart() {}
123
124type ToolCall struct {
125 ID string `json:"id"`
126 Name string `json:"name"`
127 Input string `json:"input"`
128 Type string `json:"type"`
129 Finished bool `json:"finished"`
130}
131
132func (ToolCall) isPart() {}
133
134type ToolResult struct {
135 ToolCallID string `json:"tool_call_id"`
136 Name string `json:"name"`
137 Content string `json:"content"`
138 Metadata string `json:"metadata"`
139 IsError bool `json:"is_error"`
140}
141
142func (ToolResult) isPart() {}
143
144type Finish struct {
145 Reason FinishReason `json:"reason"`
146 Time int64 `json:"time"`
147 Message string `json:"message,omitempty"`
148 Details string `json:"details,omitempty"`
149}
150
151func (Finish) isPart() {}
152
153// MarshalJSON implements the [json.Marshaler] interface.
154func (m Message) MarshalJSON() ([]byte, error) {
155 // We need to handle the Parts specially since they're ContentPart interfaces
156 // which can't be directly marshaled by the standard JSON package.
157 parts, err := MarshallParts(m.Parts)
158 if err != nil {
159 return nil, err
160 }
161
162 // Create an alias to avoid infinite recursion
163 type Alias Message
164 return json.Marshal(&struct {
165 Parts json.RawMessage `json:"parts"`
166 *Alias
167 }{
168 Parts: json.RawMessage(parts),
169 Alias: (*Alias)(&m),
170 })
171}
172
173// UnmarshalJSON implements the [json.Unmarshaler] interface.
174func (m *Message) UnmarshalJSON(data []byte) error {
175 // Create an alias to avoid infinite recursion
176 type Alias Message
177 aux := &struct {
178 Parts json.RawMessage `json:"parts"`
179 *Alias
180 }{
181 Alias: (*Alias)(m),
182 }
183
184 if err := json.Unmarshal(data, &aux); err != nil {
185 return err
186 }
187
188 // Unmarshal the parts using our custom function
189 parts, err := UnmarshallParts([]byte(aux.Parts))
190 if err != nil {
191 return err
192 }
193
194 m.Parts = parts
195 return nil
196}
197
198func (m *Message) Content() TextContent {
199 for _, part := range m.Parts {
200 if c, ok := part.(TextContent); ok {
201 return c
202 }
203 }
204 return TextContent{}
205}
206
207func (m *Message) ReasoningContent() ReasoningContent {
208 for _, part := range m.Parts {
209 if c, ok := part.(ReasoningContent); ok {
210 return c
211 }
212 }
213 return ReasoningContent{}
214}
215
216func (m *Message) ImageURLContent() []ImageURLContent {
217 imageURLContents := make([]ImageURLContent, 0)
218 for _, part := range m.Parts {
219 if c, ok := part.(ImageURLContent); ok {
220 imageURLContents = append(imageURLContents, c)
221 }
222 }
223 return imageURLContents
224}
225
226func (m *Message) BinaryContent() []BinaryContent {
227 binaryContents := make([]BinaryContent, 0)
228 for _, part := range m.Parts {
229 if c, ok := part.(BinaryContent); ok {
230 binaryContents = append(binaryContents, c)
231 }
232 }
233 return binaryContents
234}
235
236func (m *Message) ToolCalls() []ToolCall {
237 toolCalls := make([]ToolCall, 0)
238 for _, part := range m.Parts {
239 if c, ok := part.(ToolCall); ok {
240 toolCalls = append(toolCalls, c)
241 }
242 }
243 return toolCalls
244}
245
246func (m *Message) ToolResults() []ToolResult {
247 toolResults := make([]ToolResult, 0)
248 for _, part := range m.Parts {
249 if c, ok := part.(ToolResult); ok {
250 toolResults = append(toolResults, c)
251 }
252 }
253 return toolResults
254}
255
256func (m *Message) IsFinished() bool {
257 for _, part := range m.Parts {
258 if _, ok := part.(Finish); ok {
259 return true
260 }
261 }
262 return false
263}
264
265func (m *Message) FinishPart() *Finish {
266 for _, part := range m.Parts {
267 if c, ok := part.(Finish); ok {
268 return &c
269 }
270 }
271 return nil
272}
273
274func (m *Message) FinishReason() FinishReason {
275 for _, part := range m.Parts {
276 if c, ok := part.(Finish); ok {
277 return c.Reason
278 }
279 }
280 return ""
281}
282
283func (m *Message) IsThinking() bool {
284 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
285 return true
286 }
287 return false
288}
289
290func (m *Message) AppendContent(delta string) {
291 found := false
292 for i, part := range m.Parts {
293 if c, ok := part.(TextContent); ok {
294 m.Parts[i] = TextContent{Text: c.Text + delta}
295 found = true
296 }
297 }
298 if !found {
299 m.Parts = append(m.Parts, TextContent{Text: delta})
300 }
301}
302
303func (m *Message) AppendReasoningContent(delta string) {
304 found := false
305 for i, part := range m.Parts {
306 if c, ok := part.(ReasoningContent); ok {
307 m.Parts[i] = ReasoningContent{
308 Thinking: c.Thinking + delta,
309 Signature: c.Signature,
310 StartedAt: c.StartedAt,
311 FinishedAt: c.FinishedAt,
312 }
313 found = true
314 }
315 }
316 if !found {
317 m.Parts = append(m.Parts, ReasoningContent{
318 Thinking: delta,
319 StartedAt: time.Now().Unix(),
320 })
321 }
322}
323
324func (m *Message) AppendReasoningSignature(signature string) {
325 for i, part := range m.Parts {
326 if c, ok := part.(ReasoningContent); ok {
327 m.Parts[i] = ReasoningContent{
328 Thinking: c.Thinking,
329 Signature: c.Signature + signature,
330 StartedAt: c.StartedAt,
331 FinishedAt: c.FinishedAt,
332 }
333 return
334 }
335 }
336 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
337}
338
339func (m *Message) FinishThinking() {
340 for i, part := range m.Parts {
341 if c, ok := part.(ReasoningContent); ok {
342 if c.FinishedAt == 0 {
343 m.Parts[i] = ReasoningContent{
344 Thinking: c.Thinking,
345 Signature: c.Signature,
346 StartedAt: c.StartedAt,
347 FinishedAt: time.Now().Unix(),
348 }
349 }
350 return
351 }
352 }
353}
354
355func (m *Message) ThinkingDuration() time.Duration {
356 reasoning := m.ReasoningContent()
357 if reasoning.StartedAt == 0 {
358 return 0
359 }
360
361 endTime := reasoning.FinishedAt
362 if endTime == 0 {
363 endTime = time.Now().Unix()
364 }
365
366 return time.Duration(endTime-reasoning.StartedAt) * time.Second
367}
368
369func (m *Message) FinishToolCall(toolCallID string) {
370 for i, part := range m.Parts {
371 if c, ok := part.(ToolCall); ok {
372 if c.ID == toolCallID {
373 m.Parts[i] = ToolCall{
374 ID: c.ID,
375 Name: c.Name,
376 Input: c.Input,
377 Type: c.Type,
378 Finished: true,
379 }
380 return
381 }
382 }
383 }
384}
385
386func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
387 for i, part := range m.Parts {
388 if c, ok := part.(ToolCall); ok {
389 if c.ID == toolCallID {
390 m.Parts[i] = ToolCall{
391 ID: c.ID,
392 Name: c.Name,
393 Input: c.Input + inputDelta,
394 Type: c.Type,
395 Finished: c.Finished,
396 }
397 return
398 }
399 }
400 }
401}
402
403func (m *Message) AddToolCall(tc ToolCall) {
404 for i, part := range m.Parts {
405 if c, ok := part.(ToolCall); ok {
406 if c.ID == tc.ID {
407 m.Parts[i] = tc
408 return
409 }
410 }
411 }
412 m.Parts = append(m.Parts, tc)
413}
414
415func (m *Message) SetToolCalls(tc []ToolCall) {
416 // remove any existing tool call part it could have multiple
417 parts := make([]ContentPart, 0)
418 for _, part := range m.Parts {
419 if _, ok := part.(ToolCall); ok {
420 continue
421 }
422 parts = append(parts, part)
423 }
424 m.Parts = parts
425 for _, toolCall := range tc {
426 m.Parts = append(m.Parts, toolCall)
427 }
428}
429
430func (m *Message) AddToolResult(tr ToolResult) {
431 m.Parts = append(m.Parts, tr)
432}
433
434func (m *Message) SetToolResults(tr []ToolResult) {
435 for _, toolResult := range tr {
436 m.Parts = append(m.Parts, toolResult)
437 }
438}
439
440func (m *Message) AddFinish(reason FinishReason, message, details string) {
441 // remove any existing finish part
442 for i, part := range m.Parts {
443 if _, ok := part.(Finish); ok {
444 m.Parts = slices.Delete(m.Parts, i, i+1)
445 break
446 }
447 }
448 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
449}
450
451func (m *Message) AddImageURL(url, detail string) {
452 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
453}
454
455func (m *Message) AddBinary(mimeType string, data []byte) {
456 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
457}
458
459type partType string
460
461const (
462 reasoningType partType = "reasoning"
463 textType partType = "text"
464 imageURLType partType = "image_url"
465 binaryType partType = "binary"
466 toolCallType partType = "tool_call"
467 toolResultType partType = "tool_result"
468 finishType partType = "finish"
469)
470
471type partWrapper struct {
472 Type partType `json:"type"`
473 Data ContentPart `json:"data"`
474}
475
476func MarshallParts(parts []ContentPart) ([]byte, error) {
477 wrappedParts := make([]partWrapper, len(parts))
478
479 for i, part := range parts {
480 var typ partType
481
482 switch part.(type) {
483 case ReasoningContent:
484 typ = reasoningType
485 case TextContent:
486 typ = textType
487 case ImageURLContent:
488 typ = imageURLType
489 case BinaryContent:
490 typ = binaryType
491 case ToolCall:
492 typ = toolCallType
493 case ToolResult:
494 typ = toolResultType
495 case Finish:
496 typ = finishType
497 default:
498 return nil, fmt.Errorf("unknown part type: %T", part)
499 }
500
501 wrappedParts[i] = partWrapper{
502 Type: typ,
503 Data: part,
504 }
505 }
506 return json.Marshal(wrappedParts)
507}
508
509func UnmarshallParts(data []byte) ([]ContentPart, error) {
510 temp := []json.RawMessage{}
511
512 if err := json.Unmarshal(data, &temp); err != nil {
513 return nil, err
514 }
515
516 parts := make([]ContentPart, 0)
517
518 for _, rawPart := range temp {
519 var wrapper struct {
520 Type partType `json:"type"`
521 Data json.RawMessage `json:"data"`
522 }
523
524 if err := json.Unmarshal(rawPart, &wrapper); err != nil {
525 return nil, err
526 }
527
528 switch wrapper.Type {
529 case reasoningType:
530 part := ReasoningContent{}
531 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
532 return nil, err
533 }
534 parts = append(parts, part)
535 case textType:
536 part := TextContent{}
537 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
538 return nil, err
539 }
540 parts = append(parts, part)
541 case imageURLType:
542 part := ImageURLContent{}
543 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
544 return nil, err
545 }
546 case binaryType:
547 part := BinaryContent{}
548 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
549 return nil, err
550 }
551 parts = append(parts, part)
552 case toolCallType:
553 part := ToolCall{}
554 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
555 return nil, err
556 }
557 parts = append(parts, part)
558 case toolResultType:
559 part := ToolResult{}
560 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
561 return nil, err
562 }
563 parts = append(parts, part)
564 case finishType:
565 part := Finish{}
566 if err := json.Unmarshal(wrapper.Data, &part); err != nil {
567 return nil, err
568 }
569 parts = append(parts, part)
570 default:
571 return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
572 }
573 }
574
575 return parts, nil
576}
577
578type Attachment struct {
579 FilePath string `json:"file_path"`
580 FileName string `json:"file_name"`
581 MimeType string `json:"mime_type"`
582 Content []byte `json:"content"`
583}
584
585// MarshalJSON implements the [json.Marshaler] interface.
586func (a Attachment) MarshalJSON() ([]byte, error) {
587 // Encode the content as a base64 string
588 type Alias Attachment
589 return json.Marshal(&struct {
590 Content string `json:"content"`
591 *Alias
592 }{
593 Content: base64.StdEncoding.EncodeToString(a.Content),
594 Alias: (*Alias)(&a),
595 })
596}
597
598// UnmarshalJSON implements the [json.Unmarshaler] interface.
599func (a *Attachment) UnmarshalJSON(data []byte) error {
600 // Decode the content from a base64 string
601 type Alias Attachment
602 aux := &struct {
603 Content string `json:"content"`
604 *Alias
605 }{
606 Alias: (*Alias)(a),
607 }
608 if err := json.Unmarshal(data, &aux); err != nil {
609 return err
610 }
611 content, err := base64.StdEncoding.DecodeString(aux.Content)
612 if err != nil {
613 return err
614 }
615 a.Content = content
616 return nil
617}