1package message
2
3import (
4 "encoding/base64"
5 "encoding/json"
6 "slices"
7 "time"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10)
11
12type MessageRole string
13
14const (
15 Assistant MessageRole = "assistant"
16 User MessageRole = "user"
17 System MessageRole = "system"
18 Tool MessageRole = "tool"
19)
20
21func (r MessageRole) MarshalText() ([]byte, error) {
22 return []byte(r), nil
23}
24
25func (r *MessageRole) UnmarshalText(data []byte) error {
26 *r = MessageRole(data)
27 return nil
28}
29
30type FinishReason string
31
32const (
33 FinishReasonEndTurn FinishReason = "end_turn"
34 FinishReasonMaxTokens FinishReason = "max_tokens"
35 FinishReasonToolUse FinishReason = "tool_use"
36 FinishReasonCanceled FinishReason = "canceled"
37 FinishReasonError FinishReason = "error"
38 FinishReasonPermissionDenied FinishReason = "permission_denied"
39
40 // Should never happen
41 FinishReasonUnknown FinishReason = "unknown"
42)
43
44func (fr FinishReason) MarshalText() ([]byte, error) {
45 return []byte(fr), nil
46}
47
48func (fr *FinishReason) UnmarshalText(data []byte) error {
49 *fr = FinishReason(data)
50 return nil
51}
52
53type ContentPart interface {
54 isPart()
55}
56
57type ReasoningContent struct {
58 Thinking string `json:"thinking"`
59 Signature string `json:"signature"`
60 StartedAt int64 `json:"started_at,omitempty"`
61 FinishedAt int64 `json:"finished_at,omitempty"`
62}
63
64func (tc ReasoningContent) String() string {
65 return tc.Thinking
66}
67func (ReasoningContent) isPart() {}
68
69type TextContent struct {
70 Text string `json:"text"`
71}
72
73func (tc TextContent) String() string {
74 return tc.Text
75}
76
77func (TextContent) isPart() {}
78
79type ImageURLContent struct {
80 URL string `json:"url"`
81 Detail string `json:"detail,omitempty"`
82}
83
84func (iuc ImageURLContent) String() string {
85 return iuc.URL
86}
87
88func (ImageURLContent) isPart() {}
89
90type BinaryContent struct {
91 Path string
92 MIMEType string
93 Data []byte
94}
95
96func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
97 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
98 if p == catwalk.InferenceProviderOpenAI {
99 return "data:" + bc.MIMEType + ";base64," + base64Encoded
100 }
101 return base64Encoded
102}
103
104func (BinaryContent) isPart() {}
105
106type ToolCall struct {
107 ID string `json:"id"`
108 Name string `json:"name"`
109 Input string `json:"input"`
110 Type string `json:"type"`
111 Finished bool `json:"finished"`
112}
113
114func (ToolCall) isPart() {}
115
116type ToolResult struct {
117 ToolCallID string `json:"tool_call_id"`
118 Name string `json:"name"`
119 Content string `json:"content"`
120 Metadata string `json:"metadata"`
121 IsError bool `json:"is_error"`
122}
123
124func (ToolResult) isPart() {}
125
126type Finish struct {
127 Reason FinishReason `json:"reason"`
128 Time int64 `json:"time"`
129 Message string `json:"message,omitempty"`
130 Details string `json:"details,omitempty"`
131}
132
133func (Finish) isPart() {}
134
135type Message struct {
136 ID string `json:"id"`
137 Role MessageRole `json:"role"`
138 SessionID string `json:"session_id"`
139 Parts []ContentPart `json:"parts"`
140 Model string `json:"model"`
141 Provider string `json:"provider"`
142 CreatedAt int64 `json:"created_at"`
143 UpdatedAt int64 `json:"updated_at"`
144}
145
146// MarshalJSON implements the [json.Marshaler] interface.
147func (m Message) MarshalJSON() ([]byte, error) {
148 // We need to handle the Parts specially since they're ContentPart interfaces
149 // which can't be directly marshaled by the standard JSON package.
150 parts, err := marshallParts(m.Parts)
151 if err != nil {
152 return nil, err
153 }
154
155 // Create an alias to avoid infinite recursion
156 type Alias Message
157 return json.Marshal(&struct {
158 Parts json.RawMessage `json:"parts"`
159 *Alias
160 }{
161 Parts: json.RawMessage(parts),
162 Alias: (*Alias)(&m),
163 })
164}
165
166// UnmarshalJSON implements the [json.Unmarshaler] interface.
167func (m *Message) UnmarshalJSON(data []byte) error {
168 // Create an alias to avoid infinite recursion
169 type Alias Message
170 aux := &struct {
171 Parts json.RawMessage `json:"parts"`
172 *Alias
173 }{
174 Alias: (*Alias)(m),
175 }
176
177 if err := json.Unmarshal(data, &aux); err != nil {
178 return err
179 }
180
181 // Unmarshal the parts using our custom function
182 parts, err := unmarshallParts([]byte(aux.Parts))
183 if err != nil {
184 return err
185 }
186
187 m.Parts = parts
188 return nil
189}
190
191func (m *Message) Content() TextContent {
192 for _, part := range m.Parts {
193 if c, ok := part.(TextContent); ok {
194 return c
195 }
196 }
197 return TextContent{}
198}
199
200func (m *Message) ReasoningContent() ReasoningContent {
201 for _, part := range m.Parts {
202 if c, ok := part.(ReasoningContent); ok {
203 return c
204 }
205 }
206 return ReasoningContent{}
207}
208
209func (m *Message) ImageURLContent() []ImageURLContent {
210 imageURLContents := make([]ImageURLContent, 0)
211 for _, part := range m.Parts {
212 if c, ok := part.(ImageURLContent); ok {
213 imageURLContents = append(imageURLContents, c)
214 }
215 }
216 return imageURLContents
217}
218
219func (m *Message) BinaryContent() []BinaryContent {
220 binaryContents := make([]BinaryContent, 0)
221 for _, part := range m.Parts {
222 if c, ok := part.(BinaryContent); ok {
223 binaryContents = append(binaryContents, c)
224 }
225 }
226 return binaryContents
227}
228
229func (m *Message) ToolCalls() []ToolCall {
230 toolCalls := make([]ToolCall, 0)
231 for _, part := range m.Parts {
232 if c, ok := part.(ToolCall); ok {
233 toolCalls = append(toolCalls, c)
234 }
235 }
236 return toolCalls
237}
238
239func (m *Message) ToolResults() []ToolResult {
240 toolResults := make([]ToolResult, 0)
241 for _, part := range m.Parts {
242 if c, ok := part.(ToolResult); ok {
243 toolResults = append(toolResults, c)
244 }
245 }
246 return toolResults
247}
248
249func (m *Message) IsFinished() bool {
250 for _, part := range m.Parts {
251 if _, ok := part.(Finish); ok {
252 return true
253 }
254 }
255 return false
256}
257
258func (m *Message) FinishPart() *Finish {
259 for _, part := range m.Parts {
260 if c, ok := part.(Finish); ok {
261 return &c
262 }
263 }
264 return nil
265}
266
267func (m *Message) FinishReason() FinishReason {
268 for _, part := range m.Parts {
269 if c, ok := part.(Finish); ok {
270 return c.Reason
271 }
272 }
273 return ""
274}
275
276func (m *Message) IsThinking() bool {
277 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
278 return true
279 }
280 return false
281}
282
283func (m *Message) AppendContent(delta string) {
284 found := false
285 for i, part := range m.Parts {
286 if c, ok := part.(TextContent); ok {
287 m.Parts[i] = TextContent{Text: c.Text + delta}
288 found = true
289 }
290 }
291 if !found {
292 m.Parts = append(m.Parts, TextContent{Text: delta})
293 }
294}
295
296func (m *Message) AppendReasoningContent(delta string) {
297 found := false
298 for i, part := range m.Parts {
299 if c, ok := part.(ReasoningContent); ok {
300 m.Parts[i] = ReasoningContent{
301 Thinking: c.Thinking + delta,
302 Signature: c.Signature,
303 StartedAt: c.StartedAt,
304 FinishedAt: c.FinishedAt,
305 }
306 found = true
307 }
308 }
309 if !found {
310 m.Parts = append(m.Parts, ReasoningContent{
311 Thinking: delta,
312 StartedAt: time.Now().Unix(),
313 })
314 }
315}
316
317func (m *Message) AppendReasoningSignature(signature string) {
318 for i, part := range m.Parts {
319 if c, ok := part.(ReasoningContent); ok {
320 m.Parts[i] = ReasoningContent{
321 Thinking: c.Thinking,
322 Signature: c.Signature + signature,
323 StartedAt: c.StartedAt,
324 FinishedAt: c.FinishedAt,
325 }
326 return
327 }
328 }
329 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
330}
331
332func (m *Message) FinishThinking() {
333 for i, part := range m.Parts {
334 if c, ok := part.(ReasoningContent); ok {
335 if c.FinishedAt == 0 {
336 m.Parts[i] = ReasoningContent{
337 Thinking: c.Thinking,
338 Signature: c.Signature,
339 StartedAt: c.StartedAt,
340 FinishedAt: time.Now().Unix(),
341 }
342 }
343 return
344 }
345 }
346}
347
348func (m *Message) ThinkingDuration() time.Duration {
349 reasoning := m.ReasoningContent()
350 if reasoning.StartedAt == 0 {
351 return 0
352 }
353
354 endTime := reasoning.FinishedAt
355 if endTime == 0 {
356 endTime = time.Now().Unix()
357 }
358
359 return time.Duration(endTime-reasoning.StartedAt) * time.Second
360}
361
362func (m *Message) FinishToolCall(toolCallID string) {
363 for i, part := range m.Parts {
364 if c, ok := part.(ToolCall); ok {
365 if c.ID == toolCallID {
366 m.Parts[i] = ToolCall{
367 ID: c.ID,
368 Name: c.Name,
369 Input: c.Input,
370 Type: c.Type,
371 Finished: true,
372 }
373 return
374 }
375 }
376 }
377}
378
379func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
380 for i, part := range m.Parts {
381 if c, ok := part.(ToolCall); ok {
382 if c.ID == toolCallID {
383 m.Parts[i] = ToolCall{
384 ID: c.ID,
385 Name: c.Name,
386 Input: c.Input + inputDelta,
387 Type: c.Type,
388 Finished: c.Finished,
389 }
390 return
391 }
392 }
393 }
394}
395
396func (m *Message) AddToolCall(tc ToolCall) {
397 for i, part := range m.Parts {
398 if c, ok := part.(ToolCall); ok {
399 if c.ID == tc.ID {
400 m.Parts[i] = tc
401 return
402 }
403 }
404 }
405 m.Parts = append(m.Parts, tc)
406}
407
408func (m *Message) SetToolCalls(tc []ToolCall) {
409 // remove any existing tool call part it could have multiple
410 parts := make([]ContentPart, 0)
411 for _, part := range m.Parts {
412 if _, ok := part.(ToolCall); ok {
413 continue
414 }
415 parts = append(parts, part)
416 }
417 m.Parts = parts
418 for _, toolCall := range tc {
419 m.Parts = append(m.Parts, toolCall)
420 }
421}
422
423func (m *Message) AddToolResult(tr ToolResult) {
424 m.Parts = append(m.Parts, tr)
425}
426
427func (m *Message) SetToolResults(tr []ToolResult) {
428 for _, toolResult := range tr {
429 m.Parts = append(m.Parts, toolResult)
430 }
431}
432
433func (m *Message) AddFinish(reason FinishReason, message, details string) {
434 // remove any existing finish part
435 for i, part := range m.Parts {
436 if _, ok := part.(Finish); ok {
437 m.Parts = slices.Delete(m.Parts, i, i+1)
438 break
439 }
440 }
441 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
442}
443
444func (m *Message) AddImageURL(url, detail string) {
445 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
446}
447
448func (m *Message) AddBinary(mimeType string, data []byte) {
449 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
450}