1package message
2
3import (
4 "encoding/base64"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/llm/models"
9)
10
11type MessageRole string
12
13const (
14 Assistant MessageRole = "assistant"
15 User MessageRole = "user"
16 System MessageRole = "system"
17 Tool MessageRole = "tool"
18)
19
20type FinishReason string
21
22const (
23 FinishReasonEndTurn FinishReason = "end_turn"
24 FinishReasonMaxTokens FinishReason = "max_tokens"
25 FinishReasonToolUse FinishReason = "tool_use"
26 FinishReasonCanceled FinishReason = "canceled"
27 FinishReasonError FinishReason = "error"
28 FinishReasonPermissionDenied FinishReason = "permission_denied"
29
30 // Should never happen
31 FinishReasonUnknown FinishReason = "unknown"
32)
33
34type ContentPart interface {
35 isPart()
36}
37
38type ReasoningContent struct {
39 Thinking string `json:"thinking"`
40}
41
42func (tc ReasoningContent) String() string {
43 return tc.Thinking
44}
45func (ReasoningContent) isPart() {}
46
47type TextContent struct {
48 Text string `json:"text"`
49}
50
51func (tc TextContent) String() string {
52 return tc.Text
53}
54
55func (TextContent) isPart() {}
56
57type ImageURLContent struct {
58 URL string `json:"url"`
59 Detail string `json:"detail,omitempty"`
60}
61
62func (iuc ImageURLContent) String() string {
63 return iuc.URL
64}
65
66func (ImageURLContent) isPart() {}
67
68type BinaryContent struct {
69 Path string
70 MIMEType string
71 Data []byte
72}
73
74func (bc BinaryContent) String(provider models.InferenceProvider) string {
75 base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
76 if provider == models.ProviderOpenAI {
77 return "data:" + bc.MIMEType + ";base64," + base64Encoded
78 }
79 return base64Encoded
80}
81
82func (BinaryContent) isPart() {}
83
84type ToolCall struct {
85 ID string `json:"id"`
86 Name string `json:"name"`
87 Input string `json:"input"`
88 Type string `json:"type"`
89 Finished bool `json:"finished"`
90}
91
92func (ToolCall) isPart() {}
93
94type ToolResult struct {
95 ToolCallID string `json:"tool_call_id"`
96 Name string `json:"name"`
97 Content string `json:"content"`
98 Metadata string `json:"metadata"`
99 IsError bool `json:"is_error"`
100}
101
102func (ToolResult) isPart() {}
103
104type Finish struct {
105 Reason FinishReason `json:"reason"`
106 Time int64 `json:"time"`
107}
108
109func (Finish) isPart() {}
110
111type Message struct {
112 ID string
113 Role MessageRole
114 SessionID string
115 Parts []ContentPart
116 Model models.ModelID
117 CreatedAt int64
118 UpdatedAt int64
119}
120
121func (m *Message) Content() TextContent {
122 for _, part := range m.Parts {
123 if c, ok := part.(TextContent); ok {
124 return c
125 }
126 }
127 return TextContent{}
128}
129
130func (m *Message) ReasoningContent() ReasoningContent {
131 for _, part := range m.Parts {
132 if c, ok := part.(ReasoningContent); ok {
133 return c
134 }
135 }
136 return ReasoningContent{}
137}
138
139func (m *Message) ImageURLContent() []ImageURLContent {
140 imageURLContents := make([]ImageURLContent, 0)
141 for _, part := range m.Parts {
142 if c, ok := part.(ImageURLContent); ok {
143 imageURLContents = append(imageURLContents, c)
144 }
145 }
146 return imageURLContents
147}
148
149func (m *Message) BinaryContent() []BinaryContent {
150 binaryContents := make([]BinaryContent, 0)
151 for _, part := range m.Parts {
152 if c, ok := part.(BinaryContent); ok {
153 binaryContents = append(binaryContents, c)
154 }
155 }
156 return binaryContents
157}
158
159func (m *Message) ToolCalls() []ToolCall {
160 toolCalls := make([]ToolCall, 0)
161 for _, part := range m.Parts {
162 if c, ok := part.(ToolCall); ok {
163 toolCalls = append(toolCalls, c)
164 }
165 }
166 return toolCalls
167}
168
169func (m *Message) ToolResults() []ToolResult {
170 toolResults := make([]ToolResult, 0)
171 for _, part := range m.Parts {
172 if c, ok := part.(ToolResult); ok {
173 toolResults = append(toolResults, c)
174 }
175 }
176 return toolResults
177}
178
179func (m *Message) IsFinished() bool {
180 for _, part := range m.Parts {
181 if _, ok := part.(Finish); ok {
182 return true
183 }
184 }
185 return false
186}
187
188func (m *Message) FinishPart() *Finish {
189 for _, part := range m.Parts {
190 if c, ok := part.(Finish); ok {
191 return &c
192 }
193 }
194 return nil
195}
196
197func (m *Message) FinishReason() FinishReason {
198 for _, part := range m.Parts {
199 if c, ok := part.(Finish); ok {
200 return c.Reason
201 }
202 }
203 return ""
204}
205
206func (m *Message) IsThinking() bool {
207 if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
208 return true
209 }
210 return false
211}
212
213func (m *Message) AppendContent(delta string) {
214 found := false
215 for i, part := range m.Parts {
216 if c, ok := part.(TextContent); ok {
217 m.Parts[i] = TextContent{Text: c.Text + delta}
218 found = true
219 }
220 }
221 if !found {
222 m.Parts = append(m.Parts, TextContent{Text: delta})
223 }
224}
225
226func (m *Message) AppendReasoningContent(delta string) {
227 found := false
228 for i, part := range m.Parts {
229 if c, ok := part.(ReasoningContent); ok {
230 m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
231 found = true
232 }
233 }
234 if !found {
235 m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
236 }
237}
238
239func (m *Message) FinishToolCall(toolCallID string) {
240 for i, part := range m.Parts {
241 if c, ok := part.(ToolCall); ok {
242 if c.ID == toolCallID {
243 m.Parts[i] = ToolCall{
244 ID: c.ID,
245 Name: c.Name,
246 Input: c.Input,
247 Type: c.Type,
248 Finished: true,
249 }
250 return
251 }
252 }
253 }
254}
255
256func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
257 for i, part := range m.Parts {
258 if c, ok := part.(ToolCall); ok {
259 if c.ID == toolCallID {
260 m.Parts[i] = ToolCall{
261 ID: c.ID,
262 Name: c.Name,
263 Input: c.Input + inputDelta,
264 Type: c.Type,
265 Finished: c.Finished,
266 }
267 return
268 }
269 }
270 }
271}
272
273func (m *Message) AddToolCall(tc ToolCall) {
274 for i, part := range m.Parts {
275 if c, ok := part.(ToolCall); ok {
276 if c.ID == tc.ID {
277 m.Parts[i] = tc
278 return
279 }
280 }
281 }
282 m.Parts = append(m.Parts, tc)
283}
284
285func (m *Message) SetToolCalls(tc []ToolCall) {
286 // remove any existing tool call part it could have multiple
287 parts := make([]ContentPart, 0)
288 for _, part := range m.Parts {
289 if _, ok := part.(ToolCall); ok {
290 continue
291 }
292 parts = append(parts, part)
293 }
294 m.Parts = parts
295 for _, toolCall := range tc {
296 m.Parts = append(m.Parts, toolCall)
297 }
298}
299
300func (m *Message) AddToolResult(tr ToolResult) {
301 m.Parts = append(m.Parts, tr)
302}
303
304func (m *Message) SetToolResults(tr []ToolResult) {
305 for _, toolResult := range tr {
306 m.Parts = append(m.Parts, toolResult)
307 }
308}
309
310func (m *Message) AddFinish(reason FinishReason) {
311 // remove any existing finish part
312 for i, part := range m.Parts {
313 if _, ok := part.(Finish); ok {
314 m.Parts = slices.Delete(m.Parts, i, i+1)
315 break
316 }
317 }
318 m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
319}
320
321func (m *Message) AddImageURL(url, detail string) {
322 m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
323}
324
325func (m *Message) AddBinary(mimeType string, data []byte) {
326 m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
327}