diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 58db8f307f547e55c37aea425d365e7c4168f001..f19085ed47829a5e901325ac0741098ebcce48c1 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -6,7 +6,7 @@ use crate::thread::{ }; use crate::thread_store::ThreadStore; use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus}; -use crate::ui::{AgentNotification, AgentNotificationEvent, ContextPill}; +use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill}; use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use collections::HashMap; use editor::{Editor, MultiBuffer}; @@ -487,14 +487,14 @@ impl ActiveThread { let updated_context_ids = refresh_task.await; this.update(cx, |this, cx| { - this.context_store.read_with(cx, |context_store, cx| { + this.context_store.read_with(cx, |context_store, _cx| { context_store .context() .iter() .filter(|context| { updated_context_ids.contains(&context.id()) }) - .flat_map(|context| context.snapshot(cx)) + .cloned() .collect() }) }) @@ -806,7 +806,7 @@ impl ActiveThread { let thread = self.thread.read(cx); // Get all the data we need from thread before we start using it in closures let checkpoint = thread.checkpoint_for_message(message_id); - let context = thread.context_for_message(message_id); + let context = thread.context_for_message(message_id).collect::>(); let tool_uses = thread.tool_uses_for_message(message_id, cx); // Don't render user messages that are just there for returning tool results. @@ -926,53 +926,50 @@ impl ActiveThread { .into_any_element(), }; - let message_content = - v_flex() - .gap_1p5() - .child( - if let Some(edit_message_editor) = edit_message_editor.clone() { - div() - .key_context("EditMessageEditor") - .on_action(cx.listener(Self::cancel_editing_message)) - .on_action(cx.listener(Self::confirm_editing_message)) - .min_h_6() - .child(edit_message_editor) - } else { - div() - .min_h_6() - .text_ui(cx) - .child(self.render_message_content(message_id, rendered_message, cx)) - }, - ) - .when_some(context, |parent, context| { - if !context.is_empty() { - parent.child(h_flex().flex_wrap().gap_1().children( - context.into_iter().map(|context| { - let context_id = context.id; - ContextPill::added(context, false, false, None).on_click(Rc::new( - cx.listener({ - let workspace = workspace.clone(); - let context_store = context_store.clone(); - move |_, _, window, cx| { - if let Some(workspace) = workspace.upgrade() { - open_context( - context_id, - context_store.clone(), - workspace, - window, - cx, - ); - cx.notify(); - } + let message_content = v_flex() + .gap_1p5() + .child( + if let Some(edit_message_editor) = edit_message_editor.clone() { + div() + .key_context("EditMessageEditor") + .on_action(cx.listener(Self::cancel_editing_message)) + .on_action(cx.listener(Self::confirm_editing_message)) + .min_h_6() + .child(edit_message_editor) + } else { + div() + .min_h_6() + .text_ui(cx) + .child(self.render_message_content(message_id, rendered_message, cx)) + }, + ) + .when(!context.is_empty(), |parent| { + parent.child( + h_flex() + .flex_wrap() + .gap_1() + .children(context.into_iter().map(|context| { + let context_id = context.id(); + ContextPill::added(AddedContext::new(context, cx), false, false, None) + .on_click(Rc::new(cx.listener({ + let workspace = workspace.clone(); + let context_store = context_store.clone(); + move |_, _, window, cx| { + if let Some(workspace) = workspace.upgrade() { + open_context( + context_id, + context_store.clone(), + workspace, + window, + cx, + ); + cx.notify(); } - }), - )) - }), - )) - } else { - parent - } - }); + } + }))) + })), + ) + }); let styled_message = match message.role { Role::User => v_flex() @@ -1974,7 +1971,7 @@ pub(crate) fn open_context( } } AssistantContext::Directory(directory_context) => { - let path = directory_context.path.clone(); + let path = directory_context.project_path.clone(); workspace.update(cx, |workspace, cx| { workspace.project().update(cx, |project, cx| { if let Some(entry) = project.entry_for_path(&path, cx) { diff --git a/crates/assistant2/src/buffer_codegen.rs b/crates/assistant2/src/buffer_codegen.rs index 5ba3e3d05e4d4fcd2f423444ce4f2e001d0038a5..e1ef0ecac8707334542c0b7297bb808f1a1e3bb3 100644 --- a/crates/assistant2/src/buffer_codegen.rs +++ b/crates/assistant2/src/buffer_codegen.rs @@ -414,7 +414,11 @@ impl CodegenAlternative { }; if let Some(context_store) = &self.context_store { - attach_context_to_message(&mut request_message, context_store.read(cx).snapshot(cx)); + attach_context_to_message( + &mut request_message, + context_store.read(cx).context().iter(), + cx, + ); } request_message.content.push(prompt.into()); diff --git a/crates/assistant2/src/context.rs b/crates/assistant2/src/context.rs index 64385352f82d6cb742a760b25aff6b12b91ece72..6d8583579d2d648751fa3ba0696019e90cd3d58f 100644 --- a/crates/assistant2/src/context.rs +++ b/crates/assistant2/src/context.rs @@ -1,8 +1,7 @@ -use std::ops::Range; +use std::{ops::Range, sync::Arc}; -use file_icons::FileIcons; use gpui::{App, Entity, SharedString}; -use language::Buffer; +use language::{Buffer, File}; use language_model::{LanguageModelRequestMessage, MessageContent}; use project::ProjectPath; use serde::{Deserialize, Serialize}; @@ -10,7 +9,7 @@ use text::{Anchor, BufferId}; use ui::IconName; use util::post_inc; -use crate::{context_store::buffer_path_log_err, thread::Thread}; +use crate::thread::Thread; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] pub struct ContextId(pub(crate) usize); @@ -21,19 +20,6 @@ impl ContextId { } } -/// Some context attached to a message in a thread. -#[derive(Debug, Clone)] -pub struct ContextSnapshot { - pub id: ContextId, - pub name: SharedString, - pub parent: Option, - pub tooltip: Option, - pub icon_path: Option, - pub kind: ContextKind, - /// Joining these strings separated by \n yields text for model. Not refreshed by `snapshot`. - pub text: Box<[SharedString]>, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ContextKind { File, @@ -55,7 +41,7 @@ impl ContextKind { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum AssistantContext { File(FileContext), Directory(DirectoryContext), @@ -68,7 +54,7 @@ impl AssistantContext { pub fn id(&self) -> ContextId { match self { Self::File(file) => file.id, - Self::Directory(directory) => directory.snapshot.id, + Self::Directory(directory) => directory.id, Self::Symbol(symbol) => symbol.id, Self::FetchedUrl(url) => url.id, Self::Thread(thread) => thread.id, @@ -76,26 +62,26 @@ impl AssistantContext { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FileContext { pub id: ContextId, pub context_buffer: ContextBuffer, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DirectoryContext { - pub path: ProjectPath, + pub id: ContextId, + pub project_path: ProjectPath, pub context_buffers: Vec, - pub snapshot: ContextSnapshot, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SymbolContext { pub id: ContextId, pub context_symbol: ContextSymbol, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FetchedUrlContext { pub id: ContextId, pub url: SharedString, @@ -105,24 +91,45 @@ pub struct FetchedUrlContext { // TODO: Model holds onto the thread even if the thread is deleted. Can either handle this // explicitly or have a WeakModel and remove during snapshot. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ThreadContext { pub id: ContextId, pub thread: Entity, pub text: SharedString, } +impl ThreadContext { + pub fn summary(&self, cx: &App) -> SharedString { + self.thread + .read(cx) + .summary() + .unwrap_or("New thread".into()) + } +} + // TODO: Model holds onto the buffer even if the file is deleted and closed. Should remove // the context from the message editor in this case. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ContextBuffer { pub id: BufferId, pub buffer: Entity, + pub file: Arc, pub version: clock::Global, pub text: SharedString, } +impl std::fmt::Debug for ContextBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ContextBuffer") + .field("id", &self.id) + .field("buffer", &self.buffer) + .field("version", &self.version) + .field("text", &self.text) + .finish() + } +} + #[derive(Debug, Clone)] pub struct ContextSymbol { pub id: ContextSymbolId, @@ -141,145 +148,10 @@ pub struct ContextSymbolId { pub range: Range, } -impl AssistantContext { - pub fn snapshot(&self, cx: &App) -> Option { - match &self { - Self::File(file_context) => file_context.snapshot(cx), - Self::Directory(directory_context) => Some(directory_context.snapshot()), - Self::Symbol(symbol_context) => symbol_context.snapshot(cx), - Self::FetchedUrl(fetched_url_context) => Some(fetched_url_context.snapshot()), - Self::Thread(thread_context) => Some(thread_context.snapshot(cx)), - } - } -} - -impl FileContext { - pub fn snapshot(&self, cx: &App) -> Option { - let buffer = self.context_buffer.buffer.read(cx); - let path = buffer_path_log_err(buffer, cx)?; - let full_path: SharedString = path.to_string_lossy().into_owned().into(); - let name = match path.file_name() { - Some(name) => name.to_string_lossy().into_owned().into(), - None => full_path.clone(), - }; - let parent = path - .parent() - .and_then(|p| p.file_name()) - .map(|p| p.to_string_lossy().into_owned().into()); - - let icon_path = FileIcons::get_icon(&path, cx); - - Some(ContextSnapshot { - id: self.id, - name, - parent, - tooltip: Some(full_path), - icon_path, - kind: ContextKind::File, - text: Box::new([self.context_buffer.text.clone()]), - }) - } -} - -impl DirectoryContext { - pub fn new( - id: ContextId, - project_path: ProjectPath, - context_buffers: Vec, - ) -> DirectoryContext { - let full_path: SharedString = project_path.path.to_string_lossy().into_owned().into(); - - let name = match project_path.path.file_name() { - Some(name) => name.to_string_lossy().into_owned().into(), - None => full_path.clone(), - }; - - let parent = project_path - .path - .parent() - .and_then(|p| p.file_name()) - .map(|p| p.to_string_lossy().into_owned().into()); - - // TODO: include directory path in text? - let text = context_buffers - .iter() - .map(|b| b.text.clone()) - .collect::>() - .into(); - - DirectoryContext { - path: project_path, - context_buffers, - snapshot: ContextSnapshot { - id, - name, - parent, - tooltip: Some(full_path), - icon_path: None, - kind: ContextKind::Directory, - text, - }, - } - } - - pub fn snapshot(&self) -> ContextSnapshot { - self.snapshot.clone() - } -} - -impl SymbolContext { - pub fn snapshot(&self, cx: &App) -> Option { - let buffer = self.context_symbol.buffer.read(cx); - let name = self.context_symbol.id.name.clone(); - let path = buffer_path_log_err(buffer, cx)? - .to_string_lossy() - .into_owned() - .into(); - - Some(ContextSnapshot { - id: self.id, - name, - parent: Some(path), - tooltip: None, - icon_path: None, - kind: ContextKind::Symbol, - text: Box::new([self.context_symbol.text.clone()]), - }) - } -} - -impl FetchedUrlContext { - pub fn snapshot(&self) -> ContextSnapshot { - ContextSnapshot { - id: self.id, - name: self.url.clone(), - parent: None, - tooltip: None, - icon_path: None, - kind: ContextKind::FetchedUrl, - text: Box::new([self.text.clone()]), - } - } -} - -impl ThreadContext { - pub fn snapshot(&self, cx: &App) -> ContextSnapshot { - let thread = self.thread.read(cx); - ContextSnapshot { - id: self.id, - name: thread.summary().unwrap_or("New thread".into()), - parent: None, - tooltip: None, - icon_path: None, - kind: ContextKind::Thread, - text: Box::new([self.text.clone()]), - } - } -} - -pub fn attach_context_to_message( +pub fn attach_context_to_message<'a>( message: &mut LanguageModelRequestMessage, - contexts: impl Iterator, + contexts: impl Iterator, + cx: &App, ) { let mut file_context = Vec::new(); let mut directory_context = Vec::new(); @@ -287,91 +159,61 @@ pub fn attach_context_to_message( let mut fetch_context = Vec::new(); let mut thread_context = Vec::new(); - let mut capacity = 0; for context in contexts { - capacity += context.text.len(); - match context.kind { - ContextKind::File => file_context.push(context), - ContextKind::Directory => directory_context.push(context), - ContextKind::Symbol => symbol_context.push(context), - ContextKind::FetchedUrl => fetch_context.push(context), - ContextKind::Thread => thread_context.push(context), + match context { + AssistantContext::File(context) => file_context.push(context), + AssistantContext::Directory(context) => directory_context.push(context), + AssistantContext::Symbol(context) => symbol_context.push(context), + AssistantContext::FetchedUrl(context) => fetch_context.push(context), + AssistantContext::Thread(context) => thread_context.push(context), } } - if !file_context.is_empty() { - capacity += 1; - } - if !directory_context.is_empty() { - capacity += 1; - } - if !symbol_context.is_empty() { - capacity += 1; - } - if !fetch_context.is_empty() { - capacity += 1 + fetch_context.len(); - } - if !thread_context.is_empty() { - capacity += 1 + thread_context.len(); - } - if capacity == 0 { - return; - } - let mut context_chunks = Vec::with_capacity(capacity); + let mut context_chunks = Vec::new(); if !file_context.is_empty() { context_chunks.push("The following files are available:\n"); - for context in &file_context { - for chunk in &context.text { - context_chunks.push(&chunk); - } + for context in file_context { + context_chunks.push(&context.context_buffer.text); } } if !directory_context.is_empty() { context_chunks.push("The following directories are available:\n"); - for context in &directory_context { - for chunk in &context.text { - context_chunks.push(&chunk); + for context in directory_context { + for context_buffer in &context.context_buffers { + context_chunks.push(&context_buffer.text); } } } if !symbol_context.is_empty() { context_chunks.push("The following symbols are available:\n"); - for context in &symbol_context { - for chunk in &context.text { - context_chunks.push(&chunk); - } + for context in symbol_context { + context_chunks.push(&context.context_symbol.text); } } if !fetch_context.is_empty() { context_chunks.push("The following fetched results are available:\n"); for context in &fetch_context { - context_chunks.push(&context.name); - for chunk in &context.text { - context_chunks.push(&chunk); - } + context_chunks.push(&context.url); + context_chunks.push(&context.text); } } + // Need to own the SharedString for summary so that it can be referenced. + let mut thread_context_chunks = Vec::new(); if !thread_context.is_empty() { context_chunks.push("The following previous conversation threads are available:\n"); for context in &thread_context { - context_chunks.push(&context.name); - for chunk in &context.text { - context_chunks.push(&chunk); - } + thread_context_chunks.push(context.summary(cx)); + thread_context_chunks.push(context.text.clone()); } } - - debug_assert!( - context_chunks.len() == capacity, - "attach_context_message calculated capacity of {}, but length was {}", - capacity, - context_chunks.len() - ); + for chunk in &thread_context_chunks { + context_chunks.push(chunk); + } if !context_chunks.is_empty() { message diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index cf7c582f4c67a741e64464ebd46399ffd2bb936e..8fdc53e53acf453cc898d82fc7e52caadddf7cea 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -2,20 +2,20 @@ use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Context as _, Result, anyhow}; use collections::{BTreeMap, HashMap, HashSet}; use futures::{self, Future, FutureExt, future}; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; -use language::Buffer; +use language::{Buffer, File}; use project::{ProjectItem, ProjectPath, Worktree}; use rope::Rope; use text::{Anchor, BufferId, OffsetRangeExt}; -use util::maybe; +use util::{ResultExt, maybe}; use workspace::Workspace; use crate::context::{ - AssistantContext, ContextBuffer, ContextId, ContextSnapshot, ContextSymbol, ContextSymbolId, - DirectoryContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext, + AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, + FetchedUrlContext, FileContext, SymbolContext, ThreadContext, }; use crate::context_strip::SuggestedContext; use crate::thread::{Thread, ThreadId}; @@ -50,12 +50,6 @@ impl ContextStore { } } - pub fn snapshot<'a>(&'a self, cx: &'a App) -> impl Iterator + 'a { - self.context() - .iter() - .flat_map(|context| context.snapshot(cx)) - } - pub fn context(&self) -> &Vec { &self.context } @@ -121,7 +115,7 @@ impl ContextStore { None, cx.to_async(), ) - })?; + })??; let text = text_task.await; @@ -144,13 +138,13 @@ impl ContextStore { let Some(file) = buffer.file() else { return Err(anyhow!("Buffer has no path.")); }; - Ok(collect_buffer_info_and_text( + collect_buffer_info_and_text( file.path().clone(), buffer_entity, buffer, None, cx.to_async(), - )) + ) })??; let text = text_task.await; @@ -166,8 +160,10 @@ impl ContextStore { fn insert_file(&mut self, context_buffer: ContextBuffer) { let id = self.next_context_id.post_inc(); self.files.insert(context_buffer.id, id); - self.context - .push(AssistantContext::File(FileContext { id, context_buffer })); + self.context.push(AssistantContext::File(FileContext { + id, + context_buffer: context_buffer, + })); } pub fn add_directory( @@ -231,15 +227,18 @@ impl ContextStore { // Skip all binary files and other non-UTF8 files if let Ok(buffer_entity) = buffer_entity { let buffer = buffer_entity.read(cx); - let (buffer_info, text_task) = collect_buffer_info_and_text( + if let Some((buffer_info, text_task)) = collect_buffer_info_and_text( path, buffer_entity, buffer, None, cx.to_async(), - ); - buffer_infos.push(buffer_info); - text_tasks.push(text_task); + ) + .log_err() + { + buffer_infos.push(buffer_info); + text_tasks.push(text_task); + } } } anyhow::Ok(()) @@ -253,7 +252,10 @@ impl ContextStore { .collect::>(); if context_buffers.is_empty() { - bail!("No text files found in {}", &project_path.path.display()); + return Err(anyhow!( + "No text files found in {}", + &project_path.path.display() + )); } this.update(cx, |this, _| { @@ -269,11 +271,11 @@ impl ContextStore { self.directories.insert(project_path.path.to_path_buf(), id); self.context - .push(AssistantContext::Directory(DirectoryContext::new( + .push(AssistantContext::Directory(DirectoryContext { id, project_path, context_buffers, - ))); + })); } pub fn add_symbol( @@ -314,13 +316,16 @@ impl ContextStore { } } - let (buffer_info, collect_content_task) = collect_buffer_info_and_text( + let (buffer_info, collect_content_task) = match collect_buffer_info_and_text( file.path().clone(), buffer, buffer_ref, Some(symbol_enclosing_range.clone()), cx.to_async(), - ); + ) { + Ok((buffer_info, collect_context_task)) => (buffer_info, collect_context_task), + Err(err) => return Task::ready(Err(err)), + }; cx.spawn(async move |this, cx| { let content = collect_content_task.await; @@ -568,6 +573,7 @@ pub enum FileInclusion { // ContextBuffer without text. struct BufferInfo { buffer_entity: Entity, + file: Arc, id: BufferId, version: clock::Global, } @@ -576,6 +582,7 @@ fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer { ContextBuffer { id: info.id, buffer: info.buffer_entity, + file: info.file, version: info.version, text, } @@ -604,10 +611,14 @@ fn collect_buffer_info_and_text( buffer: &Buffer, range: Option>, cx: AsyncApp, -) -> (BufferInfo, Task) { +) -> Result<(BufferInfo, Task)> { let buffer_info = BufferInfo { id: buffer.remote_id(), buffer_entity, + file: buffer + .file() + .context("buffer context must have a file")? + .clone(), version: buffer.version(), }; // Important to collect version at the same time as content so that staleness logic is correct. @@ -617,23 +628,26 @@ fn collect_buffer_info_and_text( buffer.as_rope().clone() }; let text_task = cx.background_spawn(async move { to_fenced_codeblock(&path, content) }); - (buffer_info, text_task) + Ok((buffer_info, text_task)) } pub fn buffer_path_log_err(buffer: &Buffer, cx: &App) -> Option> { if let Some(file) = buffer.file() { - let mut path = file.path().clone(); - - if path.as_os_str().is_empty() { - path = file.full_path(cx).into(); - } - Some(path) + Some(file_path(file, cx)) } else { log::error!("Buffer that had a path unexpectedly no longer has a path."); None } } +pub fn file_path(file: &Arc, cx: &App) -> Arc { + let mut path = file.path().clone(); + if path.as_os_str().is_empty() { + path = file.full_path(cx).into(); + } + return path; +} + fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString { let path_extension = path.extension().and_then(|ext| ext.to_str()); let path_string = path.to_string_lossy(); @@ -714,7 +728,7 @@ pub fn refresh_context_store_text( let buffer = buffer.read(cx); buffer_path_log_err(&buffer, cx).map_or(false, |path| { - path.starts_with(&directory_context.path.path) + path.starts_with(&directory_context.project_path.path) }) }); @@ -801,13 +815,17 @@ fn refresh_directory_text( let context_buffers = future::join_all(futures); - let id = directory_context.snapshot.id; - let path = directory_context.path.clone(); + let id = directory_context.id; + let project_path = directory_context.project_path.clone(); Some(cx.spawn(async move |cx| { let context_buffers = context_buffers.await; context_store .update(cx, |context_store, _| { - let new_directory_context = DirectoryContext::new(id, path, context_buffers); + let new_directory_context = DirectoryContext { + id, + project_path, + context_buffers, + }; context_store.replace_context(AssistantContext::Directory(new_directory_context)); }) .ok(); @@ -870,7 +888,8 @@ fn refresh_context_buffer( buffer, None, cx.to_async(), - ); + ) + .log_err()?; Some(text_task.map(move |text| make_context_buffer(buffer_info, text))) } else { None @@ -891,7 +910,8 @@ fn refresh_context_symbol( buffer, Some(context_symbol.enclosing_range.clone()), cx.to_async(), - ); + ) + .log_err()?; let name = context_symbol.id.name.clone(); let range = context_symbol.id.range.clone(); let enclosing_range = context_symbol.enclosing_range.clone(); diff --git a/crates/assistant2/src/context_strip.rs b/crates/assistant2/src/context_strip.rs index 223c340752340440fd8ee370f13d61e12a97eb3b..a2137ede0e91685ce501463dad52ab016deae7bc 100644 --- a/crates/assistant2/src/context_strip.rs +++ b/crates/assistant2/src/context_strip.rs @@ -17,7 +17,7 @@ use crate::context_picker::{ConfirmBehavior, ContextPicker}; use crate::context_store::ContextStore; use crate::thread::Thread; use crate::thread_store::ThreadStore; -use crate::ui::ContextPill; +use crate::ui::{AddedContext, ContextPill}; use crate::{ AcceptSuggestedContext, AssistantPanel, FocusDown, FocusLeft, FocusRight, FocusUp, RemoveAllContext, RemoveFocusedContext, ToggleContextPicker, @@ -363,19 +363,19 @@ impl Focusable for ContextStrip { impl Render for ContextStrip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let context_store = self.context_store.read(cx); - let context = context_store - .context() - .iter() - .flat_map(|context| context.snapshot(cx)) - .collect::>(); + let context = context_store.context(); let context_picker = self.context_picker.clone(); let focus_handle = self.focus_handle.clone(); let suggested_context = self.suggested_context(cx); - let dupe_names = context + let added_contexts = context + .iter() + .map(|c| AddedContext::new(c, cx)) + .collect::>(); + let dupe_names = added_contexts .iter() - .map(|context| context.name.clone()) + .map(|c| c.name.clone()) .sorted() .tuple_windows() .filter(|(a, b)| a == b) @@ -461,34 +461,39 @@ impl Render for ContextStrip { ) } }) - .children(context.iter().enumerate().map(|(i, context)| { - let id = context.id; - ContextPill::added( - context.clone(), - dupe_names.contains(&context.name), - self.focused_index == Some(i), - Some({ - let id = context.id; - let context_store = self.context_store.clone(); - Rc::new(cx.listener(move |_this, _event, _window, cx| { - context_store.update(cx, |this, _cx| { - this.remove_context(id); - }); - cx.notify(); - })) + .children( + added_contexts + .into_iter() + .enumerate() + .map(|(i, added_context)| { + let name = added_context.name.clone(); + let id = added_context.id; + ContextPill::added( + added_context, + dupe_names.contains(&name), + self.focused_index == Some(i), + Some({ + let context_store = self.context_store.clone(); + Rc::new(cx.listener(move |_this, _event, _window, cx| { + context_store.update(cx, |this, _cx| { + this.remove_context(id); + }); + cx.notify(); + })) + }), + ) + .on_click({ + Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { + if event.down.click_count > 1 { + this.open_context(id, window, cx); + } else { + this.focused_index = Some(i); + } + cx.notify(); + })) + }) }), - ) - .on_click(Rc::new(cx.listener( - move |this, event: &ClickEvent, window, cx| { - if event.down.click_count > 1 { - this.open_context(id, window, cx); - } else { - this.focused_index = Some(i); - } - cx.notify(); - }, - ))) - })) + ) .when_some(suggested_context, |el, suggested| { el.child( ContextPill::suggested( diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 5b4e11014e59ce811921ecbc7e07a072fa384d36..8cbcb580ae76b25ceff352001816551546766212 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -239,7 +239,7 @@ impl MessageEditor { .ok(); thread .update(cx, |thread, cx| { - let context = context_store.read(cx).snapshot(cx).collect::>(); + let context = context_store.read(cx).context().clone(); thread.action_log().update(cx, |action_log, cx| { action_log.clear_reviewed_changes(cx); }); diff --git a/crates/assistant2/src/terminal_inline_assistant.rs b/crates/assistant2/src/terminal_inline_assistant.rs index 54b771b0397d5177bcace405fec359b20b66a72e..a67e1ed3477e4e9d47653b241771730460f895cb 100644 --- a/crates/assistant2/src/terminal_inline_assistant.rs +++ b/crates/assistant2/src/terminal_inline_assistant.rs @@ -252,7 +252,8 @@ impl TerminalInlineAssistant { attach_context_to_message( &mut request_message, - assist.context_store.read(cx).snapshot(cx), + assist.context_store.read(cx).context().iter(), + cx, ); request_message.content.push(prompt.into()); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 3cf48e14a3b88b78556b408de30696de48c78773..ed2b22306bea3e69060fa4d8272cd8fe224ceab4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -29,7 +29,7 @@ use settings::Settings; use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc}; use uuid::Uuid; -use crate::context::{ContextId, ContextSnapshot, attach_context_to_message}; +use crate::context::{AssistantContext, ContextId, attach_context_to_message}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, @@ -175,7 +175,7 @@ pub struct Thread { pending_summary: Task>, messages: Vec, next_message_id: MessageId, - context: BTreeMap, + context: BTreeMap, context_by_message: HashMap>, system_prompt_context: Option, checkpoints_by_message: HashMap, @@ -473,15 +473,15 @@ impl Thread { cx.notify(); } - pub fn context_for_message(&self, id: MessageId) -> Option> { - let context = self.context_by_message.get(&id)?; - Some( - context - .into_iter() - .filter_map(|context_id| self.context.get(&context_id)) - .cloned() - .collect::>(), - ) + pub fn context_for_message(&self, id: MessageId) -> impl Iterator { + self.context_by_message + .get(&id) + .into_iter() + .flat_map(|context| { + context + .iter() + .filter_map(|context_id| self.context.get(&context_id)) + }) } /// Returns whether all of the tool uses have finished running. @@ -513,15 +513,18 @@ impl Thread { pub fn insert_user_message( &mut self, text: impl Into, - context: Vec, + context: Vec, git_checkpoint: Option, cx: &mut Context, ) -> MessageId { let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx); - let context_ids = context.iter().map(|context| context.id).collect::>(); + let context_ids = context + .iter() + .map(|context| context.id()) + .collect::>(); self.context - .extend(context.into_iter().map(|context| (context.id, context))); + .extend(context.into_iter().map(|context| (context.id(), context))); self.context_by_message.insert(message_id, context_ids); if let Some(git_checkpoint) = git_checkpoint { self.pending_checkpoint = Some(ThreadCheckpoint { @@ -889,9 +892,8 @@ impl Thread { let referenced_context = referenced_context_ids .into_iter() - .filter_map(|context_id| self.context.get(context_id)) - .cloned(); - attach_context_to_message(&mut context_message, referenced_context); + .filter_map(|context_id| self.context.get(context_id)); + attach_context_to_message(&mut context_message, referenced_context, cx); request.messages.push(context_message); } @@ -1300,13 +1302,13 @@ impl Thread { pub fn attach_tool_results( &mut self, - updated_context: Vec, + updated_context: Vec, cx: &mut Context, ) { self.context.extend( updated_context .into_iter() - .map(|context| (context.id, context)), + .map(|context| (context.id(), context)), ); // Insert a user message to contain the tool results. diff --git a/crates/assistant2/src/ui/context_pill.rs b/crates/assistant2/src/ui/context_pill.rs index eeaf975192991bdaf23bb68b361e5aa60c6e43c5..dd39a48e2ec7345521602c807a1483272ea498ca 100644 --- a/crates/assistant2/src/ui/context_pill.rs +++ b/crates/assistant2/src/ui/context_pill.rs @@ -1,14 +1,15 @@ use std::rc::Rc; +use file_icons::FileIcons; use gpui::ClickEvent; use ui::{IconButtonShape, Tooltip, prelude::*}; -use crate::context::{ContextKind, ContextSnapshot}; +use crate::context::{AssistantContext, ContextId, ContextKind}; #[derive(IntoElement)] pub enum ContextPill { Added { - context: ContextSnapshot, + context: AddedContext, dupe_name: bool, focused: bool, on_click: Option>, @@ -25,7 +26,7 @@ pub enum ContextPill { impl ContextPill { pub fn added( - context: ContextSnapshot, + context: AddedContext, dupe_name: bool, focused: bool, on_remove: Option>, @@ -77,17 +78,21 @@ impl ContextPill { pub fn icon(&self) -> Icon { match self { - Self::Added { context, .. } => match &context.icon_path { - Some(icon_path) => Icon::from_path(icon_path), - None => Icon::new(context.kind.icon()), - }, Self::Suggested { icon_path: Some(icon_path), .. + } + | Self::Added { + context: + AddedContext { + icon_path: Some(icon_path), + .. + }, + .. } => Icon::from_path(icon_path), - Self::Suggested { - kind, - icon_path: None, + Self::Suggested { kind, .. } + | Self::Added { + context: AddedContext { kind, .. }, .. } => Icon::new(kind.icon()), } @@ -144,7 +149,7 @@ impl RenderOnce for ContextPill { element } }) - .when_some(context.tooltip.clone(), |element, tooltip| { + .when_some(context.tooltip.as_ref(), |element, tooltip| { element.tooltip(Tooltip::text(tooltip.clone())) }), ) @@ -219,3 +224,91 @@ impl RenderOnce for ContextPill { } } } + +pub struct AddedContext { + pub id: ContextId, + pub kind: ContextKind, + pub name: SharedString, + pub parent: Option, + pub tooltip: Option, + pub icon_path: Option, +} + +impl AddedContext { + pub fn new(context: &AssistantContext, cx: &App) -> AddedContext { + match context { + AssistantContext::File(file_context) => { + let full_path = file_context.context_buffer.file.full_path(cx); + let full_path_string: SharedString = + full_path.to_string_lossy().into_owned().into(); + let name = full_path + .file_name() + .map(|n| n.to_string_lossy().into_owned().into()) + .unwrap_or_else(|| full_path_string.clone()); + let parent = full_path + .parent() + .and_then(|p| p.file_name()) + .map(|n| n.to_string_lossy().into_owned().into()); + AddedContext { + id: file_context.id, + kind: ContextKind::File, + name, + parent, + tooltip: Some(full_path_string), + icon_path: FileIcons::get_icon(&full_path, cx), + } + } + + AssistantContext::Directory(directory_context) => { + // TODO: handle worktree disambiguation. Maybe by storing an `Arc` to also + // handle renames? + let full_path = &directory_context.project_path.path; + let full_path_string: SharedString = + full_path.to_string_lossy().into_owned().into(); + let name = full_path + .file_name() + .map(|n| n.to_string_lossy().into_owned().into()) + .unwrap_or_else(|| full_path_string.clone()); + let parent = full_path + .parent() + .and_then(|p| p.file_name()) + .map(|n| n.to_string_lossy().into_owned().into()); + AddedContext { + id: directory_context.id, + kind: ContextKind::Directory, + name, + parent, + tooltip: Some(full_path_string), + icon_path: None, + } + } + + AssistantContext::Symbol(symbol_context) => AddedContext { + id: symbol_context.id, + kind: ContextKind::Symbol, + name: symbol_context.context_symbol.id.name.clone(), + parent: None, + tooltip: None, + icon_path: None, + }, + + AssistantContext::FetchedUrl(fetched_url_context) => AddedContext { + id: fetched_url_context.id, + kind: ContextKind::FetchedUrl, + name: fetched_url_context.url.clone(), + parent: None, + tooltip: None, + icon_path: None, + }, + + AssistantContext::Thread(thread_context) => AddedContext { + id: thread_context.id, + kind: ContextKind::Thread, + name: thread_context.summary(cx), + parent: None, + tooltip: None, + icon_path: None, + }, + } + } +}