diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 30539e5dba5a5b7aa66c2d2ffc35336dbe3d9774..65655a21385baa5d98926ed62ba89ba0aac2c539 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -200,7 +200,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy var currentAssistant *message.Message var shouldSummarize bool result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ - Prompt: call.Prompt, + Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments), Files: files, Messages: history, ProviderOptions: call.ProviderOptions, @@ -649,11 +649,11 @@ func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions { } func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) { + parts := []message.ContentPart{message.TextContent{Text: call.Prompt}} var attachmentParts []message.ContentPart for _, attachment := range call.Attachments { attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content}) } - parts := []message.ContentPart{message.TextContent{Text: call.Prompt}} parts = append(parts, attachmentParts...) msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{ Role: message.User, @@ -690,6 +690,9 @@ If not, please feel free to ignore. Again do not mention this message to the use var files []fantasy.FilePart for _, attachment := range attachments { + if attachment.IsText() { + continue + } files = append(files, fantasy.FilePart{ Filename: attachment.FileName, Data: attachment.Content, diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 7ce37758897da42cd8515c94c26b9e24add6ea1e..ef2bdfc9cd7671b43ba22ec8a02b77b7510e5518 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -123,7 +123,14 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, } if !model.CatwalkCfg.SupportsImages && attachments != nil { - attachments = nil + // filter out image attachments + filteredAttachments := make([]message.Attachment, 0, len(attachments)) + for _, att := range attachments { + if att.IsText() { + filteredAttachments = append(filteredAttachments, att) + } + } + attachments = filteredAttachments } providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) diff --git a/internal/message/attachment.go b/internal/message/attachment.go index 6e89f001436ed120d52c08c05ade8c8a741cfb7a..0e3b70a8766c74d37399c1ba8c38fe19e74f871d 100644 --- a/internal/message/attachment.go +++ b/internal/message/attachment.go @@ -1,8 +1,13 @@ package message +import "strings" + type Attachment struct { FilePath string FileName string MimeType string Content []byte } + +func (a Attachment) IsText() bool { return strings.HasPrefix(a.MimeType, "text/") } +func (a Attachment) IsImage() bool { return strings.HasPrefix(a.MimeType, "image/") } diff --git a/internal/message/content.go b/internal/message/content.go index 7333f738c0aa685833c57cc97086e61928d3f51e..6c03d42aed05a7772f37d15dc782bf96c8b69685 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -3,6 +3,7 @@ package message import ( "encoding/base64" "errors" + "fmt" "slices" "strings" "time" @@ -435,16 +436,52 @@ func (m *Message) AddBinary(mimeType string, data []byte) { m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data}) } +func PromptWithTextAttachments(prompt string, attachments []Attachment) string { + addedAttachments := false + for _, content := range attachments { + if !content.IsText() { + continue + } + if !addedAttachments { + prompt += "\nThe files below have been attached by the user, consider them in your response\n" + addedAttachments = true + } + tag := `\n` + if content.FilePath != "" { + tag = fmt.Sprintf("\n", content.FilePath) + } + prompt += tag + prompt += "\n" + string(content.Content) + "\n\n" + } + return prompt +} + func (m *Message) ToAIMessage() []fantasy.Message { var messages []fantasy.Message switch m.Role { case User: var parts []fantasy.MessagePart text := strings.TrimSpace(m.Content().Text) + var textAttachments []Attachment + for _, content := range m.BinaryContent() { + if !strings.HasPrefix(content.MIMEType, "text/") { + continue + } + textAttachments = append(textAttachments, Attachment{ + FilePath: content.Path, + MimeType: content.MIMEType, + Content: content.Data, + }) + } + text = PromptWithTextAttachments(text, textAttachments) if text != "" { parts = append(parts, fantasy.TextPart{Text: text}) } for _, content := range m.BinaryContent() { + // skip text attachements + if strings.HasPrefix(content.MIMEType, "text/") { + continue + } parts = append(parts, fantasy.FilePart{ Filename: content.Path, Data: content.Data, diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index bd90d90a7ea26294ddd7e4149c14f6e7f32e1cb5..014d662ce59d1de84f16cd17057aa158c80384a7 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -2,6 +2,7 @@ package editor import ( "context" + "errors" "fmt" "math/rand" "net/http" @@ -29,6 +30,7 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/dialogs/quit" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" + "github.com/charmbracelet/x/ansi" ) type Editor interface { @@ -84,10 +86,7 @@ var DeleteKeyMaps = DeleteAttachmentKeyMaps{ ), } -const ( - maxAttachments = 5 - maxFileResults = 25 -) +const maxFileResults = 25 type OpenEditorMsg struct { Text string @@ -145,14 +144,14 @@ func (m *editorCmp) send() tea.Cmd { return util.CmdHandler(dialogs.OpenDialogMsg{Model: quit.NewQuitDialog()}) } - m.textarea.Reset() attachments := m.attachments - m.attachments = nil if value == "" { return nil } + m.textarea.Reset() + m.attachments = nil // Change the placeholder when sending a new message. m.randomizePlaceholders() @@ -176,9 +175,6 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { case tea.WindowSizeMsg: return m, m.repositionCompletions case filepicker.FilePickedMsg: - if len(m.attachments) >= maxAttachments { - return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments)) - } m.attachments = append(m.attachments, msg.Attachment) return m, nil case completions.CompletionsOpenedMsg: @@ -206,6 +202,17 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.currentQuery = "" m.completionsStartIndex = 0 } + content, err := os.ReadFile(item.Path) + if err != nil { + // if it fails, let the LLM handle it later. + return m, nil + } + m.attachments = append(m.attachments, message.Attachment{ + FilePath: item.Path, + FileName: filepath.Base(item.Path), + MimeType: mimeOf(content), + Content: content, + }) } case commands.OpenExternalEditorMsg: @@ -217,39 +224,30 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.textarea.SetValue(msg.Text) m.textarea.MoveToEnd() case tea.PasteMsg: - path := strings.ReplaceAll(msg.Content, "\\ ", " ") - // try to get an image - path, err := filepath.Abs(strings.TrimSpace(path)) - if err != nil { + content, path, err := pasteToFile(msg) + if errors.Is(err, errNotAFile) { m.textarea, cmd = m.textarea.Update(msg) return m, cmd } - isAllowedType := false - for _, ext := range filepicker.AllowedTypes { - if strings.HasSuffix(path, ext) { - isAllowedType = true - break - } - } - if !isAllowedType { - m.textarea, cmd = m.textarea.Update(msg) - return m, cmd + if err != nil { + return m, util.ReportError(err) } - tooBig, _ := filepicker.IsFileTooBig(path, filepicker.MaxAttachmentSize) - if tooBig { - m.textarea, cmd = m.textarea.Update(msg) - return m, cmd + + if len(content) > maxAttachmentSize { + return m, util.ReportWarn("File is too big (>5mb)") } - content, err := os.ReadFile(path) - if err != nil { - m.textarea, cmd = m.textarea.Update(msg) - return m, cmd + mimeType := mimeOf(content) + attachment := message.Attachment{ + FilePath: path, + FileName: filepath.Base(path), + MimeType: mimeType, + Content: content, } - mimeBufferSize := min(512, len(content)) - mimeType := http.DetectContentType(content[:mimeBufferSize]) - fileName := filepath.Base(path) - attachment := message.Attachment{FilePath: path, FileName: fileName, MimeType: mimeType, Content: content} + if !attachment.IsText() && !attachment.IsImage() { + return m, util.ReportWarn("Invalid file content type: " + mimeType) + } + m.textarea.InsertString(attachment.FileName) return m, util.CmdHandler(filepicker.FilePickedMsg{ Attachment: attachment, }) @@ -427,18 +425,17 @@ func (m *editorCmp) View() string { m.textarea.Placeholder = "Yolo mode!" } if len(m.attachments) == 0 { - content := t.S().Base.Padding(1).Render( + return t.S().Base.Padding(1).Render( m.textarea.View(), ) - return content } - content := t.S().Base.Padding(0, 1, 1, 1).Render( - lipgloss.JoinVertical(lipgloss.Top, + return t.S().Base.Padding(0, 1, 1, 1).Render( + lipgloss.JoinVertical( + lipgloss.Top, m.attachmentsContent(), m.textarea.View(), ), ) - return content } func (m *editorCmp) SetSize(width, height int) tea.Cmd { @@ -456,24 +453,45 @@ func (m *editorCmp) GetSize() (int, int) { func (m *editorCmp) attachmentsContent() string { var styledAttachments []string t := styles.CurrentTheme() - attachmentStyles := t.S().Base. - MarginLeft(1). + attachmentStyle := t.S().Base. + Padding(0, 1). + MarginRight(1). Background(t.FgMuted). - Foreground(t.FgBase) + Foreground(t.FgBase). + Render + iconStyle := t.S().Base. + Foreground(t.BgSubtle). + Background(t.Green). + Padding(0, 1). + Bold(true). + Render + rmStyle := t.S().Base. + Padding(0, 1). + Bold(true). + Background(t.Red). + Foreground(t.FgBase). + Render for i, attachment := range m.attachments { - var filename string - if len(attachment.FileName) > 10 { - filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, attachment.FileName[0:7]) - } else { - filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, attachment.FileName) + filename := ansi.Truncate(filepath.Base(attachment.FileName), 10, "...") + icon := styles.ImageIcon + if attachment.IsText() { + icon = styles.TextIcon } if m.deleteMode { - filename = fmt.Sprintf("%d%s", i, filename) + styledAttachments = append( + styledAttachments, + rmStyle(fmt.Sprintf("%d", i)), + attachmentStyle(filename), + ) + continue } - styledAttachments = append(styledAttachments, attachmentStyles.Render(filename)) + styledAttachments = append( + styledAttachments, + iconStyle(icon), + attachmentStyle(filename), + ) } - content := lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...) - return content + return lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...) } func (m *editorCmp) SetPosition(x, y int) tea.Cmd { @@ -597,3 +615,51 @@ func New(app *app.App) Editor { return e } + +var maxAttachmentSize = 5 * 1024 * 1024 // 5MB + +var errNotAFile = errors.New("not a file") + +func pasteToFile(msg tea.PasteMsg) ([]byte, string, error) { + content, path, err := filepathToFile(msg.Content) + if err == nil { + return content, path, err + } + + if strings.Count(msg.Content, "\n") > 2 { + return contentToFile([]byte(msg.Content)) + } + + return nil, "", errNotAFile +} + +func contentToFile(content []byte) ([]byte, string, error) { + f, err := os.CreateTemp("", "paste_*.txt") + if err != nil { + return nil, "", err + } + if _, err := f.Write(content); err != nil { + return nil, "", err + } + if err := f.Close(); err != nil { + return nil, "", err + } + return content, f.Name(), nil +} + +func filepathToFile(name string) ([]byte, string, error) { + path, err := filepath.Abs(strings.TrimSpace(strings.ReplaceAll(name, "\\", ""))) + if err != nil { + return nil, "", err + } + content, err := os.ReadFile(path) + if err != nil { + return nil, "", err + } + return content, path, nil +} + +func mimeOf(content []byte) string { + mimeBufferSize := min(512, len(content)) + return http.DetectContentType(content[:mimeBufferSize]) +} diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index 38012c235df4d455c1b826f6a5ff491783ea1f5e..1359823edb7a783cd23b600e1ddae3870f2a2107 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -227,19 +227,32 @@ func (m *messageCmp) renderUserMessage() string { m.toMarkdown(m.message.Content().String()), } - attachmentStyles := t.S().Text. - MarginLeft(1). - Background(t.BgSubtle) + attachmentStyle := t.S().Base. + Padding(0, 1). + MarginRight(1). + Background(t.FgMuted). + Foreground(t.FgBase). + Render + iconStyle := t.S().Base. + Foreground(t.BgSubtle). + Background(t.Green). + Padding(0, 1). + Bold(true). + Render attachments := make([]string, len(m.message.BinaryContent())) for i, attachment := range m.message.BinaryContent() { const maxFilenameWidth = 10 - filename := filepath.Base(attachment.Path) - attachments[i] = attachmentStyles.Render(fmt.Sprintf( - " %s %s ", - styles.DocumentIcon, - ansi.Truncate(filename, maxFilenameWidth, "..."), - )) + filename := ansi.Truncate(filepath.Base(attachment.Path), 10, "...") + icon := styles.ImageIcon + if strings.HasPrefix(attachment.MIMEType, "text/") { + icon = styles.TextIcon + } + attachments[i] = lipgloss.JoinHorizontal( + lipgloss.Left, + iconStyle(icon), + attachmentStyle(filename), + ) } if len(attachments) > 0 { diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 33a71f9e3b9d184c3fb18423a91dd0baf8c2a0b9..8c54b028f90326ac8cee1cacb0df2377528e4a2b 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -613,8 +613,11 @@ func (p *chatPage) View() string { pillsArea = pillsRow } - style := t.S().Base.MarginTop(1).PaddingLeft(3) - pillsArea = style.Render(pillsArea) + pillsArea = t.S().Base. + MaxWidth(p.width). + MarginTop(1). + PaddingLeft(3). + Render(pillsArea) } if p.compact { diff --git a/internal/tui/styles/icons.go b/internal/tui/styles/icons.go index dfb3cf0c27ccf4a90d84f256d40e1a9a87fc5aa3..0db13358a2f9812293c18497b71ba138484b8f17 100644 --- a/internal/tui/styles/icons.go +++ b/internal/tui/styles/icons.go @@ -10,7 +10,8 @@ const ( ArrowRightIcon string = "→" CenterSpinnerIcon string = "⋯" LoadingIcon string = "⟳" - DocumentIcon string = "🖼" + ImageIcon string = "■" + TextIcon string = "☰" ModelIcon string = "◇" // Tool call icons