content.go

  1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/crush/internal/llm/models"
  9)
 10
 11type MessageRole string
 12
 13const (
 14	Assistant MessageRole = "assistant"
 15	User      MessageRole = "user"
 16	System    MessageRole = "system"
 17	Tool      MessageRole = "tool"
 18)
 19
 20type FinishReason string
 21
 22const (
 23	FinishReasonEndTurn          FinishReason = "end_turn"
 24	FinishReasonMaxTokens        FinishReason = "max_tokens"
 25	FinishReasonToolUse          FinishReason = "tool_use"
 26	FinishReasonCanceled         FinishReason = "canceled"
 27	FinishReasonError            FinishReason = "error"
 28	FinishReasonPermissionDenied FinishReason = "permission_denied"
 29
 30	// Should never happen
 31	FinishReasonUnknown FinishReason = "unknown"
 32)
 33
 34type ContentPart interface {
 35	isPart()
 36}
 37
 38type ReasoningContent struct {
 39	Thinking string `json:"thinking"`
 40}
 41
 42func (tc ReasoningContent) String() string {
 43	return tc.Thinking
 44}
 45func (ReasoningContent) isPart() {}
 46
 47type TextContent struct {
 48	Text string `json:"text"`
 49}
 50
 51func (tc TextContent) String() string {
 52	return tc.Text
 53}
 54
 55func (TextContent) isPart() {}
 56
 57type ImageURLContent struct {
 58	URL    string `json:"url"`
 59	Detail string `json:"detail,omitempty"`
 60}
 61
 62func (iuc ImageURLContent) String() string {
 63	return iuc.URL
 64}
 65
 66func (ImageURLContent) isPart() {}
 67
 68type BinaryContent struct {
 69	Path     string
 70	MIMEType string
 71	Data     []byte
 72}
 73
 74func (bc BinaryContent) String(provider models.InferenceProvider) string {
 75	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 76	if provider == models.ProviderOpenAI {
 77		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 78	}
 79	return base64Encoded
 80}
 81
 82func (BinaryContent) isPart() {}
 83
 84type ToolCall struct {
 85	ID       string `json:"id"`
 86	Name     string `json:"name"`
 87	Input    string `json:"input"`
 88	Type     string `json:"type"`
 89	Finished bool   `json:"finished"`
 90}
 91
 92func (ToolCall) isPart() {}
 93
 94type ToolResult struct {
 95	ToolCallID string `json:"tool_call_id"`
 96	Name       string `json:"name"`
 97	Content    string `json:"content"`
 98	Metadata   string `json:"metadata"`
 99	IsError    bool   `json:"is_error"`
100}
101
102func (ToolResult) isPart() {}
103
104type Finish struct {
105	Reason FinishReason `json:"reason"`
106	Time   int64        `json:"time"`
107}
108
109func (Finish) isPart() {}
110
111type Message struct {
112	ID        string
113	Role      MessageRole
114	SessionID string
115	Parts     []ContentPart
116	Model     models.ModelID
117	CreatedAt int64
118	UpdatedAt int64
119}
120
121func (m *Message) Content() TextContent {
122	for _, part := range m.Parts {
123		if c, ok := part.(TextContent); ok {
124			return c
125		}
126	}
127	return TextContent{}
128}
129
130func (m *Message) ReasoningContent() ReasoningContent {
131	for _, part := range m.Parts {
132		if c, ok := part.(ReasoningContent); ok {
133			return c
134		}
135	}
136	return ReasoningContent{}
137}
138
139func (m *Message) ImageURLContent() []ImageURLContent {
140	imageURLContents := make([]ImageURLContent, 0)
141	for _, part := range m.Parts {
142		if c, ok := part.(ImageURLContent); ok {
143			imageURLContents = append(imageURLContents, c)
144		}
145	}
146	return imageURLContents
147}
148
149func (m *Message) BinaryContent() []BinaryContent {
150	binaryContents := make([]BinaryContent, 0)
151	for _, part := range m.Parts {
152		if c, ok := part.(BinaryContent); ok {
153			binaryContents = append(binaryContents, c)
154		}
155	}
156	return binaryContents
157}
158
159func (m *Message) ToolCalls() []ToolCall {
160	toolCalls := make([]ToolCall, 0)
161	for _, part := range m.Parts {
162		if c, ok := part.(ToolCall); ok {
163			toolCalls = append(toolCalls, c)
164		}
165	}
166	return toolCalls
167}
168
169func (m *Message) ToolResults() []ToolResult {
170	toolResults := make([]ToolResult, 0)
171	for _, part := range m.Parts {
172		if c, ok := part.(ToolResult); ok {
173			toolResults = append(toolResults, c)
174		}
175	}
176	return toolResults
177}
178
179func (m *Message) IsFinished() bool {
180	for _, part := range m.Parts {
181		if _, ok := part.(Finish); ok {
182			return true
183		}
184	}
185	return false
186}
187
188func (m *Message) FinishPart() *Finish {
189	for _, part := range m.Parts {
190		if c, ok := part.(Finish); ok {
191			return &c
192		}
193	}
194	return nil
195}
196
197func (m *Message) FinishReason() FinishReason {
198	for _, part := range m.Parts {
199		if c, ok := part.(Finish); ok {
200			return c.Reason
201		}
202	}
203	return ""
204}
205
206func (m *Message) IsThinking() bool {
207	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
208		return true
209	}
210	return false
211}
212
213func (m *Message) AppendContent(delta string) {
214	found := false
215	for i, part := range m.Parts {
216		if c, ok := part.(TextContent); ok {
217			m.Parts[i] = TextContent{Text: c.Text + delta}
218			found = true
219		}
220	}
221	if !found {
222		m.Parts = append(m.Parts, TextContent{Text: delta})
223	}
224}
225
226func (m *Message) AppendReasoningContent(delta string) {
227	found := false
228	for i, part := range m.Parts {
229		if c, ok := part.(ReasoningContent); ok {
230			m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
231			found = true
232		}
233	}
234	if !found {
235		m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
236	}
237}
238
239func (m *Message) FinishToolCall(toolCallID string) {
240	for i, part := range m.Parts {
241		if c, ok := part.(ToolCall); ok {
242			if c.ID == toolCallID {
243				m.Parts[i] = ToolCall{
244					ID:       c.ID,
245					Name:     c.Name,
246					Input:    c.Input,
247					Type:     c.Type,
248					Finished: true,
249				}
250				return
251			}
252		}
253	}
254}
255
256func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
257	for i, part := range m.Parts {
258		if c, ok := part.(ToolCall); ok {
259			if c.ID == toolCallID {
260				m.Parts[i] = ToolCall{
261					ID:       c.ID,
262					Name:     c.Name,
263					Input:    c.Input + inputDelta,
264					Type:     c.Type,
265					Finished: c.Finished,
266				}
267				return
268			}
269		}
270	}
271}
272
273func (m *Message) AddToolCall(tc ToolCall) {
274	for i, part := range m.Parts {
275		if c, ok := part.(ToolCall); ok {
276			if c.ID == tc.ID {
277				m.Parts[i] = tc
278				return
279			}
280		}
281	}
282	m.Parts = append(m.Parts, tc)
283}
284
285func (m *Message) SetToolCalls(tc []ToolCall) {
286	// remove any existing tool call part it could have multiple
287	parts := make([]ContentPart, 0)
288	for _, part := range m.Parts {
289		if _, ok := part.(ToolCall); ok {
290			continue
291		}
292		parts = append(parts, part)
293	}
294	m.Parts = parts
295	for _, toolCall := range tc {
296		m.Parts = append(m.Parts, toolCall)
297	}
298}
299
300func (m *Message) AddToolResult(tr ToolResult) {
301	m.Parts = append(m.Parts, tr)
302}
303
304func (m *Message) SetToolResults(tr []ToolResult) {
305	for _, toolResult := range tr {
306		m.Parts = append(m.Parts, toolResult)
307	}
308}
309
310func (m *Message) AddFinish(reason FinishReason) {
311	// remove any existing finish part
312	for i, part := range m.Parts {
313		if _, ok := part.(Finish); ok {
314			m.Parts = slices.Delete(m.Parts, i, i+1)
315			break
316		}
317	}
318	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
319}
320
321func (m *Message) AddImageURL(url, detail string) {
322	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
323}
324
325func (m *Message) AddBinary(mimeType string, data []byte) {
326	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
327}