1package message
2
3import (
4 "encoding/base64"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9)
10
11type MessageRole string
12
13const (
14 Assistant MessageRole = "assistant"
15 User MessageRole = "user"
16 System MessageRole = "system"
17 Tool MessageRole = "tool"
18)
19
20type FinishReason string
21
22const (
23 FinishReasonEndTurn FinishReason = "end_turn"
24 FinishReasonMaxTokens FinishReason = "max_tokens"
25 FinishReasonToolUse FinishReason = "tool_use"
26 FinishReasonCanceled FinishReason = "canceled"
27 FinishReasonError FinishReason = "error"
28 FinishReasonPermissionDenied FinishReason = "permission_denied"
29
30 // Should never happen
31 FinishReasonUnknown FinishReason = "unknown"
32)
33
34type ContentPart interface {
35 isPart()
36}
37
38type ReasoningContent struct {
39 Thinking string `json:"thinking"`
40 Signature string `json:"signature"`
41 StartedAt int64 `json:"started_at,omitempty"`
42 FinishedAt int64 `json:"finished_at,omitempty"`
43}
44
45func (tc ReasoningContent) String() string {
46 return tc.Thinking
47}
48func (ReasoningContent) isPart() {}
49
50type TextContent struct {
51 Text string `json:"text"`
52}
53
54func (tc TextContent) String() string {
55 return tc.Text
56}
57
58func (TextContent) isPart() {}
59
60type ImageURLContent struct {
61 URL string `json:"url"`
62 Detail string `json:"detail,omitempty"`
63}
64
65func (iuc ImageURLContent) String() string {
66 return iuc.URL
67}
68
69func (ImageURLContent) isPart() {}
70
71type BinaryContent struct {
72 Path string
73 MIMEType string
74 Data []byte
75}
76
77func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
78 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
79 if p == catwalk.InferenceProviderOpenAI {
80 return "data:" + bc.MIMEType + ";base64," + base64Encoded
81 }
82 return base64Encoded
83}
84
85func (BinaryContent) isPart() {}
86
87type ToolCall struct {
88 ID string `json:"id"`
89 Name string `json:"name"`
90 Input string `json:"input"`
91 Type string `json:"type"`
92 Finished bool `json:"finished"`
93}
94
95func (ToolCall) isPart() {}
96
97type ToolResult struct {
98 ToolCallID string `json:"tool_call_id"`
99 Name string `json:"name"`
100 Content string `json:"content"`
101 Metadata string `json:"metadata"`
102 IsError bool `json:"is_error"`
103}
104
105func (ToolResult) isPart() {}
106
107type Finish struct {
108 Reason FinishReason `json:"reason"`
109 Time int64 `json:"time"`
110 Message string `json:"message,omitempty"`
111 Details string `json:"details,omitempty"`
112}
113
114func (Finish) isPart() {}
115
116type Retry struct {
117 Error string `json:"error"`
118 RetryAfter int64 `json:"retry_after"`
119 Timestamp int64 `json:"timestamp"`
120}
121
122type RetryContent struct {
123 Retries []Retry `json:"retries"`
124 Retrying bool `json:"retrying"`
125}
126
127func (RetryContent) isPart() {}
128
129type Message struct {
130 ID string
131 Role MessageRole
132 SessionID string
133 Parts []ContentPart
134 Model string
135 Provider string
136 CreatedAt int64
137 UpdatedAt int64
138}
139
140func (m *Message) Content() TextContent {
141 for _, part := range m.Parts {
142 if c, ok := part.(TextContent); ok {
143 return c
144 }
145 }
146 return TextContent{}
147}
148
149func (m *Message) ReasoningContent() ReasoningContent {
150 for _, part := range m.Parts {
151 if c, ok := part.(ReasoningContent); ok {
152 return c
153 }
154 }
155 return ReasoningContent{}
156}
157
158func (m *Message) ImageURLContent() []ImageURLContent {
159 imageURLContents := make([]ImageURLContent, 0)
160 for _, part := range m.Parts {
161 if c, ok := part.(ImageURLContent); ok {
162 imageURLContents = append(imageURLContents, c)
163 }
164 }
165 return imageURLContents
166}
167
168func (m *Message) BinaryContent() []BinaryContent {
169 binaryContents := make([]BinaryContent, 0)
170 for _, part := range m.Parts {
171 if c, ok := part.(BinaryContent); ok {
172 binaryContents = append(binaryContents, c)
173 }
174 }
175 return binaryContents
176}
177
178func (m *Message) ToolCalls() []ToolCall {
179 toolCalls := make([]ToolCall, 0)
180 for _, part := range m.Parts {
181 if c, ok := part.(ToolCall); ok {
182 toolCalls = append(toolCalls, c)
183 }
184 }
185 return toolCalls
186}
187
188func (m *Message) ToolResults() []ToolResult {
189 toolResults := make([]ToolResult, 0)
190 for _, part := range m.Parts {
191 if c, ok := part.(ToolResult); ok {
192 toolResults = append(toolResults, c)
193 }
194 }
195 return toolResults
196}
197
198func (m *Message) IsFinished() bool {
199 for _, part := range m.Parts {
200 if _, ok := part.(Finish); ok {
201 return true
202 }
203 }
204 return false
205}
206
207func (m *Message) FinishPart() *Finish {
208 for _, part := range m.Parts {
209 if c, ok := part.(Finish); ok {
210 return &c
211 }
212 }
213 return nil
214}
215
216func (m *Message) FinishReason() FinishReason {
217 for _, part := range m.Parts {
218 if c, ok := part.(Finish); ok {
219 return c.Reason
220 }
221 }
222 return ""
223}
224
225func (m *Message) IsThinking() bool {
226 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
227 return true
228 }
229 return false
230}
231
232func (m *Message) AppendContent(delta string) {
233 found := false
234 for i, part := range m.Parts {
235 if c, ok := part.(TextContent); ok {
236 m.Parts[i] = TextContent{Text: c.Text + delta}
237 found = true
238 }
239 }
240 if !found {
241 m.Parts = append(m.Parts, TextContent{Text: delta})
242 }
243}
244
245func (m *Message) AppendReasoningContent(delta string) {
246 found := false
247 for i, part := range m.Parts {
248 if c, ok := part.(ReasoningContent); ok {
249 m.Parts[i] = ReasoningContent{
250 Thinking: c.Thinking + delta,
251 Signature: c.Signature,
252 StartedAt: c.StartedAt,
253 FinishedAt: c.FinishedAt,
254 }
255 found = true
256 }
257 }
258 if !found {
259 m.Parts = append(m.Parts, ReasoningContent{
260 Thinking: delta,
261 StartedAt: time.Now().Unix(),
262 })
263 }
264}
265
266func (m *Message) AppendReasoningSignature(signature string) {
267 for i, part := range m.Parts {
268 if c, ok := part.(ReasoningContent); ok {
269 m.Parts[i] = ReasoningContent{
270 Thinking: c.Thinking,
271 Signature: c.Signature + signature,
272 StartedAt: c.StartedAt,
273 FinishedAt: c.FinishedAt,
274 }
275 return
276 }
277 }
278 m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
279}
280
281func (m *Message) FinishThinking() {
282 for i, part := range m.Parts {
283 if c, ok := part.(ReasoningContent); ok {
284 if c.FinishedAt == 0 {
285 m.Parts[i] = ReasoningContent{
286 Thinking: c.Thinking,
287 Signature: c.Signature,
288 StartedAt: c.StartedAt,
289 FinishedAt: time.Now().Unix(),
290 }
291 }
292 return
293 }
294 }
295}
296
297func (m *Message) ThinkingDuration() time.Duration {
298 reasoning := m.ReasoningContent()
299 if reasoning.StartedAt == 0 {
300 return 0
301 }
302
303 endTime := reasoning.FinishedAt
304 if endTime == 0 {
305 endTime = time.Now().Unix()
306 }
307
308 return time.Duration(endTime-reasoning.StartedAt) * time.Second
309}
310
311func (m *Message) FinishToolCall(toolCallID string) {
312 for i, part := range m.Parts {
313 if c, ok := part.(ToolCall); ok {
314 if c.ID == toolCallID {
315 m.Parts[i] = ToolCall{
316 ID: c.ID,
317 Name: c.Name,
318 Input: c.Input,
319 Type: c.Type,
320 Finished: true,
321 }
322 return
323 }
324 }
325 }
326}
327
328func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
329 for i, part := range m.Parts {
330 if c, ok := part.(ToolCall); ok {
331 if c.ID == toolCallID {
332 m.Parts[i] = ToolCall{
333 ID: c.ID,
334 Name: c.Name,
335 Input: c.Input + inputDelta,
336 Type: c.Type,
337 Finished: c.Finished,
338 }
339 return
340 }
341 }
342 }
343}
344
345func (m *Message) AddToolCall(tc ToolCall) {
346 for i, part := range m.Parts {
347 if c, ok := part.(ToolCall); ok {
348 if c.ID == tc.ID {
349 m.Parts[i] = tc
350 return
351 }
352 }
353 }
354 m.Parts = append(m.Parts, tc)
355}
356
357func (m *Message) SetToolCalls(tc []ToolCall) {
358 // remove any existing tool call part it could have multiple
359 parts := make([]ContentPart, 0)
360 for _, part := range m.Parts {
361 if _, ok := part.(ToolCall); ok {
362 continue
363 }
364 parts = append(parts, part)
365 }
366 m.Parts = parts
367 for _, toolCall := range tc {
368 m.Parts = append(m.Parts, toolCall)
369 }
370}
371
372func (m *Message) AddToolResult(tr ToolResult) {
373 m.Parts = append(m.Parts, tr)
374}
375
376func (m *Message) SetToolResults(tr []ToolResult) {
377 for _, toolResult := range tr {
378 m.Parts = append(m.Parts, toolResult)
379 }
380}
381
382func (m *Message) AddFinish(reason FinishReason, message, details string) {
383 // remove any existing finish part
384 for i, part := range m.Parts {
385 if _, ok := part.(Finish); ok {
386 m.Parts = slices.Delete(m.Parts, i, i+1)
387 break
388 }
389 }
390 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
391}
392
393func (m *Message) AddImageURL(url, detail string) {
394 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
395}
396
397func (m *Message) AddBinary(mimeType string, data []byte) {
398 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
399}
400
401func (m *Message) RetryContent() *RetryContent {
402 for _, part := range m.Parts {
403 if c, ok := part.(RetryContent); ok {
404 return &c
405 }
406 }
407 return nil
408}
409
410func (m *Message) AddRetry(error string, retryAfter int64) {
411 retry := Retry{
412 Error: error,
413 RetryAfter: retryAfter,
414 Timestamp: time.Now().Unix(),
415 }
416
417 found := false
418 for i, part := range m.Parts {
419 if c, ok := part.(RetryContent); ok {
420 m.Parts[i] = RetryContent{
421 Retries: append(c.Retries, retry),
422 Retrying: c.Retrying,
423 }
424 found = true
425 break
426 }
427 }
428 if !found {
429 m.Parts = append(m.Parts, RetryContent{
430 Retries: []Retry{retry},
431 Retrying: false,
432 })
433 }
434}
435
436func (m *Message) SetRetrying(retrying bool) {
437 found := false
438 for i, part := range m.Parts {
439 if c, ok := part.(RetryContent); ok {
440 m.Parts[i] = RetryContent{
441 Retries: c.Retries,
442 Retrying: retrying,
443 }
444 found = true
445 break
446 }
447 }
448 if !found && retrying {
449 m.Parts = append(m.Parts, RetryContent{
450 Retries: []Retry{},
451 Retrying: retrying,
452 })
453 }
454}
455
456func (m *Message) IsRetrying() bool {
457 if retryContent := m.RetryContent(); retryContent != nil {
458 return retryContent.Retrying
459 }
460 return false
461}
462
463func (m *Message) GetRetries() []Retry {
464 if retryContent := m.RetryContent(); retryContent != nil {
465 return retryContent.Retries
466 }
467 return []Retry{}
468}