1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9)
 10
 11type MessageRole string
 12
 13const (
 14	Assistant MessageRole = "assistant"
 15	User      MessageRole = "user"
 16	System    MessageRole = "system"
 17	Tool      MessageRole = "tool"
 18)
 19
 20type FinishReason string
 21
 22const (
 23	FinishReasonEndTurn          FinishReason = "end_turn"
 24	FinishReasonMaxTokens        FinishReason = "max_tokens"
 25	FinishReasonToolUse          FinishReason = "tool_use"
 26	FinishReasonCanceled         FinishReason = "canceled"
 27	FinishReasonError            FinishReason = "error"
 28	FinishReasonPermissionDenied FinishReason = "permission_denied"
 29
 30	// Should never happen
 31	FinishReasonUnknown FinishReason = "unknown"
 32)
 33
 34type ContentPart interface {
 35	isPart()
 36}
 37
 38type ReasoningContent struct {
 39	Thinking string `json:"thinking"`
 40}
 41
 42func (tc ReasoningContent) String() string {
 43	return tc.Thinking
 44}
 45func (ReasoningContent) isPart() {}
 46
 47type TextContent struct {
 48	Text string `json:"text"`
 49}
 50
 51func (tc TextContent) String() string {
 52	return tc.Text
 53}
 54
 55func (TextContent) isPart() {}
 56
 57type ImageURLContent struct {
 58	URL    string `json:"url"`
 59	Detail string `json:"detail,omitempty"`
 60}
 61
 62func (iuc ImageURLContent) String() string {
 63	return iuc.URL
 64}
 65
 66func (ImageURLContent) isPart() {}
 67
 68type BinaryContent struct {
 69	Path     string
 70	MIMEType string
 71	Data     []byte
 72}
 73
 74func (bc BinaryContent) String(p provider.InferenceProvider) string {
 75	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 76	if p == provider.InferenceProviderOpenAI {
 77		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 78	}
 79	return base64Encoded
 80}
 81
 82func (BinaryContent) isPart() {}
 83
 84type ToolCall struct {
 85	ID       string `json:"id"`
 86	Name     string `json:"name"`
 87	Input    string `json:"input"`
 88	Type     string `json:"type"`
 89	Finished bool   `json:"finished"`
 90}
 91
 92func (ToolCall) isPart() {}
 93
 94type ToolResult struct {
 95	ToolCallID string `json:"tool_call_id"`
 96	Name       string `json:"name"`
 97	Content    string `json:"content"`
 98	Metadata   string `json:"metadata"`
 99	IsError    bool   `json:"is_error"`
100}
101
102func (ToolResult) isPart() {}
103
104type Finish struct {
105	Reason  FinishReason `json:"reason"`
106	Time    int64        `json:"time"`
107	Message string       `json:"message,omitempty"`
108	Details string       `json:"details,omitempty"`
109}
110
111func (Finish) isPart() {}
112
113type Message struct {
114	ID        string
115	Role      MessageRole
116	SessionID string
117	Parts     []ContentPart
118	Model     string
119	Provider  string
120	CreatedAt int64
121	UpdatedAt int64
122}
123
124func (m *Message) Content() TextContent {
125	for _, part := range m.Parts {
126		if c, ok := part.(TextContent); ok {
127			return c
128		}
129	}
130	return TextContent{}
131}
132
133func (m *Message) ReasoningContent() ReasoningContent {
134	for _, part := range m.Parts {
135		if c, ok := part.(ReasoningContent); ok {
136			return c
137		}
138	}
139	return ReasoningContent{}
140}
141
142func (m *Message) ImageURLContent() []ImageURLContent {
143	imageURLContents := make([]ImageURLContent, 0)
144	for _, part := range m.Parts {
145		if c, ok := part.(ImageURLContent); ok {
146			imageURLContents = append(imageURLContents, c)
147		}
148	}
149	return imageURLContents
150}
151
152func (m *Message) BinaryContent() []BinaryContent {
153	binaryContents := make([]BinaryContent, 0)
154	for _, part := range m.Parts {
155		if c, ok := part.(BinaryContent); ok {
156			binaryContents = append(binaryContents, c)
157		}
158	}
159	return binaryContents
160}
161
162func (m *Message) ToolCalls() []ToolCall {
163	toolCalls := make([]ToolCall, 0)
164	for _, part := range m.Parts {
165		if c, ok := part.(ToolCall); ok {
166			toolCalls = append(toolCalls, c)
167		}
168	}
169	return toolCalls
170}
171
172func (m *Message) ToolResults() []ToolResult {
173	toolResults := make([]ToolResult, 0)
174	for _, part := range m.Parts {
175		if c, ok := part.(ToolResult); ok {
176			toolResults = append(toolResults, c)
177		}
178	}
179	return toolResults
180}
181
182func (m *Message) IsFinished() bool {
183	for _, part := range m.Parts {
184		if _, ok := part.(Finish); ok {
185			return true
186		}
187	}
188	return false
189}
190
191func (m *Message) FinishPart() *Finish {
192	for _, part := range m.Parts {
193		if c, ok := part.(Finish); ok {
194			return &c
195		}
196	}
197	return nil
198}
199
200func (m *Message) FinishReason() FinishReason {
201	for _, part := range m.Parts {
202		if c, ok := part.(Finish); ok {
203			return c.Reason
204		}
205	}
206	return ""
207}
208
209func (m *Message) IsThinking() bool {
210	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
211		return true
212	}
213	return false
214}
215
216func (m *Message) AppendContent(delta string) {
217	found := false
218	for i, part := range m.Parts {
219		if c, ok := part.(TextContent); ok {
220			m.Parts[i] = TextContent{Text: c.Text + delta}
221			found = true
222		}
223	}
224	if !found {
225		m.Parts = append(m.Parts, TextContent{Text: delta})
226	}
227}
228
229func (m *Message) AppendReasoningContent(delta string) {
230	found := false
231	for i, part := range m.Parts {
232		if c, ok := part.(ReasoningContent); ok {
233			m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
234			found = true
235		}
236	}
237	if !found {
238		m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
239	}
240}
241
242func (m *Message) FinishToolCall(toolCallID string) {
243	for i, part := range m.Parts {
244		if c, ok := part.(ToolCall); ok {
245			if c.ID == toolCallID {
246				m.Parts[i] = ToolCall{
247					ID:       c.ID,
248					Name:     c.Name,
249					Input:    c.Input,
250					Type:     c.Type,
251					Finished: true,
252				}
253				return
254			}
255		}
256	}
257}
258
259func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
260	for i, part := range m.Parts {
261		if c, ok := part.(ToolCall); ok {
262			if c.ID == toolCallID {
263				m.Parts[i] = ToolCall{
264					ID:       c.ID,
265					Name:     c.Name,
266					Input:    c.Input + inputDelta,
267					Type:     c.Type,
268					Finished: c.Finished,
269				}
270				return
271			}
272		}
273	}
274}
275
276func (m *Message) AddToolCall(tc ToolCall) {
277	for i, part := range m.Parts {
278		if c, ok := part.(ToolCall); ok {
279			if c.ID == tc.ID {
280				m.Parts[i] = tc
281				return
282			}
283		}
284	}
285	m.Parts = append(m.Parts, tc)
286}
287
288func (m *Message) SetToolCalls(tc []ToolCall) {
289	// remove any existing tool call part it could have multiple
290	parts := make([]ContentPart, 0)
291	for _, part := range m.Parts {
292		if _, ok := part.(ToolCall); ok {
293			continue
294		}
295		parts = append(parts, part)
296	}
297	m.Parts = parts
298	for _, toolCall := range tc {
299		m.Parts = append(m.Parts, toolCall)
300	}
301}
302
303func (m *Message) AddToolResult(tr ToolResult) {
304	m.Parts = append(m.Parts, tr)
305}
306
307func (m *Message) SetToolResults(tr []ToolResult) {
308	for _, toolResult := range tr {
309		m.Parts = append(m.Parts, toolResult)
310	}
311}
312
313func (m *Message) AddFinish(reason FinishReason, message, details string) {
314	// remove any existing finish part
315	for i, part := range m.Parts {
316		if _, ok := part.(Finish); ok {
317			m.Parts = slices.Delete(m.Parts, i, i+1)
318			break
319		}
320	}
321	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
322}
323
324func (m *Message) AddImageURL(url, detail string) {
325	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
326}
327
328func (m *Message) AddBinary(mimeType string, data []byte) {
329	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
330}