1package message
2
3import (
4 "encoding/base64"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9)
10
11type MessageRole string
12
13const (
14 Assistant MessageRole = "assistant"
15 User MessageRole = "user"
16 System MessageRole = "system"
17 Tool MessageRole = "tool"
18)
19
20func (r MessageRole) MarshalText() ([]byte, error) {
21 return []byte(r), nil
22}
23
24func (r *MessageRole) UnmarshalText(data []byte) error {
25 *r = MessageRole(data)
26 return nil
27}
28
29type FinishReason string
30
31const (
32 FinishReasonEndTurn FinishReason = "end_turn"
33 FinishReasonMaxTokens FinishReason = "max_tokens"
34 FinishReasonToolUse FinishReason = "tool_use"
35 FinishReasonCanceled FinishReason = "canceled"
36 FinishReasonError FinishReason = "error"
37 FinishReasonPermissionDenied FinishReason = "permission_denied"
38
39 // Should never happen
40 FinishReasonUnknown FinishReason = "unknown"
41)
42
43func (fr FinishReason) MarshalText() ([]byte, error) {
44 return []byte(fr), nil
45}
46
47func (fr *FinishReason) UnmarshalText(data []byte) error {
48 *fr = FinishReason(data)
49 return nil
50}
51
52type ContentPart interface {
53 isPart()
54}
55
56type ReasoningContent struct {
57 Thinking string `json:"thinking"`
58 Signature string `json:"signature"`
59 StartedAt int64 `json:"started_at,omitempty"`
60 FinishedAt int64 `json:"finished_at,omitempty"`
61}
62
63func (tc ReasoningContent) String() string {
64 return tc.Thinking
65}
66func (ReasoningContent) isPart() {}
67
68type TextContent struct {
69 Text string `json:"text"`
70}
71
72func (tc TextContent) String() string {
73 return tc.Text
74}
75
76func (TextContent) isPart() {}
77
78type ImageURLContent struct {
79 URL string `json:"url"`
80 Detail string `json:"detail,omitempty"`
81}
82
83func (iuc ImageURLContent) String() string {
84 return iuc.URL
85}
86
87func (ImageURLContent) isPart() {}
88
89type BinaryContent struct {
90 Path string
91 MIMEType string
92 Data []byte
93}
94
95func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
96 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
97 if p == catwalk.InferenceProviderOpenAI {
98 return "data:" + bc.MIMEType + ";base64," + base64Encoded
99 }
100 return base64Encoded
101}
102
103func (BinaryContent) isPart() {}
104
105type ToolCall struct {
106 ID string `json:"id"`
107 Name string `json:"name"`
108 Input string `json:"input"`
109 Type string `json:"type"`
110 Finished bool `json:"finished"`
111}
112
113func (ToolCall) isPart() {}
114
115type ToolResult struct {
116 ToolCallID string `json:"tool_call_id"`
117 Name string `json:"name"`
118 Content string `json:"content"`
119 Metadata string `json:"metadata"`
120 IsError bool `json:"is_error"`
121}
122
123func (ToolResult) isPart() {}
124
125type Finish struct {
126 Reason FinishReason `json:"reason"`
127 Time int64 `json:"time"`
128 Message string `json:"message,omitempty"`
129 Details string `json:"details,omitempty"`
130}
131
132func (Finish) isPart() {}
133
134type Message struct {
135 ID string `json:"id"`
136 Role MessageRole `json:"role"`
137 SessionID string `json:"session_id"`
138 Parts []ContentPart `json:"parts"`
139 Model string `json:"model"`
140 Provider string `json:"provider"`
141 CreatedAt int64 `json:"created_at"`
142 UpdatedAt int64 `json:"updated_at"`
143}
144
145func (m *Message) Content() TextContent {
146 for _, part := range m.Parts {
147 if c, ok := part.(TextContent); ok {
148 return c
149 }
150 }
151 return TextContent{}
152}
153
154func (m *Message) ReasoningContent() ReasoningContent {
155 for _, part := range m.Parts {
156 if c, ok := part.(ReasoningContent); ok {
157 return c
158 }
159 }
160 return ReasoningContent{}
161}
162
163func (m *Message) ImageURLContent() []ImageURLContent {
164 imageURLContents := make([]ImageURLContent, 0)
165 for _, part := range m.Parts {
166 if c, ok := part.(ImageURLContent); ok {
167 imageURLContents = append(imageURLContents, c)
168 }
169 }
170 return imageURLContents
171}
172
173func (m *Message) BinaryContent() []BinaryContent {
174 binaryContents := make([]BinaryContent, 0)
175 for _, part := range m.Parts {
176 if c, ok := part.(BinaryContent); ok {
177 binaryContents = append(binaryContents, c)
178 }
179 }
180 return binaryContents
181}
182
183func (m *Message) ToolCalls() []ToolCall {
184 toolCalls := make([]ToolCall, 0)
185 for _, part := range m.Parts {
186 if c, ok := part.(ToolCall); ok {
187 toolCalls = append(toolCalls, c)
188 }
189 }
190 return toolCalls
191}
192
193func (m *Message) ToolResults() []ToolResult {
194 toolResults := make([]ToolResult, 0)
195 for _, part := range m.Parts {
196 if c, ok := part.(ToolResult); ok {
197 toolResults = append(toolResults, c)
198 }
199 }
200 return toolResults
201}
202
203func (m *Message) IsFinished() bool {
204 for _, part := range m.Parts {
205 if _, ok := part.(Finish); ok {
206 return true
207 }
208 }
209 return false
210}
211
212func (m *Message) FinishPart() *Finish {
213 for _, part := range m.Parts {
214 if c, ok := part.(Finish); ok {
215 return &c
216 }
217 }
218 return nil
219}
220
221func (m *Message) FinishReason() FinishReason {
222 for _, part := range m.Parts {
223 if c, ok := part.(Finish); ok {
224 return c.Reason
225 }
226 }
227 return ""
228}
229
230func (m *Message) IsThinking() bool {
231 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
232 return true
233 }
234 return false
235}
236
237func (m *Message) AppendContent(delta string) {
238 found := false
239 for i, part := range m.Parts {
240 if c, ok := part.(TextContent); ok {
241 m.Parts[i] = TextContent{Text: c.Text + delta}
242 found = true
243 }
244 }
245 if !found {
246 m.Parts = append(m.Parts, TextContent{Text: delta})
247 }
248}
249
250func (m *Message) AppendReasoningContent(delta string) {
251 found := false
252 for i, part := range m.Parts {
253 if c, ok := part.(ReasoningContent); ok {
254 m.Parts[i] = ReasoningContent{
255 Thinking: c.Thinking + delta,
256 Signature: c.Signature,
257 StartedAt: c.StartedAt,
258 FinishedAt: c.FinishedAt,
259 }
260 found = true
261 }
262 }
263 if !found {
264 m.Parts = append(m.Parts, ReasoningContent{
265 Thinking: delta,
266 StartedAt: time.Now().Unix(),
267 })
268 }
269}
270
271func (m *Message) AppendReasoningSignature(signature string) {
272 for i, part := range m.Parts {
273 if c, ok := part.(ReasoningContent); ok {
274 m.Parts[i] = ReasoningContent{
275 Thinking: c.Thinking,
276 Signature: c.Signature + signature,
277 StartedAt: c.StartedAt,
278 FinishedAt: c.FinishedAt,
279 }
280 return
281 }
282 }
283 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
284}
285
286func (m *Message) FinishThinking() {
287 for i, part := range m.Parts {
288 if c, ok := part.(ReasoningContent); ok {
289 if c.FinishedAt == 0 {
290 m.Parts[i] = ReasoningContent{
291 Thinking: c.Thinking,
292 Signature: c.Signature,
293 StartedAt: c.StartedAt,
294 FinishedAt: time.Now().Unix(),
295 }
296 }
297 return
298 }
299 }
300}
301
302func (m *Message) ThinkingDuration() time.Duration {
303 reasoning := m.ReasoningContent()
304 if reasoning.StartedAt == 0 {
305 return 0
306 }
307
308 endTime := reasoning.FinishedAt
309 if endTime == 0 {
310 endTime = time.Now().Unix()
311 }
312
313 return time.Duration(endTime-reasoning.StartedAt) * time.Second
314}
315
316func (m *Message) FinishToolCall(toolCallID string) {
317 for i, part := range m.Parts {
318 if c, ok := part.(ToolCall); ok {
319 if c.ID == toolCallID {
320 m.Parts[i] = ToolCall{
321 ID: c.ID,
322 Name: c.Name,
323 Input: c.Input,
324 Type: c.Type,
325 Finished: true,
326 }
327 return
328 }
329 }
330 }
331}
332
333func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
334 for i, part := range m.Parts {
335 if c, ok := part.(ToolCall); ok {
336 if c.ID == toolCallID {
337 m.Parts[i] = ToolCall{
338 ID: c.ID,
339 Name: c.Name,
340 Input: c.Input + inputDelta,
341 Type: c.Type,
342 Finished: c.Finished,
343 }
344 return
345 }
346 }
347 }
348}
349
350func (m *Message) AddToolCall(tc ToolCall) {
351 for i, part := range m.Parts {
352 if c, ok := part.(ToolCall); ok {
353 if c.ID == tc.ID {
354 m.Parts[i] = tc
355 return
356 }
357 }
358 }
359 m.Parts = append(m.Parts, tc)
360}
361
362func (m *Message) SetToolCalls(tc []ToolCall) {
363 // remove any existing tool call part it could have multiple
364 parts := make([]ContentPart, 0)
365 for _, part := range m.Parts {
366 if _, ok := part.(ToolCall); ok {
367 continue
368 }
369 parts = append(parts, part)
370 }
371 m.Parts = parts
372 for _, toolCall := range tc {
373 m.Parts = append(m.Parts, toolCall)
374 }
375}
376
377func (m *Message) AddToolResult(tr ToolResult) {
378 m.Parts = append(m.Parts, tr)
379}
380
381func (m *Message) SetToolResults(tr []ToolResult) {
382 for _, toolResult := range tr {
383 m.Parts = append(m.Parts, toolResult)
384 }
385}
386
387func (m *Message) AddFinish(reason FinishReason, message, details string) {
388 // remove any existing finish part
389 for i, part := range m.Parts {
390 if _, ok := part.(Finish); ok {
391 m.Parts = slices.Delete(m.Parts, i, i+1)
392 break
393 }
394 }
395 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
396}
397
398func (m *Message) AddImageURL(url, detail string) {
399 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
400}
401
402func (m *Message) AddBinary(mimeType string, data []byte) {
403 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
404}