1package message
  2
  3import (
  4	"encoding/base64"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9	"github.com/charmbracelet/crush/internal/ai"
 10)
 11
 12type MessageRole string
 13
 14const (
 15	Assistant MessageRole = "assistant"
 16	User      MessageRole = "user"
 17	Tool      MessageRole = "tool"
 18)
 19
 20type FinishReason string
 21
 22const (
 23	FinishReasonEndTurn          FinishReason = "end_turn"
 24	FinishReasonMaxTokens        FinishReason = "max_tokens"
 25	FinishReasonToolUse          FinishReason = "tool_use"
 26	FinishReasonCanceled         FinishReason = "canceled"
 27	FinishReasonError            FinishReason = "error"
 28	FinishReasonPermissionDenied FinishReason = "permission_denied"
 29
 30	// Should never happen
 31	FinishReasonUnknown FinishReason = "unknown"
 32)
 33
 34type ContentPart interface {
 35	isPart()
 36}
 37
 38type ReasoningContent struct {
 39	Thinking   string `json:"thinking"`
 40	Signature  string `json:"signature"`
 41	StartedAt  int64  `json:"started_at,omitempty"`
 42	FinishedAt int64  `json:"finished_at,omitempty"`
 43}
 44
 45func (tc ReasoningContent) String() string {
 46	return tc.Thinking
 47}
 48func (ReasoningContent) isPart() {}
 49
 50type TextContent struct {
 51	Text string `json:"text"`
 52}
 53
 54func (tc TextContent) String() string {
 55	return tc.Text
 56}
 57
 58func (TextContent) isPart() {}
 59
 60type ImageURLContent struct {
 61	URL    string `json:"url"`
 62	Detail string `json:"detail,omitempty"`
 63}
 64
 65func (iuc ImageURLContent) String() string {
 66	return iuc.URL
 67}
 68
 69func (ImageURLContent) isPart() {}
 70
 71type BinaryContent struct {
 72	Path     string
 73	MIMEType string
 74	Data     []byte
 75}
 76
 77// TODO: remove this from the content
 78func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
 79	base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
 80	if p == catwalk.InferenceProviderOpenAI {
 81		return "data:" + bc.MIMEType + ";base64," + base64Encoded
 82	}
 83	return base64Encoded
 84}
 85
 86func (BinaryContent) isPart() {}
 87
 88type ToolCall struct {
 89	ID               string `json:"id"`
 90	Name             string `json:"name"`
 91	Input            string `json:"input"`
 92	ProviderExecuted bool   `json:"provider_executed"`
 93	Type             string `json:"type"`
 94	Finished         bool   `json:"finished"`
 95}
 96
 97func (ToolCall) isPart() {}
 98
 99// TODO: make the tool result content be either string, or media content
100type ToolResult struct {
101	ToolCallID string `json:"tool_call_id"`
102	Name       string `json:"name"`
103	Content    string `json:"content"`
104	Data       string `json:"data"`
105	MIMEType   string `json:"mime_type"`
106	Metadata   string `json:"metadata"`
107	IsError    bool   `json:"is_error"`
108}
109
110func (ToolResult) isPart() {}
111
112type Finish struct {
113	Reason  FinishReason `json:"reason"`
114	Time    int64        `json:"time"`
115	Message string       `json:"message,omitempty"`
116	Details string       `json:"details,omitempty"`
117}
118
119func (Finish) isPart() {}
120
121type Message struct {
122	ID        string
123	Role      MessageRole
124	SessionID string
125	Parts     []ContentPart
126	Model     string
127	Provider  string
128	CreatedAt int64
129	UpdatedAt int64
130}
131
132func (m *Message) Content() TextContent {
133	for _, part := range m.Parts {
134		if c, ok := part.(TextContent); ok {
135			return c
136		}
137	}
138	return TextContent{}
139}
140
141func (m *Message) ReasoningContent() ReasoningContent {
142	for _, part := range m.Parts {
143		if c, ok := part.(ReasoningContent); ok {
144			return c
145		}
146	}
147	return ReasoningContent{}
148}
149
150func (m *Message) ImageURLContent() []ImageURLContent {
151	imageURLContents := make([]ImageURLContent, 0)
152	for _, part := range m.Parts {
153		if c, ok := part.(ImageURLContent); ok {
154			imageURLContents = append(imageURLContents, c)
155		}
156	}
157	return imageURLContents
158}
159
160func (m *Message) BinaryContent() []BinaryContent {
161	binaryContents := make([]BinaryContent, 0)
162	for _, part := range m.Parts {
163		if c, ok := part.(BinaryContent); ok {
164			binaryContents = append(binaryContents, c)
165		}
166	}
167	return binaryContents
168}
169
170func (m *Message) ToolCalls() []ToolCall {
171	toolCalls := make([]ToolCall, 0)
172	for _, part := range m.Parts {
173		if c, ok := part.(ToolCall); ok {
174			toolCalls = append(toolCalls, c)
175		}
176	}
177	return toolCalls
178}
179
180func (m *Message) ToolResults() []ToolResult {
181	toolResults := make([]ToolResult, 0)
182	for _, part := range m.Parts {
183		if c, ok := part.(ToolResult); ok {
184			toolResults = append(toolResults, c)
185		}
186	}
187	return toolResults
188}
189
190func (m *Message) IsFinished() bool {
191	for _, part := range m.Parts {
192		if _, ok := part.(Finish); ok {
193			return true
194		}
195	}
196	return false
197}
198
199func (m *Message) FinishPart() *Finish {
200	for _, part := range m.Parts {
201		if c, ok := part.(Finish); ok {
202			return &c
203		}
204	}
205	return nil
206}
207
208func (m *Message) FinishReason() FinishReason {
209	for _, part := range m.Parts {
210		if c, ok := part.(Finish); ok {
211			return c.Reason
212		}
213	}
214	return ""
215}
216
217func (m *Message) IsThinking() bool {
218	if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
219		return true
220	}
221	return false
222}
223
224func (m *Message) AppendContent(delta string) {
225	found := false
226	for i, part := range m.Parts {
227		if c, ok := part.(TextContent); ok {
228			m.Parts[i] = TextContent{Text: c.Text + delta}
229			found = true
230		}
231	}
232	if !found {
233		m.Parts = append(m.Parts, TextContent{Text: delta})
234	}
235}
236
237func (m *Message) AppendReasoningContent(delta string) {
238	found := false
239	for i, part := range m.Parts {
240		if c, ok := part.(ReasoningContent); ok {
241			m.Parts[i] = ReasoningContent{
242				Thinking:   c.Thinking + delta,
243				Signature:  c.Signature,
244				StartedAt:  c.StartedAt,
245				FinishedAt: c.FinishedAt,
246			}
247			found = true
248		}
249	}
250	if !found {
251		m.Parts = append(m.Parts, ReasoningContent{
252			Thinking:  delta,
253			StartedAt: time.Now().Unix(),
254		})
255	}
256}
257
258func (m *Message) AppendReasoningSignature(signature string) {
259	for i, part := range m.Parts {
260		if c, ok := part.(ReasoningContent); ok {
261			m.Parts[i] = ReasoningContent{
262				Thinking:   c.Thinking,
263				Signature:  c.Signature + signature,
264				StartedAt:  c.StartedAt,
265				FinishedAt: c.FinishedAt,
266			}
267			return
268		}
269	}
270	m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
271}
272
273func (m *Message) FinishThinking() {
274	for i, part := range m.Parts {
275		if c, ok := part.(ReasoningContent); ok {
276			if c.FinishedAt == 0 {
277				m.Parts[i] = ReasoningContent{
278					Thinking:   c.Thinking,
279					Signature:  c.Signature,
280					StartedAt:  c.StartedAt,
281					FinishedAt: time.Now().Unix(),
282				}
283			}
284			return
285		}
286	}
287}
288
289func (m *Message) ThinkingDuration() time.Duration {
290	reasoning := m.ReasoningContent()
291	if reasoning.StartedAt == 0 {
292		return 0
293	}
294
295	endTime := reasoning.FinishedAt
296	if endTime == 0 {
297		endTime = time.Now().Unix()
298	}
299
300	return time.Duration(endTime-reasoning.StartedAt) * time.Second
301}
302
303func (m *Message) FinishToolCall(toolCallID string) {
304	for i, part := range m.Parts {
305		if c, ok := part.(ToolCall); ok {
306			if c.ID == toolCallID {
307				m.Parts[i] = ToolCall{
308					ID:       c.ID,
309					Name:     c.Name,
310					Input:    c.Input,
311					Type:     c.Type,
312					Finished: true,
313				}
314				return
315			}
316		}
317	}
318}
319
320func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
321	for i, part := range m.Parts {
322		if c, ok := part.(ToolCall); ok {
323			if c.ID == toolCallID {
324				m.Parts[i] = ToolCall{
325					ID:       c.ID,
326					Name:     c.Name,
327					Input:    c.Input + inputDelta,
328					Type:     c.Type,
329					Finished: c.Finished,
330				}
331				return
332			}
333		}
334	}
335}
336
337func (m *Message) AddToolCall(tc ToolCall) {
338	for i, part := range m.Parts {
339		if c, ok := part.(ToolCall); ok {
340			if c.ID == tc.ID {
341				m.Parts[i] = tc
342				return
343			}
344		}
345	}
346	m.Parts = append(m.Parts, tc)
347}
348
349func (m *Message) SetToolCalls(tc []ToolCall) {
350	// remove any existing tool call part it could have multiple
351	parts := make([]ContentPart, 0)
352	for _, part := range m.Parts {
353		if _, ok := part.(ToolCall); ok {
354			continue
355		}
356		parts = append(parts, part)
357	}
358	m.Parts = parts
359	for _, toolCall := range tc {
360		m.Parts = append(m.Parts, toolCall)
361	}
362}
363
364func (m *Message) AddToolResult(tr ToolResult) {
365	m.Parts = append(m.Parts, tr)
366}
367
368func (m *Message) SetToolResults(tr []ToolResult) {
369	for _, toolResult := range tr {
370		m.Parts = append(m.Parts, toolResult)
371	}
372}
373
374func (m *Message) AddFinish(reason FinishReason, message, details string) {
375	// remove any existing finish part
376	for i, part := range m.Parts {
377		if _, ok := part.(Finish); ok {
378			m.Parts = slices.Delete(m.Parts, i, i+1)
379			break
380		}
381	}
382	m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
383}
384
385func (m *Message) AddImageURL(url, detail string) {
386	m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
387}
388
389func (m *Message) AddBinary(mimeType string, data []byte) {
390	m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
391}
392
393func (m *Message) ToAIMessage() ai.Message {
394	var role ai.MessageRole
395	var parts []ai.MessagePart
396	switch m.Role {
397	case User:
398		role = ai.MessageRoleUser
399		if m.Content().Text != "" {
400			parts = append(parts, ai.TextPart{Text: m.Content().Text})
401		}
402		for _, content := range m.BinaryContent() {
403			parts = append(parts, ai.FilePart{
404				Filename:  content.Path,
405				Data:      content.Data,
406				MediaType: content.MIMEType,
407			})
408		}
409	case Assistant:
410		role = ai.MessageRoleAssistant
411		if m.Content().Text != "" {
412			parts = append(parts, ai.TextPart{Text: m.Content().Text})
413		}
414		for _, call := range m.ToolCalls() {
415			parts = append(parts, ai.ToolCallPart{
416				ToolCallID:       call.ID,
417				ToolName:         call.Name,
418				Input:            call.Input,
419				ProviderExecuted: call.ProviderExecuted,
420			})
421		}
422	case Tool:
423		role = ai.MessageRoleTool
424		for _, result := range m.ToolResults() {
425			var content ai.ToolResultOutputContent
426			if result.IsError {
427				content = ai.ToolResultOutputContentError{
428					Error: result.Content,
429				}
430			} else if result.Data != "" {
431				content = ai.ToolResultOutputContentMedia{
432					Data:      result.Data,
433					MediaType: result.MIMEType,
434				}
435			} else {
436				content = ai.ToolResultOutputContentText{
437					Text: result.Content,
438				}
439			}
440			parts = append(parts, ai.ToolResultPart{
441				ToolCallID: result.ToolCallID,
442				Output:     content,
443			})
444		}
445	}
446
447	msg := ai.Message{
448		Role:    role,
449		Content: parts,
450	}
451
452	return msg
453}