1package message
2
3import (
4 "encoding/base64"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/fur/provider"
9)
10
11type MessageRole string
12
13const (
14 Assistant MessageRole = "assistant"
15 User MessageRole = "user"
16 System MessageRole = "system"
17 Tool MessageRole = "tool"
18)
19
20type FinishReason string
21
22const (
23 FinishReasonEndTurn FinishReason = "end_turn"
24 FinishReasonMaxTokens FinishReason = "max_tokens"
25 FinishReasonToolUse FinishReason = "tool_use"
26 FinishReasonCanceled FinishReason = "canceled"
27 FinishReasonError FinishReason = "error"
28 FinishReasonPermissionDenied FinishReason = "permission_denied"
29
30 // Should never happen
31 FinishReasonUnknown FinishReason = "unknown"
32)
33
34type ContentPart interface {
35 isPart()
36}
37
38type ReasoningContent struct {
39 Thinking string `json:"thinking"`
40}
41
42func (tc ReasoningContent) String() string {
43 return tc.Thinking
44}
45func (ReasoningContent) isPart() {}
46
47type TextContent struct {
48 Text string `json:"text"`
49}
50
51func (tc TextContent) String() string {
52 return tc.Text
53}
54
55func (TextContent) isPart() {}
56
57type ImageURLContent struct {
58 URL string `json:"url"`
59 Detail string `json:"detail,omitempty"`
60}
61
62func (iuc ImageURLContent) String() string {
63 return iuc.URL
64}
65
66func (ImageURLContent) isPart() {}
67
68type BinaryContent struct {
69 Path string
70 MIMEType string
71 Data []byte
72}
73
74func (bc BinaryContent) String(p provider.InferenceProvider) string {
75 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
76 if p == provider.InferenceProviderOpenAI {
77 return "data:" + bc.MIMEType + ";base64," + base64Encoded
78 }
79 return base64Encoded
80}
81
82func (BinaryContent) isPart() {}
83
84type ToolCall struct {
85 ID string `json:"id"`
86 Name string `json:"name"`
87 Input string `json:"input"`
88 Type string `json:"type"`
89 Finished bool `json:"finished"`
90}
91
92func (ToolCall) isPart() {}
93
94type ToolResult struct {
95 ToolCallID string `json:"tool_call_id"`
96 Name string `json:"name"`
97 Content string `json:"content"`
98 Metadata string `json:"metadata"`
99 IsError bool `json:"is_error"`
100}
101
102func (ToolResult) isPart() {}
103
104type Finish struct {
105 Reason FinishReason `json:"reason"`
106 Time int64 `json:"time"`
107}
108
109func (Finish) isPart() {}
110
111type Message struct {
112 ID string
113 Role MessageRole
114 SessionID string
115 Parts []ContentPart
116 Model string
117 Provider string
118 CreatedAt int64
119 UpdatedAt int64
120}
121
122func (m *Message) Content() TextContent {
123 for _, part := range m.Parts {
124 if c, ok := part.(TextContent); ok {
125 return c
126 }
127 }
128 return TextContent{}
129}
130
131func (m *Message) ReasoningContent() ReasoningContent {
132 for _, part := range m.Parts {
133 if c, ok := part.(ReasoningContent); ok {
134 return c
135 }
136 }
137 return ReasoningContent{}
138}
139
140func (m *Message) ImageURLContent() []ImageURLContent {
141 imageURLContents := make([]ImageURLContent, 0)
142 for _, part := range m.Parts {
143 if c, ok := part.(ImageURLContent); ok {
144 imageURLContents = append(imageURLContents, c)
145 }
146 }
147 return imageURLContents
148}
149
150func (m *Message) BinaryContent() []BinaryContent {
151 binaryContents := make([]BinaryContent, 0)
152 for _, part := range m.Parts {
153 if c, ok := part.(BinaryContent); ok {
154 binaryContents = append(binaryContents, c)
155 }
156 }
157 return binaryContents
158}
159
160func (m *Message) ToolCalls() []ToolCall {
161 toolCalls := make([]ToolCall, 0)
162 for _, part := range m.Parts {
163 if c, ok := part.(ToolCall); ok {
164 toolCalls = append(toolCalls, c)
165 }
166 }
167 return toolCalls
168}
169
170func (m *Message) ToolResults() []ToolResult {
171 toolResults := make([]ToolResult, 0)
172 for _, part := range m.Parts {
173 if c, ok := part.(ToolResult); ok {
174 toolResults = append(toolResults, c)
175 }
176 }
177 return toolResults
178}
179
180func (m *Message) IsFinished() bool {
181 for _, part := range m.Parts {
182 if _, ok := part.(Finish); ok {
183 return true
184 }
185 }
186 return false
187}
188
189func (m *Message) FinishPart() *Finish {
190 for _, part := range m.Parts {
191 if c, ok := part.(Finish); ok {
192 return &c
193 }
194 }
195 return nil
196}
197
198func (m *Message) FinishReason() FinishReason {
199 for _, part := range m.Parts {
200 if c, ok := part.(Finish); ok {
201 return c.Reason
202 }
203 }
204 return ""
205}
206
207func (m *Message) IsThinking() bool {
208 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
209 return true
210 }
211 return false
212}
213
214func (m *Message) AppendContent(delta string) {
215 found := false
216 for i, part := range m.Parts {
217 if c, ok := part.(TextContent); ok {
218 m.Parts[i] = TextContent{Text: c.Text + delta}
219 found = true
220 }
221 }
222 if !found {
223 m.Parts = append(m.Parts, TextContent{Text: delta})
224 }
225}
226
227func (m *Message) AppendReasoningContent(delta string) {
228 found := false
229 for i, part := range m.Parts {
230 if c, ok := part.(ReasoningContent); ok {
231 m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
232 found = true
233 }
234 }
235 if !found {
236 m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
237 }
238}
239
240func (m *Message) FinishToolCall(toolCallID string) {
241 for i, part := range m.Parts {
242 if c, ok := part.(ToolCall); ok {
243 if c.ID == toolCallID {
244 m.Parts[i] = ToolCall{
245 ID: c.ID,
246 Name: c.Name,
247 Input: c.Input,
248 Type: c.Type,
249 Finished: true,
250 }
251 return
252 }
253 }
254 }
255}
256
257func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
258 for i, part := range m.Parts {
259 if c, ok := part.(ToolCall); ok {
260 if c.ID == toolCallID {
261 m.Parts[i] = ToolCall{
262 ID: c.ID,
263 Name: c.Name,
264 Input: c.Input + inputDelta,
265 Type: c.Type,
266 Finished: c.Finished,
267 }
268 return
269 }
270 }
271 }
272}
273
274func (m *Message) AddToolCall(tc ToolCall) {
275 for i, part := range m.Parts {
276 if c, ok := part.(ToolCall); ok {
277 if c.ID == tc.ID {
278 m.Parts[i] = tc
279 return
280 }
281 }
282 }
283 m.Parts = append(m.Parts, tc)
284}
285
286func (m *Message) SetToolCalls(tc []ToolCall) {
287 // remove any existing tool call part it could have multiple
288 parts := make([]ContentPart, 0)
289 for _, part := range m.Parts {
290 if _, ok := part.(ToolCall); ok {
291 continue
292 }
293 parts = append(parts, part)
294 }
295 m.Parts = parts
296 for _, toolCall := range tc {
297 m.Parts = append(m.Parts, toolCall)
298 }
299}
300
301func (m *Message) AddToolResult(tr ToolResult) {
302 m.Parts = append(m.Parts, tr)
303}
304
305func (m *Message) SetToolResults(tr []ToolResult) {
306 for _, toolResult := range tr {
307 m.Parts = append(m.Parts, toolResult)
308 }
309}
310
311func (m *Message) AddFinish(reason FinishReason) {
312 // remove any existing finish part
313 for i, part := range m.Parts {
314 if _, ok := part.(Finish); ok {
315 m.Parts = slices.Delete(m.Parts, i, i+1)
316 break
317 }
318 }
319 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
320}
321
322func (m *Message) AddImageURL(url, detail string) {
323 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
324}
325
326func (m *Message) AddBinary(mimeType string, data []byte) {
327 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
328}