1package message
2
3import (
4 "encoding/base64"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/fur/provider"
9)
10
11type MessageRole string
12
13const (
14 Assistant MessageRole = "assistant"
15 User MessageRole = "user"
16 System MessageRole = "system"
17 Tool MessageRole = "tool"
18)
19
20type FinishReason string
21
22const (
23 FinishReasonEndTurn FinishReason = "end_turn"
24 FinishReasonMaxTokens FinishReason = "max_tokens"
25 FinishReasonToolUse FinishReason = "tool_use"
26 FinishReasonCanceled FinishReason = "canceled"
27 FinishReasonError FinishReason = "error"
28 FinishReasonPermissionDenied FinishReason = "permission_denied"
29
30 // Should never happen
31 FinishReasonUnknown FinishReason = "unknown"
32)
33
34type ContentPart interface {
35 isPart()
36}
37
38type ReasoningContent struct {
39 Thinking string `json:"thinking"`
40}
41
42func (tc ReasoningContent) String() string {
43 return tc.Thinking
44}
45func (ReasoningContent) isPart() {}
46
47type TextContent struct {
48 Text string `json:"text"`
49}
50
51func (tc TextContent) String() string {
52 return tc.Text
53}
54
55func (TextContent) isPart() {}
56
57type ImageURLContent struct {
58 URL string `json:"url"`
59 Detail string `json:"detail,omitempty"`
60}
61
62func (iuc ImageURLContent) String() string {
63 return iuc.URL
64}
65
66func (ImageURLContent) isPart() {}
67
68type BinaryContent struct {
69 Path string
70 MIMEType string
71 Data []byte
72}
73
74func (bc BinaryContent) String(p provider.InferenceProvider) string {
75 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
76 if p == provider.InferenceProviderOpenAI {
77 return "data:" + bc.MIMEType + ";base64," + base64Encoded
78 }
79 return base64Encoded
80}
81
82func (BinaryContent) isPart() {}
83
84type ToolCall struct {
85 ID string `json:"id"`
86 Name string `json:"name"`
87 Input string `json:"input"`
88 Type string `json:"type"`
89 Finished bool `json:"finished"`
90}
91
92func (ToolCall) isPart() {}
93
94type ToolResult struct {
95 ToolCallID string `json:"tool_call_id"`
96 Name string `json:"name"`
97 Content string `json:"content"`
98 Metadata string `json:"metadata"`
99 IsError bool `json:"is_error"`
100}
101
102func (ToolResult) isPart() {}
103
104type Finish struct {
105 Reason FinishReason `json:"reason"`
106 Time int64 `json:"time"`
107 Message string `json:"message,omitempty"`
108 Details string `json:"details,omitempty"`
109}
110
111func (Finish) isPart() {}
112
113type Message struct {
114 ID string
115 Role MessageRole
116 SessionID string
117 Parts []ContentPart
118 Model string
119 Provider string
120 CreatedAt int64
121 UpdatedAt int64
122}
123
124func (m *Message) Content() TextContent {
125 for _, part := range m.Parts {
126 if c, ok := part.(TextContent); ok {
127 return c
128 }
129 }
130 return TextContent{}
131}
132
133func (m *Message) ReasoningContent() ReasoningContent {
134 for _, part := range m.Parts {
135 if c, ok := part.(ReasoningContent); ok {
136 return c
137 }
138 }
139 return ReasoningContent{}
140}
141
142func (m *Message) ImageURLContent() []ImageURLContent {
143 imageURLContents := make([]ImageURLContent, 0)
144 for _, part := range m.Parts {
145 if c, ok := part.(ImageURLContent); ok {
146 imageURLContents = append(imageURLContents, c)
147 }
148 }
149 return imageURLContents
150}
151
152func (m *Message) BinaryContent() []BinaryContent {
153 binaryContents := make([]BinaryContent, 0)
154 for _, part := range m.Parts {
155 if c, ok := part.(BinaryContent); ok {
156 binaryContents = append(binaryContents, c)
157 }
158 }
159 return binaryContents
160}
161
162func (m *Message) ToolCalls() []ToolCall {
163 toolCalls := make([]ToolCall, 0)
164 for _, part := range m.Parts {
165 if c, ok := part.(ToolCall); ok {
166 toolCalls = append(toolCalls, c)
167 }
168 }
169 return toolCalls
170}
171
172func (m *Message) ToolResults() []ToolResult {
173 toolResults := make([]ToolResult, 0)
174 for _, part := range m.Parts {
175 if c, ok := part.(ToolResult); ok {
176 toolResults = append(toolResults, c)
177 }
178 }
179 return toolResults
180}
181
182func (m *Message) IsFinished() bool {
183 for _, part := range m.Parts {
184 if _, ok := part.(Finish); ok {
185 return true
186 }
187 }
188 return false
189}
190
191func (m *Message) FinishPart() *Finish {
192 for _, part := range m.Parts {
193 if c, ok := part.(Finish); ok {
194 return &c
195 }
196 }
197 return nil
198}
199
200func (m *Message) FinishReason() FinishReason {
201 for _, part := range m.Parts {
202 if c, ok := part.(Finish); ok {
203 return c.Reason
204 }
205 }
206 return ""
207}
208
209func (m *Message) IsThinking() bool {
210 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
211 return true
212 }
213 return false
214}
215
216func (m *Message) AppendContent(delta string) {
217 found := false
218 for i, part := range m.Parts {
219 if c, ok := part.(TextContent); ok {
220 m.Parts[i] = TextContent{Text: c.Text + delta}
221 found = true
222 }
223 }
224 if !found {
225 m.Parts = append(m.Parts, TextContent{Text: delta})
226 }
227}
228
229func (m *Message) AppendReasoningContent(delta string) {
230 found := false
231 for i, part := range m.Parts {
232 if c, ok := part.(ReasoningContent); ok {
233 m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
234 found = true
235 }
236 }
237 if !found {
238 m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
239 }
240}
241
242func (m *Message) FinishToolCall(toolCallID string) {
243 for i, part := range m.Parts {
244 if c, ok := part.(ToolCall); ok {
245 if c.ID == toolCallID {
246 m.Parts[i] = ToolCall{
247 ID: c.ID,
248 Name: c.Name,
249 Input: c.Input,
250 Type: c.Type,
251 Finished: true,
252 }
253 return
254 }
255 }
256 }
257}
258
259func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
260 for i, part := range m.Parts {
261 if c, ok := part.(ToolCall); ok {
262 if c.ID == toolCallID {
263 m.Parts[i] = ToolCall{
264 ID: c.ID,
265 Name: c.Name,
266 Input: c.Input + inputDelta,
267 Type: c.Type,
268 Finished: c.Finished,
269 }
270 return
271 }
272 }
273 }
274}
275
276func (m *Message) AddToolCall(tc ToolCall) {
277 for i, part := range m.Parts {
278 if c, ok := part.(ToolCall); ok {
279 if c.ID == tc.ID {
280 m.Parts[i] = tc
281 return
282 }
283 }
284 }
285 m.Parts = append(m.Parts, tc)
286}
287
288func (m *Message) SetToolCalls(tc []ToolCall) {
289 // remove any existing tool call part it could have multiple
290 parts := make([]ContentPart, 0)
291 for _, part := range m.Parts {
292 if _, ok := part.(ToolCall); ok {
293 continue
294 }
295 parts = append(parts, part)
296 }
297 m.Parts = parts
298 for _, toolCall := range tc {
299 m.Parts = append(m.Parts, toolCall)
300 }
301}
302
303func (m *Message) AddToolResult(tr ToolResult) {
304 m.Parts = append(m.Parts, tr)
305}
306
307func (m *Message) SetToolResults(tr []ToolResult) {
308 for _, toolResult := range tr {
309 m.Parts = append(m.Parts, toolResult)
310 }
311}
312
313func (m *Message) AddFinish(reason FinishReason, message, details string) {
314 // remove any existing finish part
315 for i, part := range m.Parts {
316 if _, ok := part.(Finish); ok {
317 m.Parts = slices.Delete(m.Parts, i, i+1)
318 break
319 }
320 }
321 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
322}
323
324func (m *Message) AddImageURL(url, detail string) {
325 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
326}
327
328func (m *Message) AddBinary(mimeType string, data []byte) {
329 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
330}