Detailed changes
@@ -61,7 +61,6 @@ dependencies = [
"buffer_diff",
"chrono",
"client",
- "clock",
"collections",
"command_palette_hooks",
"component",
@@ -99,6 +98,7 @@ dependencies = [
"prompt_store",
"proto",
"rand 0.8.5",
+ "ref-cast",
"release_channel",
"rope",
"rules_library",
@@ -11716,6 +11716,26 @@ dependencies = [
"thiserror 2.0.12",
]
+[[package]]
+name = "ref-cast"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
+dependencies = [
+ "ref-cast-impl",
+]
+
+[[package]]
+name = "ref-cast-impl"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
+]
+
[[package]]
name = "refineable"
version = "0.1.0"
@@ -500,6 +500,7 @@ prost-types = "0.9"
pulldown-cmark = { version = "0.12.0", default-features = false }
quote = "1.0.9"
rand = "0.8.5"
+ref-cast = "1.0.24"
rayon = "1.8"
regex = "1.5"
repair_json = "0.1.0"
@@ -1,2 +1,7 @@
allow-private-module-inception = true
avoid-breaking-exported-api = false
+ignore-interior-mutability = [
+ # Suppresses clippy::mutable_key_type, which is a false positive as the Eq
+ # and Hash impls do not use fields with interior mutability.
+ "agent::context::AgentContextKey"
+]
@@ -28,7 +28,6 @@ async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
-clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -65,6 +64,7 @@ project.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
+ref-cast.workspace = true
release_channel.workspace = true
rope.workspace = true
schemars.workspace = true
@@ -1,4 +1,4 @@
-use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
+use crate::context::{AgentContext, RULES_ICON};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@@ -25,8 +25,8 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry};
use language_model::{
- LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role,
- StopReason,
+ LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
+ RequestUsage, Role, StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@@ -47,13 +47,10 @@ use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
-use crate::context_store::ContextStore;
-
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
- context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
@@ -717,7 +714,6 @@ impl ActiveThread {
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
- context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -740,7 +736,6 @@ impl ActiveThread {
language_registry,
thread_store,
thread: thread.clone(),
- context_store,
workspace,
save_thread_task: None,
messages: Vec::new(),
@@ -780,10 +775,6 @@ impl ActiveThread {
this
}
- pub fn context_store(&self) -> &Entity<ContextStore> {
- &self.context_store
- }
-
pub fn thread(&self) -> &Entity<Thread> {
&self.thread
}
@@ -1273,26 +1264,36 @@ impl ActiveThread {
}
let token_count = if let Some(task) = cx.update(|cx| {
- let context = thread.read(cx).context_for_message(message_id);
- let new_context = thread.read(cx).filter_new_context(context);
- let context_text =
- format_context_as_string(new_context, cx).unwrap_or(String::new());
+ let Some(message) = thread.read(cx).message(message_id) else {
+ log::error!("Message that was being edited no longer exists");
+ return None;
+ };
let message_text = editor.read(cx).text(cx);
- let content = context_text + &message_text;
-
- if content.is_empty() {
+ if message_text.is_empty() && message.loaded_context.is_empty() {
return None;
}
+ let mut request_message = LanguageModelRequestMessage {
+ role: language_model::Role::User,
+ content: Vec::new(),
+ cache: false,
+ };
+
+ message
+ .loaded_context
+ .add_to_request_message(&mut request_message);
+
+ if !message_text.is_empty() {
+ request_message
+ .content
+ .push(MessageContent::Text(message_text));
+ }
+
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
- messages: vec![LanguageModelRequestMessage {
- role: language_model::Role::User,
- content: vec![content.into()],
- cache: false,
- }],
+ messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
@@ -1487,13 +1488,21 @@ impl ActiveThread {
return Empty.into_any();
};
- let context_store = self.context_store.clone();
let workspace = self.workspace.clone();
let thread = self.thread.read(cx);
+ let prompt_store = self.thread_store.read(cx).prompt_store().as_ref();
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
- let context = thread.context_for_message(message_id).collect::<Vec<_>>();
+ let added_context = if let Some(workspace) = workspace.upgrade() {
+ let project = workspace.read(cx).project().read(cx);
+ thread
+ .context_for_message(message_id)
+ .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
+ .collect::<Vec<_>>()
+ } else {
+ return Empty.into_any();
+ };
let tool_uses = thread.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty();
@@ -1641,90 +1650,78 @@ impl ActiveThread {
};
let message_is_empty = message.should_display_content();
- let has_content = !message_is_empty || !context.is_empty();
+ let has_content = !message_is_empty || !added_context.is_empty();
- let message_content =
- has_content.then(|| {
- v_flex()
- .gap_1()
- .when(!message_is_empty, |parent| {
- parent.child(
- if let Some(edit_message_editor) = edit_message_editor.clone() {
- let settings = ThemeSettings::get_global(cx);
- let font_size = TextSize::Small.rems(cx);
- let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
-
- let text_style = TextStyle {
- color: cx.theme().colors().text,
- font_family: settings.buffer_font.family.clone(),
- font_fallbacks: settings.buffer_font.fallbacks.clone(),
- font_features: settings.buffer_font.features.clone(),
- font_size: font_size.into(),
- line_height: line_height.into(),
- ..Default::default()
- };
-
- div()
- .key_context("EditMessageEditor")
- .on_action(cx.listener(Self::cancel_editing_message))
- .on_action(cx.listener(Self::confirm_editing_message))
- .min_h_6()
- .child(EditorElement::new(
- &edit_message_editor,
- EditorStyle {
- background: colors.editor_background,
- local_player: cx.theme().players().local(),
- text: text_style,
- syntax: cx.theme().syntax().clone(),
- ..Default::default()
- },
- ))
- .into_any()
- } else {
- div()
- .min_h_6()
- .child(self.render_message_content(
- message_id,
- rendered_message,
- has_tool_uses,
- workspace.clone(),
- window,
- cx,
- ))
- .into_any()
- },
- )
- })
- .when(!context.is_empty(), |parent| {
- parent.child(h_flex().flex_wrap().gap_1().children(
- context.into_iter().map(|context| {
- let context_id = context.id();
- ContextPill::added(
- AddedContext::new(context, cx),
- false,
- false,
- None,
- )
- .on_click(Rc::new(cx.listener({
+ let message_content = has_content.then(|| {
+ v_flex()
+ .gap_1()
+ .when(!message_is_empty, |parent| {
+ parent.child(
+ if let Some(edit_message_editor) = edit_message_editor.clone() {
+ let settings = ThemeSettings::get_global(cx);
+ let font_size = TextSize::Small.rems(cx);
+ let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
+
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.buffer_font.family.clone(),
+ font_fallbacks: settings.buffer_font.fallbacks.clone(),
+ font_features: settings.buffer_font.features.clone(),
+ font_size: font_size.into(),
+ line_height: line_height.into(),
+ ..Default::default()
+ };
+
+ div()
+ .key_context("EditMessageEditor")
+ .on_action(cx.listener(Self::cancel_editing_message))
+ .on_action(cx.listener(Self::confirm_editing_message))
+ .min_h_6()
+ .child(EditorElement::new(
+ &edit_message_editor,
+ EditorStyle {
+ background: colors.editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ syntax: cx.theme().syntax().clone(),
+ ..Default::default()
+ },
+ ))
+ .into_any()
+ } else {
+ div()
+ .min_h_6()
+ .child(self.render_message_content(
+ message_id,
+ rendered_message,
+ has_tool_uses,
+ workspace.clone(),
+ window,
+ cx,
+ ))
+ .into_any()
+ },
+ )
+ })
+ .when(!added_context.is_empty(), |parent| {
+ parent.child(h_flex().flex_wrap().gap_1().children(
+ added_context.into_iter().map(|added_context| {
+ let context = added_context.context.clone();
+ ContextPill::added(added_context, false, false, None).on_click(Rc::new(
+ cx.listener({
let workspace = workspace.clone();
- let context_store = context_store.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
- open_context(
- context_id,
- context_store.clone(),
- workspace,
- window,
- cx,
- );
+ open_context(&context, workspace, window, cx);
cx.notify();
}
}
- })))
- }),
- ))
- })
- });
+ }),
+ ))
+ }),
+ ))
+ })
+ });
let styled_message = match message.role {
Role::User => v_flex()
@@ -3173,20 +3170,14 @@ impl Render for ActiveThread {
}
pub(crate) fn open_context(
- id: ContextId,
- context_store: Entity<ContextStore>,
+ context: &AgentContext,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut App,
) {
- let Some(context) = context_store.read(cx).context_for_id(id) else {
- return;
- };
-
match context {
- AssistantContext::File(file_context) => {
- if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx)
- {
+ AgentContext::File(file_context) => {
+ if let Some(project_path) = file_context.project_path(cx) {
workspace.update(cx, |workspace, cx| {
workspace
.open_path(project_path, None, true, window, cx)
@@ -3194,7 +3185,8 @@ pub(crate) fn open_context(
});
}
}
- AssistantContext::Directory(directory_context) => {
+
+ AgentContext::Directory(directory_context) => {
let entry_id = directory_context.entry_id;
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |_project, cx| {
@@ -3202,61 +3194,51 @@ pub(crate) fn open_context(
})
})
}
- AssistantContext::Symbol(symbol_context) => {
- if let Some(project_path) = symbol_context
- .context_symbol
- .buffer
- .read(cx)
- .project_path(cx)
- {
- let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot();
- let target_position = symbol_context
- .context_symbol
- .id
- .range
- .start
- .to_point(&snapshot);
+ AgentContext::Symbol(symbol_context) => {
+ let buffer = symbol_context.buffer.read(cx);
+ if let Some(project_path) = buffer.project_path(cx) {
+ let snapshot = buffer.snapshot();
+ let target_position = symbol_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
- AssistantContext::Selection(selection_context) => {
- if let Some(project_path) = selection_context
- .context_buffer
- .buffer
- .read(cx)
- .project_path(cx)
- {
- let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot();
+
+ AgentContext::Selection(selection_context) => {
+ let buffer = selection_context.buffer.read(cx);
+ if let Some(project_path) = buffer.project_path(cx) {
+ let snapshot = buffer.snapshot();
let target_position = selection_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
- AssistantContext::FetchedUrl(fetched_url_context) => {
+
+ AgentContext::FetchedUrl(fetched_url_context) => {
cx.open_url(&fetched_url_context.url);
}
- AssistantContext::Thread(thread_context) => {
- let thread_id = thread_context.thread.read(cx).id().clone();
- workspace.update(cx, |workspace, cx| {
- if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
- panel.update(cx, |panel, cx| {
- panel
- .open_thread(&thread_id, window, cx)
- .detach_and_log_err(cx)
- });
- }
- })
- }
- AssistantContext::Rules(rules_context) => window.dispatch_action(
+
+ AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
+ if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ let thread_id = thread_context.thread.read(cx).id().clone();
+ panel
+ .open_thread(&thread_id, window, cx)
+ .detach_and_log_err(cx)
+ });
+ }
+ }),
+
+ AgentContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenRulesLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
- AssistantContext::Image(_) => {}
+
+ AgentContext::Image(_) => {}
}
}
@@ -962,11 +962,13 @@ mod tests {
})
.unwrap();
+ let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
+ prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
@@ -39,6 +39,7 @@ use thread::ThreadId;
pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
+pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
@@ -24,7 +24,7 @@ use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
-use prompt_store::{PromptBuilder, PromptId, UserPromptId};
+use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library};
use settings::{Settings, update_settings_file};
@@ -189,6 +189,7 @@ pub struct AssistantPanel {
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
local_timezone: UtcOffset,
@@ -205,14 +206,25 @@ impl AssistantPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
prompt_builder: Arc<PromptBuilder>,
- cx: AsyncWindowContext,
+ mut cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
+ let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
cx.spawn(async move |cx| {
+ let prompt_store = match prompt_store {
+ Ok(prompt_store) => prompt_store.await.ok(),
+ Err(_) => None,
+ };
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
- ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
+ ThreadStore::load(
+ project,
+ tools.clone(),
+ prompt_store.clone(),
+ prompt_builder.clone(),
+ cx,
+ )
})?
.await?;
@@ -230,7 +242,16 @@ impl AssistantPanel {
.await?;
workspace.update_in(cx, |workspace, window, cx| {
- cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
+ cx.new(|cx| {
+ Self::new(
+ workspace,
+ thread_store,
+ context_store,
+ prompt_store,
+ window,
+ cx,
+ )
+ })
})
})
}
@@ -239,6 +260,7 @@ impl AssistantPanel {
workspace: &Workspace,
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -262,6 +284,7 @@ impl AssistantPanel {
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
+ prompt_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@@ -293,7 +316,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
- message_editor_context_store.clone(),
workspace.clone(),
window,
cx,
@@ -322,6 +344,7 @@ impl AssistantPanel {
message_editor_subscription,
],
context_store,
+ prompt_store,
configuration: None,
configuration_subscription: None,
local_timezone: UtcOffset::from_whole_seconds(
@@ -355,6 +378,10 @@ impl AssistantPanel {
self.local_timezone
}
+ pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
+ &self.prompt_store
+ }
+
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
&self.thread_store
}
@@ -411,7 +438,6 @@ impl AssistantPanel {
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
- message_editor_context_store.clone(),
self.workspace.clone(),
window,
cx,
@@ -430,6 +456,7 @@ impl AssistantPanel {
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
+ self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
window,
@@ -500,9 +527,9 @@ impl AssistantPanel {
None,
))
}),
- action.prompt_to_select.map(|uuid| PromptId::User {
- uuid: UserPromptId(uuid),
- }),
+ action
+ .prompt_to_select
+ .map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -598,7 +625,6 @@ impl AssistantPanel {
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
- message_editor_context_store.clone(),
this.workspace.clone(),
window,
cx,
@@ -617,6 +643,7 @@ impl AssistantPanel {
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
+ this.prompt_store.clone(),
this.thread_store.downgrade(),
thread,
window,
@@ -1876,11 +1903,14 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
else {
return;
};
+ let prompt_store = None;
+ let thread_store = None;
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
- None,
+ prompt_store,
+ thread_store,
window,
cx,
)
@@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if panel.has_active_thread() {
- panel.thread.update(cx, |thread, cx| {
- thread.context_store().update(cx, |store, cx| {
+ panel.message_editor.update(cx, |message_editor, cx| {
+ message_editor.context_store().update(cx, |store, cx| {
let buffer = buffer.read(cx);
let selection_ranges = selection_ranges
.into_iter()
@@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
- store
- .add_selection(buffer, range, cx)
- .detach_and_log_err(cx);
+ store.add_selection(buffer, range, cx);
}
})
})
@@ -1,6 +1,6 @@
-use crate::context::attach_context_to_message;
-use crate::context_store::ContextStore;
+use crate::context::ContextLoadResult;
use crate::inline_prompt_editor::CodegenStatus;
+use crate::{context::load_context, context_store::ContextStore};
use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
@@ -8,7 +8,7 @@ use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
-use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
+use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
@@ -16,7 +16,9 @@ use language_model::{
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
+use project::Project;
use prompt_store::PromptBuilder;
+use prompt_store::PromptStore;
use rope::Rope;
use smol::future::FutureExt;
use std::{
@@ -41,6 +43,8 @@ pub struct BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
+ project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
pub is_insertion: bool,
@@ -52,6 +56,8 @@ impl BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
+ project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -62,6 +68,8 @@ impl BufferCodegen {
range.clone(),
false,
Some(context_store.clone()),
+ project.clone(),
+ prompt_store.clone(),
Some(telemetry.clone()),
builder.clone(),
cx,
@@ -77,6 +85,8 @@ impl BufferCodegen {
range,
initial_transaction_id,
context_store,
+ project,
+ prompt_store,
telemetry,
builder,
};
@@ -155,6 +165,8 @@ impl BufferCodegen {
self.range.clone(),
false,
Some(self.context_store.clone()),
+ self.project.clone(),
+ self.prompt_store.clone(),
Some(self.telemetry.clone()),
self.builder.clone(),
cx,
@@ -231,13 +243,14 @@ pub struct CodegenAlternative {
generation: Task<()>,
diff: Diff,
context_store: Option<Entity<ContextStore>>,
+ project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
active: bool,
edits: Vec<(Range<Anchor>, String)>,
line_operations: Vec<LineOperation>,
- request: Option<LanguageModelRequest>,
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
@@ -251,6 +264,8 @@ impl CodegenAlternative {
range: Range<Anchor>,
active: bool,
context_store: Option<Entity<ContextStore>>,
+ project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -292,6 +307,8 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
context_store,
+ project,
+ prompt_store,
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
@@ -299,7 +316,6 @@ impl CodegenAlternative {
edits: Vec::new(),
line_operations: Vec::new(),
range,
- request: None,
elapsed_time: None,
completion: None,
}
@@ -368,16 +384,18 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(user_prompt, cx)?;
- self.request = Some(request.clone());
-
- cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
+ cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())
}
- fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
+ fn build_request(
+ &self,
+ user_prompt: String,
+ cx: &mut App,
+ ) -> Result<Task<LanguageModelRequest>> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -410,30 +428,44 @@ impl CodegenAlternative {
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
- let mut request_message = LanguageModelRequestMessage {
- role: Role::User,
- content: Vec::new(),
- cache: false,
- };
+ let context_task = self.context_store.as_ref().map(|context_store| {
+ if let Some(project) = self.project.upgrade() {
+ let context = context_store
+ .read(cx)
+ .context()
+ .cloned()
+ .collect::<Vec<_>>();
+ load_context(context, &project, &self.prompt_store, cx)
+ } else {
+ Task::ready(ContextLoadResult::default())
+ }
+ });
- if let Some(context_store) = &self.context_store {
- attach_context_to_message(
- &mut request_message,
- context_store.read(cx).context().iter(),
- cx,
- );
- }
+ Ok(cx.spawn(async move |_cx| {
+ let mut request_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::new(),
+ cache: false,
+ };
- request_message.content.push(prompt.into());
+ if let Some(context_task) = context_task {
+ context_task
+ .await
+ .loaded_context
+ .add_to_request_message(&mut request_message);
+ }
- Ok(LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- tools: Vec::new(),
- stop: Vec::new(),
- temperature: None,
- messages: vec![request_message],
- })
+ request_message.content.push(prompt.into());
+
+ LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: None,
+ messages: vec![request_message],
+ }
+ }))
}
pub fn handle_stream(
@@ -1038,6 +1070,7 @@ impl Diff {
#[cfg(test)]
mod tests {
use super::*;
+ use fs::FakeFs;
use futures::{
Stream,
stream::{self},
@@ -1080,12 +1113,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
+ project.downgrade(),
+ None,
None,
prompt_builder,
cx,
@@ -1144,12 +1181,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
+ project.downgrade(),
+ None,
None,
prompt_builder,
cx,
@@ -1211,12 +1252,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
+ project.downgrade(),
+ None,
None,
prompt_builder,
cx,
@@ -1278,12 +1323,16 @@ mod tests {
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
+ project.downgrade(),
+ None,
None,
prompt_builder,
cx,
@@ -1333,12 +1382,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
false,
None,
+ project.downgrade(),
+ None,
None,
prompt_builder,
cx,
@@ -1,34 +1,25 @@
-use std::{
- ops::Range,
- path::{Path, PathBuf},
- sync::Arc,
-};
+use std::hash::{Hash, Hasher};
+use std::usize;
+use std::{ops::Range, path::Path, sync::Arc};
+use collections::HashSet;
+use futures::future;
use futures::{FutureExt, future::Shared};
-use gpui::{App, Entity, SharedString, Task};
+use gpui::{App, AppContext as _, Entity, SharedString, Task};
use language::Buffer;
-use language_model::{LanguageModelImage, LanguageModelRequestMessage};
-use project::{ProjectEntryId, ProjectPath, Worktree};
-use prompt_store::UserPromptId;
-use rope::Point;
-use serde::{Deserialize, Serialize};
-use text::{Anchor, BufferId};
-use ui::IconName;
-use util::post_inc;
+use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
+use project::{Project, ProjectEntryId, ProjectPath, Worktree};
+use prompt_store::{PromptStore, UserPromptId};
+use ref_cast::RefCast;
+use rope::{Point, Rope};
+use text::{Anchor, OffsetRangeExt as _};
+use ui::{ElementId, IconName};
+use util::{ResultExt as _, post_inc};
use crate::thread::Thread;
pub const RULES_ICON: IconName = IconName::Context;
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-pub struct ContextId(pub(crate) usize);
-
-impl ContextId {
- pub fn post_inc(&mut self) -> Self {
- Self(post_inc(&mut self.0))
- }
-}
-
pub enum ContextKind {
File,
Directory,
@@ -55,307 +46,761 @@ impl ContextKind {
}
}
+/// Handle for context that can be added to a user message.
+///
+/// This uses IDs that are stable enough for tracking renames and identifying when context has
+/// already been added to the thread. To use this in a set, wrap it in `AgentContextKey` to opt in
+/// to `PartialEq` and `Hash` impls that use the subset of the fields used for this stable identity.
#[derive(Debug, Clone)]
-pub enum AssistantContext {
+pub enum AgentContext {
File(FileContext),
Directory(DirectoryContext),
Symbol(SymbolContext),
+ Selection(SelectionContext),
FetchedUrl(FetchedUrlContext),
Thread(ThreadContext),
- Selection(SelectionContext),
Rules(RulesContext),
Image(ImageContext),
}
-impl AssistantContext {
- pub fn id(&self) -> ContextId {
+impl AgentContext {
+ fn id(&self) -> ContextId {
match self {
- Self::File(file) => file.id,
- Self::Directory(directory) => directory.id,
- Self::Symbol(symbol) => symbol.id,
- Self::FetchedUrl(url) => url.id,
- Self::Thread(thread) => thread.id,
- Self::Selection(selection) => selection.id,
- Self::Rules(rules) => rules.id,
- Self::Image(image) => image.id,
+ Self::File(context) => context.context_id,
+ Self::Directory(context) => context.context_id,
+ Self::Symbol(context) => context.context_id,
+ Self::Selection(context) => context.context_id,
+ Self::FetchedUrl(context) => context.context_id,
+ Self::Thread(context) => context.context_id,
+ Self::Rules(context) => context.context_id,
+ Self::Image(context) => context.context_id,
}
}
+
+ pub fn element_id(&self, name: SharedString) -> ElementId {
+ ElementId::NamedInteger(name, self.id().0)
+ }
}
+/// ID created at time of context add, for use in ElementId. This is not the stable identity of a
+/// context, instead that's handled by the `PartialEq` and `Hash` impls of `AgentContextKey`.
+#[derive(Debug, Copy, Clone)]
+pub struct ContextId(usize);
+
+impl ContextId {
+ pub fn zero() -> Self {
+ ContextId(0)
+ }
+
+ fn for_lookup() -> Self {
+ ContextId(usize::MAX)
+ }
+
+ pub fn post_inc(&mut self) -> Self {
+ Self(post_inc(&mut self.0))
+ }
+}
+
+/// File context provides the entire contents of a file.
+///
+/// This holds an `Entity<Buffer>` so that file path renames affect its display and so that it can
+/// be opened even if the file has been deleted. An alternative might be to use `ProjectEntryId`,
+/// but then when deleted there is no path info or ability to open.
#[derive(Debug, Clone)]
pub struct FileContext {
- pub id: ContextId,
- pub context_buffer: ContextBuffer,
+ pub buffer: Entity<Buffer>,
+ pub context_id: ContextId,
}
+impl FileContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.buffer == other.buffer
+ }
+
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.buffer.hash(state)
+ }
+
+ pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
+ let file = self.buffer.read(cx).file()?;
+ Some(ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ })
+ }
+
+ fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
+ let buffer_ref = self.buffer.read(cx);
+ let Some(file) = buffer_ref.file() else {
+ log::error!("file context missing path");
+ return None;
+ };
+ let full_path = file.full_path(cx);
+ let rope = buffer_ref.as_rope().clone();
+ let buffer = self.buffer.clone();
+ Some(
+ cx.background_spawn(
+ async move { (to_fenced_codeblock(&full_path, rope, None), buffer) },
+ ),
+ )
+ }
+}
+
+/// Directory contents provides the entire contents of text files in a directory.
+///
+/// This has a `ProjectEntryId` so that it follows renames.
#[derive(Debug, Clone)]
pub struct DirectoryContext {
- pub id: ContextId,
- pub worktree: Entity<Worktree>,
pub entry_id: ProjectEntryId,
- pub last_path: Arc<Path>,
- /// Buffers of the files within the directory.
- pub context_buffers: Vec<ContextBuffer>,
+ pub context_id: ContextId,
}
impl DirectoryContext {
- pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> {
- self.worktree.read(cx).entry_for_id(self.entry_id)
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.entry_id == other.entry_id
}
- pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
- let worktree = self.worktree.read(cx);
- worktree
- .entry_for_id(self.entry_id)
- .map(|entry| ProjectPath {
- worktree_id: worktree.id(),
- path: entry.path.clone(),
- })
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.entry_id.hash(state)
+ }
+
+ fn load(
+ &self,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Option<Task<Vec<(String, Entity<Buffer>)>>> {
+ let worktree = project.read(cx).worktree_for_entry(self.entry_id, cx)?;
+ let worktree_ref = worktree.read(cx);
+ let entry = worktree_ref.entry_for_id(self.entry_id)?;
+ if entry.is_file() {
+ log::error!("DirectoryContext unexpectedly refers to a file.");
+ return None;
+ }
+
+ let file_paths = collect_files_in_path(worktree_ref, entry.path.as_ref());
+ let texts_future = future::join_all(file_paths.into_iter().map(|path| {
+ load_file_path_text_as_fenced_codeblock(project.clone(), worktree.clone(), path, cx)
+ }));
+
+ Some(cx.background_spawn(async move {
+ texts_future.await.into_iter().flatten().collect::<Vec<_>>()
+ }))
}
}
#[derive(Debug, Clone)]
pub struct SymbolContext {
- pub id: ContextId,
- pub context_symbol: ContextSymbol,
+ pub buffer: Entity<Buffer>,
+ pub symbol: SharedString,
+ pub range: Range<Anchor>,
+ /// The range that fully contain the symbol. e.g. for function symbol, this will include not
+ /// only the signature, but also the body. Not used by `PartialEq` or `Hash` for `AgentContextKey`.
+ pub enclosing_range: Range<Anchor>,
+ pub context_id: ContextId,
+}
+
+impl SymbolContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.buffer == other.buffer && self.symbol == other.symbol && self.range == other.range
+ }
+
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.buffer.hash(state);
+ self.symbol.hash(state);
+ self.range.hash(state);
+ }
+
+ fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
+ let buffer_ref = self.buffer.read(cx);
+ let Some(file) = buffer_ref.file() else {
+ log::error!("symbol context's file has no path");
+ return None;
+ };
+ let full_path = file.full_path(cx);
+ let rope = buffer_ref
+ .text_for_range(self.enclosing_range.clone())
+ .collect::<Rope>();
+ let line_range = self.enclosing_range.to_point(&buffer_ref.snapshot());
+ let buffer = self.buffer.clone();
+ Some(cx.background_spawn(async move {
+ (
+ to_fenced_codeblock(&full_path, rope, Some(line_range)),
+ buffer,
+ )
+ }))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SelectionContext {
+ pub buffer: Entity<Buffer>,
+ pub range: Range<Anchor>,
+ pub context_id: ContextId,
+}
+
+impl SelectionContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.buffer == other.buffer && self.range == other.range
+ }
+
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.buffer.hash(state);
+ self.range.hash(state);
+ }
+
+ fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
+ let buffer_ref = self.buffer.read(cx);
+ let Some(file) = buffer_ref.file() else {
+ log::error!("selection context's file has no path");
+ return None;
+ };
+ let full_path = file.full_path(cx);
+ let rope = buffer_ref
+ .text_for_range(self.range.clone())
+ .collect::<Rope>();
+ let line_range = self.range.to_point(&buffer_ref.snapshot());
+ let buffer = self.buffer.clone();
+ Some(cx.background_spawn(async move {
+ (
+ to_fenced_codeblock(&full_path, rope, Some(line_range)),
+ buffer,
+ )
+ }))
+ }
}
#[derive(Debug, Clone)]
pub struct FetchedUrlContext {
- pub id: ContextId,
pub url: SharedString,
+ /// Text contents of the fetched url. Unlike other context types, the contents of this gets
+ /// populated when added rather than when sending the message. Not used by `PartialEq` or `Hash`
+ /// for `AgentContextKey`.
pub text: SharedString,
+ pub context_id: ContextId,
+}
+
+impl FetchedUrlContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.url == other.url
+ }
+
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.url.hash(state);
+ }
+
+ pub fn lookup_key(url: SharedString) -> AgentContextKey {
+ AgentContextKey(AgentContext::FetchedUrl(FetchedUrlContext {
+ url,
+ text: "".into(),
+ context_id: ContextId::for_lookup(),
+ }))
+ }
}
#[derive(Debug, Clone)]
pub struct ThreadContext {
- pub id: ContextId,
- // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
- // a WeakEntity and handle removal from the UI when it has dropped.
pub thread: Entity<Thread>,
- pub text: SharedString,
+ pub context_id: ContextId,
}
impl ThreadContext {
- pub fn summary(&self, cx: &App) -> SharedString {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.thread == other.thread
+ }
+
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.thread.hash(state)
+ }
+
+ pub fn name(&self, cx: &App) -> SharedString {
self.thread
.read(cx)
.summary()
- .unwrap_or("New thread".into())
+ .unwrap_or_else(|| "New thread".into())
+ }
+
+ pub fn load(&self, cx: &App) -> String {
+ let name = self.name(cx);
+ let contents = self.thread.read(cx).latest_detailed_summary_or_text();
+ let mut text = String::new();
+ text.push_str(&name);
+ text.push('\n');
+ text.push_str(&contents.trim());
+ text.push('\n');
+ text
}
}
#[derive(Debug, Clone)]
-pub struct ImageContext {
- pub id: ContextId,
- pub original_image: Arc<gpui::Image>,
- pub image_task: Shared<Task<Option<LanguageModelImage>>>,
+pub struct RulesContext {
+ pub prompt_id: UserPromptId,
+ pub context_id: ContextId,
}
-impl ImageContext {
- pub fn image(&self) -> Option<LanguageModelImage> {
- self.image_task.clone().now_or_never().flatten()
+impl RulesContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.prompt_id == other.prompt_id
}
- pub fn is_loading(&self) -> bool {
- self.image_task.clone().now_or_never().is_none()
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.prompt_id.hash(state)
}
- pub fn is_error(&self) -> bool {
- self.image_task
- .clone()
- .now_or_never()
- .map(|result| result.is_none())
- .unwrap_or(false)
+ pub fn lookup_key(prompt_id: UserPromptId) -> AgentContextKey {
+ AgentContextKey(AgentContext::Rules(RulesContext {
+ prompt_id,
+ context_id: ContextId::for_lookup(),
+ }))
+ }
+
+ pub fn load(
+ &self,
+ prompt_store: &Option<Entity<PromptStore>>,
+ cx: &App,
+ ) -> Task<Option<String>> {
+ let Some(prompt_store) = prompt_store.as_ref() else {
+ return Task::ready(None);
+ };
+ let prompt_store = prompt_store.read(cx);
+ let prompt_id = self.prompt_id.into();
+ let Some(metadata) = prompt_store.metadata(prompt_id) else {
+ return Task::ready(None);
+ };
+ let contents_task = prompt_store.load(prompt_id, cx);
+ cx.background_spawn(async move {
+ let contents = contents_task.await.ok()?;
+ let mut text = String::new();
+ if let Some(title) = metadata.title {
+ text.push_str("Rules title: ");
+ text.push_str(&title);
+ text.push('\n');
+ }
+ text.push_str("``````\n");
+ text.push_str(contents.trim());
+ text.push_str("\n``````\n");
+ Some(text)
+ })
}
}
-#[derive(Clone)]
-pub struct ContextBuffer {
- pub id: BufferId,
- // TODO: Entity<Buffer> holds onto the buffer even if the buffer is deleted. Should probably be
- // a WeakEntity and handle removal from the UI when it has dropped.
- pub buffer: Entity<Buffer>,
- pub last_full_path: Arc<Path>,
- pub version: clock::Global,
- pub text: SharedString,
+#[derive(Debug, Clone)]
+pub struct ImageContext {
+ pub original_image: Arc<gpui::Image>,
+ // TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml
+ // needed due to a false positive of `clippy::mutable_key_type`.
+ pub image_task: Shared<Task<Option<LanguageModelImage>>>,
+ pub context_id: ContextId,
}
-impl ContextBuffer {
- pub fn full_path(&self, cx: &App) -> PathBuf {
- let file = self.buffer.read(cx).file();
- // Note that in practice file can't be `None` because it is present when this is created and
- // there's no way for buffers to go from having a file to not.
- file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx))
- }
+pub enum ImageStatus {
+ Loading,
+ Error,
+ Ready,
}
-impl std::fmt::Debug for ContextBuffer {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("ContextBuffer")
- .field("id", &self.id)
- .field("buffer", &self.buffer)
- .field("version", &self.version)
- .field("text", &self.text)
- .finish()
+impl ImageContext {
+ pub fn eq_for_key(&self, other: &Self) -> bool {
+ self.original_image.id == other.original_image.id
}
-}
-#[derive(Debug, Clone)]
-pub struct ContextSymbol {
- pub id: ContextSymbolId,
- pub buffer: Entity<Buffer>,
- pub buffer_version: clock::Global,
- /// The range that the symbol encloses, e.g. for function symbol, this will
- /// include not only the signature, but also the body
- pub enclosing_range: Range<Anchor>,
- pub text: SharedString,
+ pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
+ self.original_image.id.hash(state);
+ }
+
+ pub fn image(&self) -> Option<LanguageModelImage> {
+ self.image_task.clone().now_or_never().flatten()
+ }
+
+ pub fn status(&self) -> ImageStatus {
+ match self.image_task.clone().now_or_never() {
+ None => ImageStatus::Loading,
+ Some(None) => ImageStatus::Error,
+ Some(Some(_)) => ImageStatus::Ready,
+ }
+ }
}
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-pub struct ContextSymbolId {
- pub path: ProjectPath,
- pub name: SharedString,
- pub range: Range<Anchor>,
+#[derive(Debug, Clone, Default)]
+pub struct ContextLoadResult {
+ pub loaded_context: LoadedContext,
+ pub referenced_buffers: HashSet<Entity<Buffer>>,
}
-#[derive(Debug, Clone)]
-pub struct SelectionContext {
- pub id: ContextId,
- pub range: Range<Anchor>,
- pub line_range: Range<Point>,
- pub context_buffer: ContextBuffer,
+#[derive(Debug, Clone, Default)]
+pub struct LoadedContext {
+ pub contexts: Vec<AgentContext>,
+ pub text: String,
+ pub images: Vec<LanguageModelImage>,
}
-#[derive(Debug, Clone)]
-pub struct RulesContext {
- pub id: ContextId,
- pub prompt_id: UserPromptId,
- pub title: SharedString,
- pub text: SharedString,
+impl LoadedContext {
+ pub fn is_empty(&self) -> bool {
+ self.text.is_empty() && self.images.is_empty()
+ }
+
+ pub fn add_to_request_message(&self, request_message: &mut LanguageModelRequestMessage) {
+ if !self.text.is_empty() {
+ request_message
+ .content
+ .push(MessageContent::Text(self.text.to_string()));
+ }
+
+ if !self.images.is_empty() {
+ // Some providers only support image parts after an initial text part
+ if request_message.content.is_empty() {
+ request_message
+ .content
+ .push(MessageContent::Text("Images attached by user:".to_string()));
+ }
+
+ for image in &self.images {
+ request_message
+ .content
+ .push(MessageContent::Image(image.clone()))
+ }
+ }
+ }
}
-/// Formats a collection of contexts into a string representation
-pub fn format_context_as_string<'a>(
- contexts: impl Iterator<Item = &'a AssistantContext>,
- cx: &App,
-) -> Option<String> {
- let mut file_context = Vec::new();
- let mut directory_context = Vec::new();
- let mut symbol_context = Vec::new();
- let mut selection_context = Vec::new();
+/// Loads and formats a collection of contexts.
+pub fn load_context(
+ contexts: Vec<AgentContext>,
+ project: &Entity<Project>,
+ prompt_store: &Option<Entity<PromptStore>>,
+ cx: &mut App,
+) -> Task<ContextLoadResult> {
+ let mut file_tasks = Vec::new();
+ let mut directory_tasks = Vec::new();
+ let mut symbol_tasks = Vec::new();
+ let mut selection_tasks = Vec::new();
let mut fetch_context = Vec::new();
let mut thread_context = Vec::new();
- let mut rules_context = Vec::new();
+ let mut rules_tasks = Vec::new();
+ let mut image_tasks = Vec::new();
- for context in contexts {
+ for context in contexts.iter().cloned() {
match context {
- AssistantContext::File(context) => file_context.push(context),
- AssistantContext::Directory(context) => directory_context.push(context),
- AssistantContext::Symbol(context) => symbol_context.push(context),
- AssistantContext::Selection(context) => selection_context.push(context),
- AssistantContext::FetchedUrl(context) => fetch_context.push(context),
- AssistantContext::Thread(context) => thread_context.push(context),
- AssistantContext::Rules(context) => rules_context.push(context),
- AssistantContext::Image(_) => {}
+ AgentContext::File(context) => file_tasks.extend(context.load(cx)),
+ AgentContext::Directory(context) => {
+ directory_tasks.extend(context.load(project.clone(), cx))
+ }
+ AgentContext::Symbol(context) => symbol_tasks.extend(context.load(cx)),
+ AgentContext::Selection(context) => selection_tasks.extend(context.load(cx)),
+ AgentContext::FetchedUrl(context) => fetch_context.push(context),
+ AgentContext::Thread(context) => thread_context.push(context.load(cx)),
+ AgentContext::Rules(context) => rules_tasks.push(context.load(prompt_store, cx)),
+ AgentContext::Image(context) => image_tasks.push(context.image_task.clone()),
}
}
- if file_context.is_empty()
- && directory_context.is_empty()
- && symbol_context.is_empty()
- && selection_context.is_empty()
- && fetch_context.is_empty()
- && thread_context.is_empty()
- && rules_context.is_empty()
- {
- return None;
- }
+ cx.background_spawn(async move {
+ let (
+ file_context,
+ directory_context,
+ symbol_context,
+ selection_context,
+ rules_context,
+ images,
+ ) = futures::join!(
+ future::join_all(file_tasks),
+ future::join_all(directory_tasks),
+ future::join_all(symbol_tasks),
+ future::join_all(selection_tasks),
+ future::join_all(rules_tasks),
+ future::join_all(image_tasks)
+ );
+
+ let directory_context = directory_context.into_iter().flatten().collect::<Vec<_>>();
+ let rules_context = rules_context.into_iter().flatten().collect::<Vec<_>>();
+ let images = images.into_iter().flatten().collect::<Vec<_>>();
+
+ let mut referenced_buffers = HashSet::default();
+ let mut text = String::new();
+
+ if file_context.is_empty()
+ && directory_context.is_empty()
+ && symbol_context.is_empty()
+ && selection_context.is_empty()
+ && fetch_context.is_empty()
+ && thread_context.is_empty()
+ && rules_context.is_empty()
+ {
+ return ContextLoadResult {
+ loaded_context: LoadedContext {
+ contexts,
+ text,
+ images,
+ },
+ referenced_buffers,
+ };
+ }
- let mut result = String::new();
- result.push_str("\n<context>\n\
- The following items were attached by the user. You don't need to use other tools to read them.\n\n");
+ text.push_str(
+ "\n<context>\n\
+ The following items were attached by the user. \
+ You don't need to use other tools to read them.\n\n",
+ );
- if !file_context.is_empty() {
- result.push_str("<files>\n");
- for context in file_context {
- result.push_str(&context.context_buffer.text);
+ if !file_context.is_empty() {
+ text.push_str("<files>");
+ for (file_text, buffer) in file_context {
+ text.push('\n');
+ text.push_str(&file_text);
+ referenced_buffers.insert(buffer);
+ }
+ text.push_str("</files>\n");
}
- result.push_str("</files>\n");
- }
- if !directory_context.is_empty() {
- result.push_str("<directories>\n");
- for context in directory_context {
- for context_buffer in &context.context_buffers {
- result.push_str(&context_buffer.text);
+ if !directory_context.is_empty() {
+ text.push_str("<directories>");
+ for (file_text, buffer) in directory_context {
+ text.push('\n');
+ text.push_str(&file_text);
+ referenced_buffers.insert(buffer);
}
+ text.push_str("</directories>\n");
}
- result.push_str("</directories>\n");
- }
- if !symbol_context.is_empty() {
- result.push_str("<symbols>\n");
- for context in symbol_context {
- result.push_str(&context.context_symbol.text);
- result.push('\n');
+ if !symbol_context.is_empty() {
+ text.push_str("<symbols>");
+ for (symbol_text, buffer) in symbol_context {
+ text.push('\n');
+ text.push_str(&symbol_text);
+ referenced_buffers.insert(buffer);
+ }
+ text.push_str("</symbols>\n");
}
- result.push_str("</symbols>\n");
- }
- if !selection_context.is_empty() {
- result.push_str("<selections>\n");
- for context in selection_context {
- result.push_str(&context.context_buffer.text);
- result.push('\n');
+ if !selection_context.is_empty() {
+ text.push_str("<selections>");
+ for (selection_text, buffer) in selection_context {
+ text.push('\n');
+ text.push_str(&selection_text);
+ referenced_buffers.insert(buffer);
+ }
+ text.push_str("</selections>\n");
}
- result.push_str("</selections>\n");
- }
- if !fetch_context.is_empty() {
- result.push_str("<fetched_urls>\n");
- for context in &fetch_context {
- result.push_str(&context.url);
- result.push('\n');
- result.push_str(&context.text);
- result.push('\n');
+ if !fetch_context.is_empty() {
+ text.push_str("<fetched_urls>");
+ for context in fetch_context {
+ text.push('\n');
+ text.push_str(&context.url);
+ text.push('\n');
+ text.push_str(&context.text);
+ }
+ text.push_str("</fetched_urls>\n");
+ }
+
+ if !thread_context.is_empty() {
+ text.push_str("<conversation_threads>");
+ for thread_text in thread_context {
+ text.push('\n');
+ text.push_str(&thread_text);
+ }
+ text.push_str("</conversation_threads>\n");
}
- result.push_str("</fetched_urls>\n");
- }
- if !thread_context.is_empty() {
- result.push_str("<conversation_threads>\n");
- for context in &thread_context {
- result.push_str(&context.summary(cx));
- result.push('\n');
- result.push_str(&context.text);
- result.push('\n');
+ if !rules_context.is_empty() {
+ text.push_str(
+ "<user_rules>\n\
+ The user has specified the following rules that should be applied:\n",
+ );
+ for rules_text in rules_context {
+ text.push('\n');
+ text.push_str(&rules_text);
+ }
+ text.push_str("</user_rules>\n");
+ }
+
+ text.push_str("</context>\n");
+
+ ContextLoadResult {
+ loaded_context: LoadedContext {
+ contexts,
+ text,
+ images,
+ },
+ referenced_buffers,
+ }
+ })
+}
+
+fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
+ let mut files = Vec::new();
+
+ for entry in worktree.child_entries(path) {
+ if entry.is_dir() {
+ files.extend(collect_files_in_path(worktree, &entry.path));
+ } else if entry.is_file() {
+ files.push(entry.path.clone());
}
- result.push_str("</conversation_threads>\n");
}
- if !rules_context.is_empty() {
- result.push_str(
- "<user_rules>\n\
- The user has specified the following rules that should be applied:\n\n",
- );
- for context in &rules_context {
- result.push_str(&context.text);
- result.push('\n');
+ files
+}
+
+fn load_file_path_text_as_fenced_codeblock(
+ project: Entity<Project>,
+ worktree: Entity<Worktree>,
+ path: Arc<Path>,
+ cx: &mut App,
+) -> Task<Option<(String, Entity<Buffer>)>> {
+ let worktree_ref = worktree.read(cx);
+ let worktree_id = worktree_ref.id();
+ let full_path = worktree_ref.full_path(&path);
+
+ let open_task = project.update(cx, |project, cx| {
+ project.buffer_store().update(cx, |buffer_store, cx| {
+ let project_path = ProjectPath { worktree_id, path };
+ buffer_store.open_buffer(project_path, cx)
+ })
+ });
+
+ let rope_task = cx.spawn(async move |cx| {
+ let buffer = open_task.await.log_err()?;
+ let rope = buffer
+ .read_with(cx, |buffer, _cx| buffer.as_rope().clone())
+ .log_err()?;
+ Some((rope, buffer))
+ });
+
+ cx.background_spawn(async move {
+ let (rope, buffer) = rope_task.await?;
+ Some((to_fenced_codeblock(&full_path, rope, None), buffer))
+ })
+}
+
+fn to_fenced_codeblock(
+ full_path: &Path,
+ content: Rope,
+ line_range: Option<Range<Point>>,
+) -> String {
+ let line_range_text = line_range.map(|range| {
+ if range.start.row == range.end.row {
+ format!(":{}", range.start.row + 1)
+ } else {
+ format!(":{}-{}", range.start.row + 1, range.end.row + 1)
}
- result.push_str("</user_rules>\n");
+ });
+
+ let path_extension = full_path.extension().and_then(|ext| ext.to_str());
+ let path_string = full_path.to_string_lossy();
+ let capacity = 3
+ + path_extension.map_or(0, |extension| extension.len() + 1)
+ + path_string.len()
+ + line_range_text.as_ref().map_or(0, |text| text.len())
+ + 1
+ + content.len()
+ + 5;
+ let mut buffer = String::with_capacity(capacity);
+
+ buffer.push_str("```");
+
+ if let Some(extension) = path_extension {
+ buffer.push_str(extension);
+ buffer.push(' ');
+ }
+ buffer.push_str(&path_string);
+
+ if let Some(line_range_text) = line_range_text {
+ buffer.push_str(&line_range_text);
+ }
+
+ buffer.push('\n');
+ for chunk in content.chunks() {
+ buffer.push_str(chunk);
}
- result.push_str("</context>\n");
- Some(result)
+ if !buffer.ends_with('\n') {
+ buffer.push('\n');
+ }
+
+ buffer.push_str("```\n");
+
+ debug_assert!(
+ buffer.len() == capacity - 1 || buffer.len() == capacity,
+ "to_fenced_codeblock calculated capacity of {}, but length was {}",
+ capacity,
+ buffer.len(),
+ );
+
+ buffer
+}
+
+/// Wraps `AgentContext` to opt-in to `PartialEq` and `Hash` impls which use a subset of fields
+/// needed for stable context identity.
+#[derive(Debug, Clone, RefCast)]
+#[repr(transparent)]
+pub struct AgentContextKey(pub AgentContext);
+
+impl AsRef<AgentContext> for AgentContextKey {
+ fn as_ref(&self) -> &AgentContext {
+ &self.0
+ }
+}
+
+impl Eq for AgentContextKey {}
+
+impl PartialEq for AgentContextKey {
+ fn eq(&self, other: &Self) -> bool {
+ match &self.0 {
+ AgentContext::File(context) => {
+ if let AgentContext::File(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Directory(context) => {
+ if let AgentContext::Directory(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Symbol(context) => {
+ if let AgentContext::Symbol(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Selection(context) => {
+ if let AgentContext::Selection(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::FetchedUrl(context) => {
+ if let AgentContext::FetchedUrl(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Thread(context) => {
+ if let AgentContext::Thread(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Rules(context) => {
+ if let AgentContext::Rules(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ AgentContext::Image(context) => {
+ if let AgentContext::Image(other_context) = &other.0 {
+ return context.eq_for_key(other_context);
+ }
+ }
+ }
+ false
+ }
}
-pub fn attach_context_to_message<'a>(
- message: &mut LanguageModelRequestMessage,
- contexts: impl Iterator<Item = &'a AssistantContext>,
- cx: &App,
-) {
- if let Some(context_string) = format_context_as_string(contexts, cx) {
- message.content.push(context_string.into());
+impl Hash for AgentContextKey {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ match &self.0 {
+ AgentContext::File(context) => context.hash_for_key(state),
+ AgentContext::Directory(context) => context.hash_for_key(state),
+ AgentContext::Symbol(context) => context.hash_for_key(state),
+ AgentContext::Selection(context) => context.hash_for_key(state),
+ AgentContext::FetchedUrl(context) => context.hash_for_key(state),
+ AgentContext::Thread(context) => context.hash_for_key(state),
+ AgentContext::Rules(context) => context.hash_for_key(state),
+ AgentContext::Image(context) => context.hash_for_key(state),
+ }
}
}
@@ -10,8 +10,11 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Result, anyhow};
+pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
+use fetch_context_picker::FetchContextPicker;
+use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
use gpui::{
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
@@ -20,10 +23,10 @@ use gpui::{
use language::Buffer;
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
-use prompt_store::UserPromptId;
-use rules_context_picker::RulesContextEntry;
+use prompt_store::{PromptStore, UserPromptId};
+use rules_context_picker::{RulesContextEntry, RulesContextPicker};
use symbol_context_picker::SymbolContextPicker;
-use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
+use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
@@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
-pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
-use crate::context_picker::fetch_context_picker::FetchContextPicker;
-use crate::context_picker::file_context_picker::FileContextPicker;
-use crate::context_picker::rules_context_picker::RulesContextPicker;
-use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
+ prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -193,6 +192,13 @@ impl ContextPicker {
)
.collect::<Vec<Subscription>>();
+ let prompt_store = thread_store.as_ref().and_then(|thread_store| {
+ thread_store
+ .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
+ .ok()
+ .flatten()
+ });
+
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
window,
@@ -202,6 +208,7 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
+ prompt_store,
_subscriptions: subscriptions,
}
}
@@ -226,7 +233,12 @@ impl ContextPicker {
.workspace
.upgrade()
.map(|workspace| {
- available_context_picker_entries(&self.thread_store, &workspace, cx)
+ available_context_picker_entries(
+ &self.prompt_store,
+ &self.thread_store,
+ &workspace,
+ cx,
+ )
})
.unwrap_or_default();
@@ -304,10 +316,10 @@ impl ContextPicker {
}));
}
ContextPickerMode::Rules => {
- if let Some(thread_store) = self.thread_store.as_ref() {
+ if let Some(prompt_store) = self.prompt_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
- thread_store.clone(),
+ prompt_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
@@ -526,6 +538,7 @@ enum RecentEntry {
}
fn available_context_picker_entries(
+ prompt_store: &Option<Entity<PromptStore>>,
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
cx: &mut App,
@@ -550,6 +563,9 @@ fn available_context_picker_entries(
if thread_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
+ }
+
+ if prompt_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
}
@@ -585,22 +601,21 @@ fn recent_context_picker_entries(
}),
);
- let mut current_threads = context_store.read(cx).thread_ids();
+ let current_threads = context_store.read(cx).thread_ids();
- if let Some(active_thread) = workspace
+ let active_thread_id = workspace
.panel::<AssistantPanel>(cx)
- .map(|panel| panel.read(cx).active_thread(cx))
- {
- current_threads.insert(active_thread.read(cx).id().clone());
- }
+ .map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
recent.extend(
thread_store
.read(cx)
- .threads()
+ .reverse_chronological_threads()
.into_iter()
- .filter(|thread| !current_threads.contains(&thread.id))
+ .filter(|thread| {
+ Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
+ })
.take(2)
.map(|thread| {
RecentEntry::Thread(ThreadContextEntry {
@@ -622,9 +637,7 @@ fn add_selections_as_context(
let selection_ranges = selection_ranges(workspace, cx);
context_store.update(cx, |context_store, cx| {
for (buffer, range) in selection_ranges {
- context_store
- .add_selection(buffer, range, cx)
- .detach_and_log_err(cx);
+ context_store.add_selection(buffer, range, cx);
}
})
}
@@ -15,22 +15,21 @@ use itertools::Itertools;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
-use prompt_store::PromptId;
+use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
-use crate::context_picker::file_context_picker::search_files;
-use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
-use super::file_context_picker::FileMatch;
+use super::file_context_picker::{FileMatch, search_files};
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
+use super::symbol_context_picker::search_symbols;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
@@ -38,8 +37,8 @@ use super::{
};
pub(crate) enum Match {
- Symbol(SymbolMatch),
File(FileMatch),
+ Symbol(SymbolMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
@@ -69,6 +68,7 @@ fn search(
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
+ prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
@@ -85,6 +85,7 @@ fn search(
.collect()
})
}
+
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
@@ -96,6 +97,7 @@ fn search(
.collect()
})
}
+
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
@@ -111,6 +113,7 @@ fn search(
Task::ready(Vec::new())
}
}
+
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
@@ -118,10 +121,11 @@ fn search(
Task::ready(Vec::new())
}
}
+
Some(ContextPickerMode::Rules) => {
- if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
+ if let Some(prompt_store) = prompt_store.as_ref() {
let search_rules_task =
- search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
+ search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
@@ -133,6 +137,7 @@ fn search(
Task::ready(Vec::new())
}
}
+
None => {
if query.is_empty() {
let mut matches = recent_entries
@@ -163,7 +168,7 @@ fn search(
.collect::<Vec<_>>();
matches.extend(
- available_context_picker_entries(&thread_store, &workspace, cx)
+ available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
@@ -180,7 +185,8 @@ fn search(
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
- let entries = available_context_picker_entries(&thread_store, &workspace, cx);
+ let entries =
+ available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
let entry_candidates = entries
.iter()
.enumerate()
@@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
move |_, _: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
- context_store
- .add_selection(buffer.clone(), range.clone(), cx)
- .detach_and_log_err(cx)
+ context_store.add_selection(
+ buffer.clone(),
+ range.clone(),
+ cx,
+ );
}
});
@@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
- thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
@@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
move |cx| {
- let prompt_uuid = rules.prompt_id;
- let prompt_id = PromptId::User { uuid: prompt_uuid };
- let context_store = context_store.clone();
- let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
- log::error!("Can't add user rules as prompt store is missing.");
- return;
- };
- let prompt_store = prompt_store.read(cx);
- let Some(metadata) = prompt_store.metadata(prompt_id) else {
- return;
- };
- let Some(title) = metadata.title else {
- return;
- };
- let text_task = prompt_store.load(prompt_id, cx);
-
- cx.spawn(async move |cx| {
- let text = text_task.await?;
- context_store.update(cx, |context_store, cx| {
- context_store.add_rules(prompt_uuid, title, text, false, cx)
- })
- })
- .detach_and_log_err(cx);
+ let user_prompt_id = rules.prompt_id;
+ context_store.update(cx, |context_store, cx| {
+ context_store.add_rules(user_prompt_id, false, cx);
+ });
},
)),
}
@@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
let url_to_fetch = url_to_fetch.clone();
cx.spawn(async move |cx| {
if context_store.update(cx, |context_store, _| {
- context_store.includes_url(&url_to_fetch).is_some()
+ context_store.includes_url(&url_to_fetch)
})? {
return Ok(());
}
@@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
move |cx| {
context_store.update(cx, |context_store, cx| {
let task = if is_directory {
- context_store.add_directory(project_path.clone(), false, cx)
+ Task::ready(context_store.add_directory(&project_path, false, cx))
} else {
context_store.add_file_from_path(project_path.clone(), false, cx)
};
@@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
);
+ let prompt_store = thread_store.as_ref().and_then(|thread_store| {
+ thread_store
+ .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
+ .ok()
+ .flatten()
+ });
+
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
+ prompt_store,
thread_store.clone(),
workspace.clone(),
cx,
@@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
))
}
+
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
@@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
workspace.clone(),
cx,
),
+
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
@@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
- Match::Rules(user_rules) => {
- let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
- Some(Self::completion_for_rules(
- user_rules,
- excerpt_id,
- source_range.clone(),
- editor.clone(),
- context_store.clone(),
- thread_store,
- ))
- }
+
+ Match::Rules(user_rules) => Some(Self::completion_for_rules(
+ user_rules,
+ excerpt_id,
+ source_range.clone(),
+ editor.clone(),
+ context_store.clone(),
+ )),
+
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
@@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
+
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,
@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| {
- context_store.read(cx).includes_url(&self.url).is_some()
+ context_store.read(cx).includes_url(&self.url)
});
Some(
@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
.context_store
.update(cx, |context_store, cx| {
if is_directory {
- context_store.add_directory(project_path, true, cx)
+ Task::ready(context_store.add_directory(&project_path, true, cx))
} else {
- context_store.add_file_from_path(project_path, true, cx)
+ context_store.add_file_from_path(project_path.clone(), true, cx)
}
})
.ok()
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
path: path.clone(),
};
if is_directory {
- context_store.read(cx).includes_directory(&project_path)
- } else {
context_store
.read(cx)
- .will_include_file_path(&project_path, cx)
+ .path_included_in_directory(&project_path, cx)
+ } else {
+ context_store.read(cx).file_path_included(&project_path, cx)
}
});
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
})),
)
.when_some(added, |el, added| match added {
- FileInclusion::Direct(_) => el.child(
+ FileInclusion::Direct => el.child(
h_flex()
.w_full()
.justify_end()
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
- FileInclusion::InDirectory(directory_project_path) => {
- // TODO: Consider using worktree full_path to include worktree name.
- let directory_path = directory_project_path.path.to_string_lossy().into_owned();
+ FileInclusion::InDirectory { full_path } => {
+ let directory_full_path = full_path.to_string_lossy().into_owned();
el.child(
h_flex()
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
- .tooltip(Tooltip::text(format!("in {directory_path}")))
+ .tooltip(Tooltip::text(format!("in {directory_full_path}")))
}
})
}
@@ -1,16 +1,15 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
-use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
-use prompt_store::{PromptId, UserPromptId};
+use prompt_store::{PromptId, PromptStore, UserPromptId};
use ui::{ListItem, prelude::*};
+use util::ResultExt as _;
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
-use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
@@ -18,13 +17,13 @@ pub struct RulesContextPicker {
impl RulesContextPicker {
pub fn new(
- thread_store: WeakEntity<ThreadStore>,
+ prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
+ let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
@@ -50,7 +49,7 @@ pub struct RulesContextEntry {
}
pub struct RulesContextPickerDelegate {
- thread_store: WeakEntity<ThreadStore>,
+ prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
@@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
impl RulesContextPickerDelegate {
pub fn new(
- thread_store: WeakEntity<ThreadStore>,
+ prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
- thread_store,
+ prompt_store,
context_picker,
context_store,
matches: Vec::new(),
@@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
- let Some(thread_store) = self.thread_store.upgrade() else {
- return Task::ready(());
- };
-
- let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
+ let search_task = search_rules(
+ query,
+ Arc::new(AtomicBool::default()),
+ &self.prompt_store,
+ cx,
+ );
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
return;
};
- let Some(thread_store) = self.thread_store.upgrade() else {
- return;
- };
-
- let prompt_id = entry.prompt_id;
-
- let load_rules_task = thread_store.update(cx, |thread_store, cx| {
- thread_store.load_rules(prompt_id, cx)
- });
-
- cx.spawn(async move |this, cx| {
- let (metadata, text) = load_rules_task.await?;
- let Some(title) = metadata.title else {
- return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
- };
- this.update(cx, |this, cx| {
- this.delegate
- .context_store
- .update(cx, |context_store, cx| {
- context_store.add_rules(prompt_id, title, text, true, cx)
- })
- .ok();
+ self.context_store
+ .update(cx, |context_store, cx| {
+ context_store.add_rules(entry.prompt_id, true, cx)
})
- })
- .detach_and_log_err(cx);
+ .log_err();
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
@@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
- let added = context_store.upgrade().map_or(false, |ctx_store| {
- ctx_store
+ let added = context_store.upgrade().map_or(false, |context_store| {
+ context_store
.read(cx)
- .includes_user_rules(&user_rules.prompt_id)
- .is_some()
+ .includes_user_rules(user_rules.prompt_id)
});
h_flex()
@@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
- thread_store: Entity<ThreadStore>,
+ prompt_store: &Entity<PromptStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
- let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
- return Task::ready(vec![]);
- };
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task
@@ -10,7 +10,6 @@ use gpui::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use project::{DocumentSymbol, Symbol};
-use text::OffsetRangeExt;
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
@@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
)
})?;
- context_store
- .update(cx, move |context_store, cx| {
- context_store.add_symbol(
- buffer,
- name.into(),
- range,
- enclosing_range,
- remove_if_exists,
- cx,
- )
- })?
- .await
+ context_store.update(cx, move |context_store, cx| {
+ context_store.add_symbol(
+ buffer,
+ name.into(),
+ range,
+ enclosing_range,
+ remove_if_exists,
+ cx,
+ )
+ })
})
}
@@ -353,38 +350,13 @@ fn compute_symbol_entries(
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
- let mut symbol_entries = Vec::with_capacity(symbols.len());
- for SymbolMatch { symbol, .. } in symbols {
- let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
- let is_included = if let Some(symbols_for_path) = symbols_for_path {
- let mut is_included = false;
- for included_symbol_id in symbols_for_path {
- if included_symbol_id.name.as_ref() == symbol.name.as_str() {
- if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
- let snapshot = buffer.read(cx).snapshot();
- let included_symbol_range =
- included_symbol_id.range.to_point_utf16(&snapshot);
-
- if included_symbol_range.start == symbol.range.start.0
- && included_symbol_range.end == symbol.range.end.0
- {
- is_included = true;
- break;
- }
- }
- }
- }
- is_included
- } else {
- false
- };
-
- symbol_entries.push(SymbolEntry {
+ symbols
+ .into_iter()
+ .map(|SymbolMatch { symbol, .. }| SymbolEntry {
+ is_included: context_store.includes_symbol(&symbol, cx),
symbol,
- is_included,
})
- }
- symbol_entries
+ .collect::<Vec<_>>()
}
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
- ctx_store.read(cx).includes_thread(&thread.id).is_some()
+ ctx_store.read(cx).includes_thread(&thread.id)
});
h_flex()
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store
.read(cx)
- .threads()
+ .reverse_chronological_threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,
@@ -1,43 +1,35 @@
use std::ops::Range;
-use std::path::Path;
+use std::path::PathBuf;
use std::sync::Arc;
-use anyhow::{Context as _, Result, anyhow};
-use collections::{BTreeMap, HashMap, HashSet};
+use anyhow::{Result, anyhow};
+use collections::{HashSet, IndexSet};
use futures::future::join_all;
-use futures::{self, Future, FutureExt, future};
-use gpui::{App, AppContext as _, Context, Entity, Image, SharedString, Task, WeakEntity};
+use futures::{self, FutureExt};
+use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
use language::Buffer;
use language_model::LanguageModelImage;
-use project::{Project, ProjectEntryId, ProjectItem, ProjectPath, Worktree};
+use project::{Project, ProjectItem, ProjectPath, Symbol};
use prompt_store::UserPromptId;
-use rope::{Point, Rope};
-use text::{Anchor, BufferId, OffsetRangeExt};
-use util::{ResultExt as _, maybe};
+use ref_cast::RefCast as _;
+use text::{Anchor, OffsetRangeExt};
+use util::ResultExt as _;
use crate::ThreadStore;
use crate::context::{
- AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
- FetchedUrlContext, FileContext, ImageContext, RulesContext, SelectionContext, SymbolContext,
- ThreadContext,
+ AgentContext, AgentContextKey, ContextId, DirectoryContext, FetchedUrlContext, FileContext,
+ ImageContext, RulesContext, SelectionContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
pub struct ContextStore {
project: WeakEntity<Project>,
- context: Vec<AssistantContext>,
thread_store: Option<WeakEntity<ThreadStore>>,
- next_context_id: ContextId,
- files: BTreeMap<BufferId, ContextId>,
- directories: HashMap<ProjectPath, ContextId>,
- symbols: HashMap<ContextSymbolId, ContextId>,
- symbol_buffers: HashMap<ContextSymbolId, Entity<Buffer>>,
- symbols_by_path: HashMap<ProjectPath, Vec<ContextSymbolId>>,
- threads: HashMap<ThreadId, ContextId>,
thread_summary_tasks: Vec<Task<()>>,
- fetched_urls: HashMap<String, ContextId>,
- user_rules: HashMap<UserPromptId, ContextId>,
+ next_context_id: ContextId,
+ context_set: IndexSet<AgentContextKey>,
+ context_thread_ids: HashSet<ThreadId>,
}
impl ContextStore {
@@ -48,35 +40,33 @@ impl ContextStore {
Self {
project,
thread_store,
- context: Vec::new(),
- next_context_id: ContextId(0),
- files: BTreeMap::default(),
- directories: HashMap::default(),
- symbols: HashMap::default(),
- symbol_buffers: HashMap::default(),
- symbols_by_path: HashMap::default(),
- threads: HashMap::default(),
thread_summary_tasks: Vec::new(),
- fetched_urls: HashMap::default(),
- user_rules: HashMap::default(),
+ next_context_id: ContextId::zero(),
+ context_set: IndexSet::default(),
+ context_thread_ids: HashSet::default(),
}
}
- pub fn context(&self) -> &Vec<AssistantContext> {
- &self.context
+ pub fn context(&self) -> impl Iterator<Item = &AgentContext> {
+ self.context_set.iter().map(|entry| entry.as_ref())
}
- pub fn context_for_id(&self, id: ContextId) -> Option<&AssistantContext> {
- self.context().iter().find(|context| context.id() == id)
+ pub fn clear(&mut self) {
+ self.context_set.clear();
+ self.context_thread_ids.clear();
}
- pub fn clear(&mut self) {
- self.context.clear();
- self.files.clear();
- self.directories.clear();
- self.threads.clear();
- self.fetched_urls.clear();
- self.user_rules.clear();
+ pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContext> {
+ let existing_context = thread
+ .messages()
+ .flat_map(|message| &message.loaded_context.contexts)
+ .map(AgentContextKey::ref_cast)
+ .collect::<HashSet<_>>();
+ self.context_set
+ .iter()
+ .filter(|context| !existing_context.contains(context))
+ .map(|entry| entry.0.clone())
+ .collect::<Vec<_>>()
}
pub fn add_file_from_path(
@@ -93,241 +83,98 @@ impl ContextStore {
let open_buffer_task = project.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?;
-
let buffer = open_buffer_task.await?;
- let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?;
-
- let already_included = this.update(cx, |this, cx| {
- match this.will_include_buffer(buffer_id, &project_path) {
- Some(FileInclusion::Direct(context_id)) => {
- if remove_if_exists {
- this.remove_context(context_id, cx);
- }
- true
- }
- Some(FileInclusion::InDirectory(_)) => true,
- None => false,
- }
- })?;
-
- if already_included {
- return anyhow::Ok(());
- }
-
- let context_buffer = this
- .update(cx, |_, cx| load_context_buffer(buffer, cx))??
- .await;
-
this.update(cx, |this, cx| {
- this.insert_file(context_buffer, cx);
- })?;
-
- anyhow::Ok(())
+ this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
+ })
})
}
pub fn add_file_from_buffer(
&mut self,
+ project_path: &ProjectPath,
buffer: Entity<Buffer>,
+ remove_if_exists: bool,
cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- cx.spawn(async move |this, cx| {
- let context_buffer = this
- .update(cx, |_, cx| load_context_buffer(buffer, cx))??
- .await;
-
- this.update(cx, |this, cx| this.insert_file(context_buffer, cx))?;
+ ) {
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::File(FileContext { buffer, context_id });
- anyhow::Ok(())
- })
- }
+ let already_included = if self.has_context(&context) {
+ if remove_if_exists {
+ self.remove_context(&context, cx);
+ }
+ true
+ } else {
+ self.path_included_in_directory(project_path, cx).is_some()
+ };
- fn insert_file(&mut self, context_buffer: ContextBuffer, cx: &mut Context<Self>) {
- let id = self.next_context_id.post_inc();
- self.files.insert(context_buffer.id, id);
- self.context
- .push(AssistantContext::File(FileContext { id, context_buffer }));
- cx.notify();
+ if !already_included {
+ self.insert_context(context, cx);
+ }
}
pub fn add_directory(
&mut self,
- project_path: ProjectPath,
+ project_path: &ProjectPath,
remove_if_exists: bool,
cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
+ ) -> Result<()> {
let Some(project) = self.project.upgrade() else {
- return Task::ready(Err(anyhow!("failed to read project")));
+ return Err(anyhow!("failed to read project"));
};
let Some(entry_id) = project
.read(cx)
- .entry_for_path(&project_path, cx)
+ .entry_for_path(project_path, cx)
.map(|entry| entry.id)
else {
- return Task::ready(Err(anyhow!("no entry found for directory context")));
+ return Err(anyhow!("no entry found for directory context"));
};
- let already_included = match self.includes_directory(&project_path) {
- Some(FileInclusion::Direct(context_id)) => {
- if remove_if_exists {
- self.remove_context(context_id, cx);
- }
- true
- }
- Some(FileInclusion::InDirectory(_)) => true,
- None => false,
- };
- if already_included {
- return Task::ready(Ok(()));
- }
-
- let worktree_id = project_path.worktree_id;
- cx.spawn(async move |this, cx| {
- let worktree = project.update(cx, |project, cx| {
- project
- .worktree_for_id(worktree_id, cx)
- .ok_or_else(|| anyhow!("no worktree found for {worktree_id:?}"))
- })??;
-
- let files = worktree.update(cx, |worktree, _cx| {
- collect_files_in_path(worktree, &project_path.path)
- })?;
-
- let open_buffers_task = project.update(cx, |project, cx| {
- let tasks = files.iter().map(|file_path| {
- project.open_buffer(
- ProjectPath {
- worktree_id,
- path: file_path.clone(),
- },
- cx,
- )
- });
- future::join_all(tasks)
- })?;
-
- let buffers = open_buffers_task.await;
-
- let context_buffer_tasks = this.update(cx, |_, cx| {
- buffers
- .into_iter()
- .flatten()
- .flat_map(move |buffer| load_context_buffer(buffer, cx).log_err())
- .collect::<Vec<_>>()
- })?;
-
- let context_buffers = future::join_all(context_buffer_tasks).await;
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::Directory(DirectoryContext {
+ entry_id,
+ context_id,
+ });
- if context_buffers.is_empty() {
- let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?;
- return Err(anyhow!("No text files found in {}", &full_path.display()));
+ if self.has_context(&context) {
+ if remove_if_exists {
+ self.remove_context(&context, cx);
}
+ } else if self.path_included_in_directory(project_path, cx).is_none() {
+ self.insert_context(context, cx);
+ }
- this.update(cx, |this, cx| {
- this.insert_directory(worktree, entry_id, project_path, context_buffers, cx);
- })?;
-
- anyhow::Ok(())
- })
- }
-
- fn insert_directory(
- &mut self,
- worktree: Entity<Worktree>,
- entry_id: ProjectEntryId,
- project_path: ProjectPath,
- context_buffers: Vec<ContextBuffer>,
- cx: &mut Context<Self>,
- ) {
- let id = self.next_context_id.post_inc();
- let last_path = project_path.path.clone();
- self.directories.insert(project_path, id);
-
- self.context
- .push(AssistantContext::Directory(DirectoryContext {
- id,
- worktree,
- entry_id,
- last_path,
- context_buffers,
- }));
- cx.notify();
+ anyhow::Ok(())
}
pub fn add_symbol(
&mut self,
buffer: Entity<Buffer>,
- symbol_name: SharedString,
- symbol_range: Range<Anchor>,
- symbol_enclosing_range: Range<Anchor>,
+ symbol: SharedString,
+ range: Range<Anchor>,
+ enclosing_range: Range<Anchor>,
remove_if_exists: bool,
cx: &mut Context<Self>,
- ) -> Task<Result<bool>> {
- let buffer_ref = buffer.read(cx);
- let Some(project_path) = buffer_ref.project_path(cx) else {
- return Task::ready(Err(anyhow!("buffer has no path")));
- };
-
- if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) {
- let mut matching_symbol_id = None;
- for symbol in symbols_for_path {
- if &symbol.name == &symbol_name {
- let snapshot = buffer_ref.snapshot();
- if symbol.range.to_offset(&snapshot) == symbol_range.to_offset(&snapshot) {
- matching_symbol_id = self.symbols.get(symbol).cloned();
- break;
- }
- }
- }
+ ) -> bool {
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::Symbol(SymbolContext {
+ buffer,
+ symbol,
+ range,
+ enclosing_range,
+ context_id,
+ });
- if let Some(id) = matching_symbol_id {
- if remove_if_exists {
- self.remove_context(id, cx);
- }
- return Task::ready(Ok(false));
+ if self.has_context(&context) {
+ if remove_if_exists {
+ self.remove_context(&context, cx);
}
+ return false;
}
- let context_buffer_task =
- match load_context_buffer_range(buffer, symbol_enclosing_range.clone(), cx) {
- Ok((_line_range, context_buffer_task)) => context_buffer_task,
- Err(err) => return Task::ready(Err(err)),
- };
-
- cx.spawn(async move |this, cx| {
- let context_buffer = context_buffer_task.await;
-
- this.update(cx, |this, cx| {
- this.insert_symbol(
- make_context_symbol(
- context_buffer,
- project_path,
- symbol_name,
- symbol_range,
- symbol_enclosing_range,
- ),
- cx,
- )
- })?;
- anyhow::Ok(true)
- })
- }
-
- fn insert_symbol(&mut self, context_symbol: ContextSymbol, cx: &mut Context<Self>) {
- let id = self.next_context_id.post_inc();
- self.symbols.insert(context_symbol.id.clone(), id);
- self.symbols_by_path
- .entry(context_symbol.id.path.clone())
- .or_insert_with(Vec::new)
- .push(context_symbol.id.clone());
- self.symbol_buffers
- .insert(context_symbol.id.clone(), context_symbol.buffer.clone());
- self.context.push(AssistantContext::Symbol(SymbolContext {
- id,
- context_symbol,
- }));
- cx.notify();
+ self.insert_context(context, cx)
}
pub fn add_thread(
@@ -336,24 +183,23 @@ impl ContextStore {
remove_if_exists: bool,
cx: &mut Context<Self>,
) {
- if let Some(context_id) = self.includes_thread(&thread.read(cx).id()) {
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::Thread(ThreadContext { thread, context_id });
+
+ if self.has_context(&context) {
if remove_if_exists {
- self.remove_context(context_id, cx);
+ self.remove_context(&context, cx);
}
} else {
- self.insert_thread(thread, cx);
+ self.insert_context(context, cx);
}
}
- pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
- let tasks = std::mem::take(&mut self.thread_summary_tasks);
-
- cx.spawn(async move |_cx| {
- join_all(tasks).await;
- })
- }
-
- fn insert_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
+ fn start_summarizing_thread_if_needed(
+ &mut self,
+ thread: &Entity<Thread>,
+ cx: &mut Context<Self>,
+ ) {
if let Some(summary_task) =
thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
{
@@ -374,106 +220,60 @@ impl ContextStore {
}
}));
}
+ }
- let id = self.next_context_id.post_inc();
-
- let text = thread.read(cx).latest_detailed_summary_or_text();
+ pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
+ let tasks = std::mem::take(&mut self.thread_summary_tasks);
- self.threads.insert(thread.read(cx).id().clone(), id);
- self.context
- .push(AssistantContext::Thread(ThreadContext { id, thread, text }));
- cx.notify();
+ cx.spawn(async move |_cx| {
+ join_all(tasks).await;
+ })
}
pub fn add_rules(
&mut self,
prompt_id: UserPromptId,
- title: impl Into<SharedString>,
- text: impl Into<SharedString>,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
- if let Some(context_id) = self.includes_user_rules(&prompt_id) {
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::Rules(RulesContext {
+ prompt_id,
+ context_id,
+ });
+
+ if self.has_context(&context) {
if remove_if_exists {
- self.remove_context(context_id, cx);
+ self.remove_context(&context, cx);
}
} else {
- self.insert_user_rules(prompt_id, title, text, cx);
+ self.insert_context(context, cx);
}
}
- pub fn insert_user_rules(
- &mut self,
- prompt_id: UserPromptId,
- title: impl Into<SharedString>,
- text: impl Into<SharedString>,
- cx: &mut Context<ContextStore>,
- ) {
- let id = self.next_context_id.post_inc();
-
- self.user_rules.insert(prompt_id, id);
- self.context.push(AssistantContext::Rules(RulesContext {
- id,
- prompt_id,
- title: title.into(),
- text: text.into(),
- }));
- cx.notify();
- }
-
pub fn add_fetched_url(
&mut self,
url: String,
text: impl Into<SharedString>,
cx: &mut Context<ContextStore>,
) {
- if self.includes_url(&url).is_none() {
- self.insert_fetched_url(url, text, cx);
- }
- }
-
- fn insert_fetched_url(
- &mut self,
- url: String,
- text: impl Into<SharedString>,
- cx: &mut Context<ContextStore>,
- ) {
- let id = self.next_context_id.post_inc();
+ let context = AgentContext::FetchedUrl(FetchedUrlContext {
+ url: url.into(),
+ text: text.into(),
+ context_id: self.next_context_id.post_inc(),
+ });
- self.fetched_urls.insert(url.clone(), id);
- self.context
- .push(AssistantContext::FetchedUrl(FetchedUrlContext {
- id,
- url: url.into(),
- text: text.into(),
- }));
- cx.notify();
+ self.insert_context(context, cx);
}
pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
- let id = self.next_context_id.post_inc();
- self.context.push(AssistantContext::Image(ImageContext {
- id,
+ let context = AgentContext::Image(ImageContext {
original_image: image,
image_task,
- }));
- cx.notify();
- }
-
- pub fn wait_for_images(&self, cx: &App) -> Task<()> {
- let tasks = self
- .context
- .iter()
- .filter_map(|ctx| match ctx {
- AssistantContext::Image(ctx) => Some(ctx.image_task.clone()),
- _ => None,
- })
- .collect::<Vec<_>>();
-
- cx.spawn(async move |_cx| {
- join_all(tasks).await;
- })
+ context_id: self.next_context_id.post_inc(),
+ });
+ self.insert_context(context, cx);
}
pub fn add_selection(
@@ -481,45 +281,21 @@ impl ContextStore {
buffer: Entity<Buffer>,
range: Range<Anchor>,
cx: &mut Context<ContextStore>,
- ) -> Task<Result<()>> {
- cx.spawn(async move |this, cx| {
- let (line_range, context_buffer_task) = this.update(cx, |_, cx| {
- load_context_buffer_range(buffer, range.clone(), cx)
- })??;
-
- let context_buffer = context_buffer_task.await;
-
- this.update(cx, |this, cx| {
- this.insert_selection(context_buffer, range, line_range, cx)
- })?;
-
- anyhow::Ok(())
- })
- }
-
- fn insert_selection(
- &mut self,
- context_buffer: ContextBuffer,
- range: Range<Anchor>,
- line_range: Range<Point>,
- cx: &mut Context<Self>,
) {
- let id = self.next_context_id.post_inc();
- self.context
- .push(AssistantContext::Selection(SelectionContext {
- id,
- range,
- line_range,
- context_buffer,
- }));
- cx.notify();
+ let context_id = self.next_context_id.post_inc();
+ let context = AgentContext::Selection(SelectionContext {
+ buffer,
+ range,
+ context_id,
+ });
+ self.insert_context(context, cx);
}
- pub fn accept_suggested_context(
+ pub fn add_suggested_context(
&mut self,
suggested: &SuggestedContext,
cx: &mut Context<ContextStore>,
- ) -> Task<Result<()>> {
+ ) {
match suggested {
SuggestedContext::File {
buffer,
@@ -527,655 +303,183 @@ impl ContextStore {
name: _,
} => {
if let Some(buffer) = buffer.upgrade() {
- return self.add_file_from_buffer(buffer, cx);
+ let context_id = self.next_context_id.post_inc();
+ self.insert_context(AgentContext::File(FileContext { buffer, context_id }), cx);
};
}
SuggestedContext::Thread { thread, name: _ } => {
if let Some(thread) = thread.upgrade() {
- self.insert_thread(thread, cx);
- };
+ let context_id = self.next_context_id.post_inc();
+ self.insert_context(
+ AgentContext::Thread(ThreadContext { thread, context_id }),
+ cx,
+ );
+ }
}
}
- Task::ready(Ok(()))
}
- pub fn remove_context(&mut self, id: ContextId, cx: &mut Context<Self>) {
- let Some(ix) = self.context.iter().position(|context| context.id() == id) else {
- return;
- };
-
- match self.context.remove(ix) {
- AssistantContext::File(_) => {
- self.files.retain(|_, context_id| *context_id != id);
- }
- AssistantContext::Directory(_) => {
- self.directories.retain(|_, context_id| *context_id != id);
- }
- AssistantContext::Symbol(symbol) => {
- if let Some(symbols_in_path) =
- self.symbols_by_path.get_mut(&symbol.context_symbol.id.path)
- {
- symbols_in_path.retain(|s| {
- self.symbols
- .get(s)
- .map_or(false, |context_id| *context_id != id)
- });
- }
- self.symbol_buffers.remove(&symbol.context_symbol.id);
- self.symbols.retain(|_, context_id| *context_id != id);
- }
- AssistantContext::Selection(_) => {}
- AssistantContext::FetchedUrl(_) => {
- self.fetched_urls.retain(|_, context_id| *context_id != id);
- }
- AssistantContext::Thread(_) => {
- self.threads.retain(|_, context_id| *context_id != id);
- }
- AssistantContext::Rules(RulesContext { prompt_id, .. }) => {
- self.user_rules.remove(&prompt_id);
+ fn insert_context(&mut self, context: AgentContext, cx: &mut Context<Self>) -> bool {
+ match &context {
+ AgentContext::Thread(thread_context) => {
+ self.context_thread_ids
+ .insert(thread_context.thread.read(cx).id().clone());
+ self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
}
- AssistantContext::Image(_) => {}
+ _ => {}
}
-
- cx.notify();
+ let inserted = self.context_set.insert(AgentContextKey(context));
+ if inserted {
+ cx.notify();
+ }
+ inserted
}
- /// Returns whether the buffer is already included directly in the context, or if it will be
- /// included in the context via a directory. Directory inclusion is based on paths rather than
- /// buffer IDs as the directory will be re-scanned.
- pub fn will_include_buffer(
- &self,
- buffer_id: BufferId,
- project_path: &ProjectPath,
- ) -> Option<FileInclusion> {
- if let Some(context_id) = self.files.get(&buffer_id) {
- return Some(FileInclusion::Direct(*context_id));
+ pub fn remove_context(&mut self, context: &AgentContext, cx: &mut Context<Self>) {
+ if self
+ .context_set
+ .shift_remove(AgentContextKey::ref_cast(context))
+ {
+ match context {
+ AgentContext::Thread(thread_context) => {
+ self.context_thread_ids
+ .remove(thread_context.thread.read(cx).id());
+ }
+ _ => {}
+ }
+ cx.notify();
}
+ }
- self.will_include_file_path_via_directory(project_path)
+ pub fn has_context(&mut self, context: &AgentContext) -> bool {
+ self.context_set
+ .contains(AgentContextKey::ref_cast(context))
}
/// Returns whether this file path is already included directly in the context, or if it will be
/// included in the context via a directory.
- pub fn will_include_file_path(
- &self,
- project_path: &ProjectPath,
- cx: &App,
- ) -> Option<FileInclusion> {
- if !self.files.is_empty() {
- let found_file_context = self.context.iter().find(|context| match &context {
- AssistantContext::File(file_context) => {
- let buffer = file_context.context_buffer.buffer.read(cx);
- if let Some(context_path) = buffer.project_path(cx) {
- &context_path == project_path
- } else {
- false
- }
- }
- _ => false,
- });
- if let Some(context) = found_file_context {
- return Some(FileInclusion::Direct(context.id()));
+ pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
+ let project = self.project.upgrade()?.read(cx);
+ self.context().find_map(|context| match context {
+ AgentContext::File(file_context) => FileInclusion::check_file(file_context, path, cx),
+ AgentContext::Directory(directory_context) => {
+ FileInclusion::check_directory(directory_context, path, project, cx)
}
- }
-
- self.will_include_file_path_via_directory(project_path)
+ _ => None,
+ })
}
- fn will_include_file_path_via_directory(
+ pub fn path_included_in_directory(
&self,
- project_path: &ProjectPath,
+ path: &ProjectPath,
+ cx: &App,
) -> Option<FileInclusion> {
- if self.directories.is_empty() {
- return None;
- }
-
- let mut path_buf = project_path.path.to_path_buf();
-
- while path_buf.pop() {
- // TODO: This isn't very efficient. Consider using a better representation of the
- // directories map.
- let directory_project_path = ProjectPath {
- worktree_id: project_path.worktree_id,
- path: path_buf.clone().into(),
- };
- if let Some(_) = self.directories.get(&directory_project_path) {
- return Some(FileInclusion::InDirectory(directory_project_path));
+ let project = self.project.upgrade()?.read(cx);
+ self.context().find_map(|context| match context {
+ AgentContext::Directory(directory_context) => {
+ FileInclusion::check_directory(directory_context, path, project, cx)
}
- }
-
- None
- }
-
- pub fn includes_directory(&self, project_path: &ProjectPath) -> Option<FileInclusion> {
- if let Some(context_id) = self.directories.get(project_path) {
- return Some(FileInclusion::Direct(*context_id));
- }
-
- self.will_include_file_path_via_directory(project_path)
- }
-
- pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option<ContextId> {
- self.symbols.get(symbol_id).copied()
- }
-
- pub fn included_symbols_by_path(&self) -> &HashMap<ProjectPath, Vec<ContextSymbolId>> {
- &self.symbols_by_path
- }
-
- pub fn buffer_for_symbol(&self, symbol_id: &ContextSymbolId) -> Option<Entity<Buffer>> {
- self.symbol_buffers.get(symbol_id).cloned()
+ _ => None,
+ })
}
- pub fn includes_thread(&self, thread_id: &ThreadId) -> Option<ContextId> {
- self.threads.get(thread_id).copied()
+ pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
+ self.context().any(|context| match context {
+ AgentContext::Symbol(context) => {
+ if context.symbol != symbol.name {
+ return false;
+ }
+ let buffer = context.buffer.read(cx);
+ let Some(context_path) = buffer.project_path(cx) else {
+ return false;
+ };
+ if context_path != symbol.path {
+ return false;
+ }
+ let context_range = context.range.to_point_utf16(&buffer.snapshot());
+ context_range.start == symbol.range.start.0
+ && context_range.end == symbol.range.end.0
+ }
+ _ => false,
+ })
}
- pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option<ContextId> {
- self.user_rules.get(prompt_id).copied()
+ pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
+ self.context_thread_ids.contains(thread_id)
}
- pub fn includes_url(&self, url: &str) -> Option<ContextId> {
- self.fetched_urls.get(url).copied()
+ pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
+ self.context_set
+ .contains(&RulesContext::lookup_key(prompt_id))
}
- /// Replaces the context that matches the ID of the new context, if any match.
- fn replace_context(&mut self, new_context: AssistantContext) {
- let id = new_context.id();
- for context in self.context.iter_mut() {
- if context.id() == id {
- *context = new_context;
- break;
- }
- }
+ pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
+ self.context_set
+ .contains(&FetchedUrlContext::lookup_key(url.into()))
}
pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
- self.context
- .iter()
+ self.context()
.filter_map(|context| match context {
- AssistantContext::File(file) => {
- let buffer = file.context_buffer.buffer.read(cx);
+ AgentContext::File(file) => {
+ let buffer = file.buffer.read(cx);
buffer.project_path(cx)
}
- AssistantContext::Directory(_)
- | AssistantContext::Symbol(_)
- | AssistantContext::Selection(_)
- | AssistantContext::FetchedUrl(_)
- | AssistantContext::Thread(_)
- | AssistantContext::Rules(_)
- | AssistantContext::Image(_) => None,
+ AgentContext::Directory(_)
+ | AgentContext::Symbol(_)
+ | AgentContext::Selection(_)
+ | AgentContext::FetchedUrl(_)
+ | AgentContext::Thread(_)
+ | AgentContext::Rules(_)
+ | AgentContext::Image(_) => None,
})
.collect()
}
- pub fn thread_ids(&self) -> HashSet<ThreadId> {
- self.threads.keys().cloned().collect()
+ pub fn thread_ids(&self) -> &HashSet<ThreadId> {
+ &self.context_thread_ids
}
}
pub enum FileInclusion {
- Direct(ContextId),
- InDirectory(ProjectPath),
-}
-
-fn make_context_symbol(
- context_buffer: ContextBuffer,
- path: ProjectPath,
- name: SharedString,
- range: Range<Anchor>,
- enclosing_range: Range<Anchor>,
-) -> ContextSymbol {
- ContextSymbol {
- id: ContextSymbolId { name, range, path },
- buffer_version: context_buffer.version,
- enclosing_range,
- buffer: context_buffer.buffer,
- text: context_buffer.text,
- }
+ Direct,
+ InDirectory { full_path: PathBuf },
}
-fn load_context_buffer_range(
- buffer: Entity<Buffer>,
- range: Range<Anchor>,
- cx: &App,
-) -> Result<(Range<Point>, Task<ContextBuffer>)> {
- let buffer_ref = buffer.read(cx);
- let id = buffer_ref.remote_id();
-
- let file = buffer_ref.file().context("context buffer missing path")?;
- let full_path = file.full_path(cx);
-
- // Important to collect version at the same time as content so that staleness logic is correct.
- let version = buffer_ref.version();
- let content = buffer_ref.text_for_range(range.clone()).collect::<Rope>();
- let line_range = range.to_point(&buffer_ref.snapshot());
-
- // Build the text on a background thread.
- let task = cx.background_spawn({
- let line_range = line_range.clone();
- async move {
- let text = to_fenced_codeblock(&full_path, content, Some(line_range));
- ContextBuffer {
- id,
- buffer,
- last_full_path: full_path.into(),
- version,
- text,
- }
- }
- });
-
- Ok((line_range, task))
-}
-
-fn load_context_buffer(buffer: Entity<Buffer>, cx: &App) -> Result<Task<ContextBuffer>> {
- let buffer_ref = buffer.read(cx);
- let id = buffer_ref.remote_id();
-
- let file = buffer_ref.file().context("context buffer missing path")?;
- let full_path = file.full_path(cx);
-
- // Important to collect version at the same time as content so that staleness logic is correct.
- let version = buffer_ref.version();
- let content = buffer_ref.as_rope().clone();
-
- // Build the text on a background thread.
- Ok(cx.background_spawn(async move {
- let text = to_fenced_codeblock(&full_path, content, None);
- ContextBuffer {
- id,
- buffer,
- last_full_path: full_path.into(),
- version,
- text,
- }
- }))
-}
-
-fn to_fenced_codeblock(
- path: &Path,
- content: Rope,
- line_range: Option<Range<Point>>,
-) -> SharedString {
- let line_range_text = line_range.map(|range| {
- if range.start.row == range.end.row {
- format!(":{}", range.start.row + 1)
+impl FileInclusion {
+ fn check_file(file_context: &FileContext, path: &ProjectPath, cx: &App) -> Option<Self> {
+ let file_path = file_context.buffer.read(cx).project_path(cx)?;
+ if path == &file_path {
+ Some(FileInclusion::Direct)
} else {
- format!(":{}-{}", range.start.row + 1, range.end.row + 1)
- }
- });
-
- let path_extension = path.extension().and_then(|ext| ext.to_str());
- let path_string = path.to_string_lossy();
- let capacity = 3
- + path_extension.map_or(0, |extension| extension.len() + 1)
- + path_string.len()
- + line_range_text.as_ref().map_or(0, |text| text.len())
- + 1
- + content.len()
- + 5;
- let mut buffer = String::with_capacity(capacity);
-
- buffer.push_str("```");
-
- if let Some(extension) = path_extension {
- buffer.push_str(extension);
- buffer.push(' ');
- }
- buffer.push_str(&path_string);
-
- if let Some(line_range_text) = line_range_text {
- buffer.push_str(&line_range_text);
- }
-
- buffer.push('\n');
- for chunk in content.chunks() {
- buffer.push_str(&chunk);
- }
-
- if !buffer.ends_with('\n') {
- buffer.push('\n');
- }
-
- buffer.push_str("```\n");
-
- debug_assert!(
- buffer.len() == capacity - 1 || buffer.len() == capacity,
- "to_fenced_codeblock calculated capacity of {}, but length was {}",
- capacity,
- buffer.len(),
- );
-
- buffer.into()
-}
-
-fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
- let mut files = Vec::new();
-
- for entry in worktree.child_entries(path) {
- if entry.is_dir() {
- files.extend(collect_files_in_path(worktree, &entry.path));
- } else if entry.is_file() {
- files.push(entry.path.clone());
- }
- }
-
- files
-}
-
-pub fn refresh_context_store_text(
- context_store: Entity<ContextStore>,
- changed_buffers: &HashSet<Entity<Buffer>>,
- cx: &App,
-) -> impl Future<Output = Vec<ContextId>> + use<> {
- let mut tasks = Vec::new();
-
- for context in &context_store.read(cx).context {
- let id = context.id();
-
- let task = maybe!({
- match context {
- AssistantContext::File(file_context) => {
- // TODO: Should refresh if the path has changed, as it's in the text.
- if changed_buffers.is_empty()
- || changed_buffers.contains(&file_context.context_buffer.buffer)
- {
- let context_store = context_store.clone();
- return refresh_file_text(context_store, file_context, cx);
- }
- }
- AssistantContext::Directory(directory_context) => {
- let directory_path = directory_context.project_path(cx)?;
- let should_refresh = directory_path.path != directory_context.last_path
- || changed_buffers.is_empty()
- || changed_buffers.iter().any(|buffer| {
- let Some(buffer_path) = buffer.read(cx).project_path(cx) else {
- return false;
- };
- buffer_path.starts_with(&directory_path)
- });
-
- if should_refresh {
- let context_store = context_store.clone();
- return refresh_directory_text(
- context_store,
- directory_context,
- directory_path,
- cx,
- );
- }
- }
- AssistantContext::Symbol(symbol_context) => {
- // TODO: Should refresh if the path has changed, as it's in the text.
- if changed_buffers.is_empty()
- || changed_buffers.contains(&symbol_context.context_symbol.buffer)
- {
- let context_store = context_store.clone();
- return refresh_symbol_text(context_store, symbol_context, cx);
- }
- }
- AssistantContext::Selection(selection_context) => {
- // TODO: Should refresh if the path has changed, as it's in the text.
- if changed_buffers.is_empty()
- || changed_buffers.contains(&selection_context.context_buffer.buffer)
- {
- let context_store = context_store.clone();
- return refresh_selection_text(context_store, selection_context, cx);
- }
- }
- AssistantContext::Thread(thread_context) => {
- if changed_buffers.is_empty() {
- let context_store = context_store.clone();
- return Some(refresh_thread_text(context_store, thread_context, cx));
- }
- }
- // Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
- // and doing the caching properly could be tricky (unless it's already handled by
- // the HttpClient?).
- AssistantContext::FetchedUrl(_) => {}
- AssistantContext::Rules(user_rules_context) => {
- let context_store = context_store.clone();
- return Some(refresh_user_rules(context_store, user_rules_context, cx));
- }
- AssistantContext::Image(_) => {}
- }
-
None
- });
-
- if let Some(task) = task {
- tasks.push(task.map(move |_| id));
}
}
- future::join_all(tasks)
-}
-
-fn refresh_file_text(
- context_store: Entity<ContextStore>,
- file_context: &FileContext,
- cx: &App,
-) -> Option<Task<()>> {
- let id = file_context.id;
- let task = refresh_context_buffer(&file_context.context_buffer, cx);
- if let Some(task) = task {
- Some(cx.spawn(async move |cx| {
- let context_buffer = task.await;
- context_store
- .update(cx, |context_store, _| {
- let new_file_context = FileContext { id, context_buffer };
- context_store.replace_context(AssistantContext::File(new_file_context));
- })
- .ok();
- }))
- } else {
- None
- }
-}
-
-fn refresh_directory_text(
- context_store: Entity<ContextStore>,
- directory_context: &DirectoryContext,
- directory_path: ProjectPath,
- cx: &App,
-) -> Option<Task<()>> {
- let mut stale = false;
- let futures = directory_context
- .context_buffers
- .iter()
- .map(|context_buffer| {
- if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) {
- stale = true;
- future::Either::Left(refresh_task)
+ fn check_directory(
+ directory_context: &DirectoryContext,
+ path: &ProjectPath,
+ project: &Project,
+ cx: &App,
+ ) -> Option<Self> {
+ let worktree = project
+ .worktree_for_entry(directory_context.entry_id, cx)?
+ .read(cx);
+ let entry = worktree.entry_for_id(directory_context.entry_id)?;
+ let directory_path = ProjectPath {
+ worktree_id: worktree.id(),
+ path: entry.path.clone(),
+ };
+ if path.starts_with(&directory_path) {
+ if path == &directory_path {
+ Some(FileInclusion::Direct)
} else {
- future::Either::Right(future::ready((*context_buffer).clone()))
- }
- })
- .collect::<Vec<_>>();
-
- if !stale {
- return None;
- }
-
- let context_buffers = future::join_all(futures);
-
- let id = directory_context.id;
- let worktree = directory_context.worktree.clone();
- let entry_id = directory_context.entry_id;
- let last_path = directory_path.path;
- Some(cx.spawn(async move |cx| {
- let context_buffers = context_buffers.await;
- context_store
- .update(cx, |context_store, _| {
- let new_directory_context = DirectoryContext {
- id,
- worktree,
- entry_id,
- last_path,
- context_buffers,
- };
- context_store.replace_context(AssistantContext::Directory(new_directory_context));
- })
- .ok();
- }))
-}
-
-fn refresh_symbol_text(
- context_store: Entity<ContextStore>,
- symbol_context: &SymbolContext,
- cx: &App,
-) -> Option<Task<()>> {
- let id = symbol_context.id;
- let task = refresh_context_symbol(&symbol_context.context_symbol, cx);
- if let Some(task) = task {
- Some(cx.spawn(async move |cx| {
- let context_symbol = task.await;
- context_store
- .update(cx, |context_store, _| {
- let new_symbol_context = SymbolContext { id, context_symbol };
- context_store.replace_context(AssistantContext::Symbol(new_symbol_context));
- })
- .ok();
- }))
- } else {
- None
- }
-}
-
-fn refresh_selection_text(
- context_store: Entity<ContextStore>,
- selection_context: &SelectionContext,
- cx: &App,
-) -> Option<Task<()>> {
- let id = selection_context.id;
- let range = selection_context.range.clone();
- let task = refresh_context_excerpt(&selection_context.context_buffer, range.clone(), cx);
- if let Some(task) = task {
- Some(cx.spawn(async move |cx| {
- let (line_range, context_buffer) = task.await;
- context_store
- .update(cx, |context_store, _| {
- let new_selection_context = SelectionContext {
- id,
- range,
- line_range,
- context_buffer,
- };
- context_store
- .replace_context(AssistantContext::Selection(new_selection_context));
+ Some(FileInclusion::InDirectory {
+ full_path: worktree.full_path(&entry.path),
})
- .ok();
- }))
- } else {
- None
- }
-}
-
-fn refresh_thread_text(
- context_store: Entity<ContextStore>,
- thread_context: &ThreadContext,
- cx: &App,
-) -> Task<()> {
- let id = thread_context.id;
- let thread = thread_context.thread.clone();
- cx.spawn(async move |cx| {
- context_store
- .update(cx, |context_store, cx| {
- let text = thread.read(cx).latest_detailed_summary_or_text();
- context_store.replace_context(AssistantContext::Thread(ThreadContext {
- id,
- thread,
- text,
- }));
- })
- .ok();
- })
-}
-
-fn refresh_user_rules(
- context_store: Entity<ContextStore>,
- user_rules_context: &RulesContext,
- cx: &App,
-) -> Task<()> {
- let id = user_rules_context.id;
- let prompt_id = user_rules_context.prompt_id;
- let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else {
- return Task::ready(());
- };
- let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| {
- thread_store.load_rules(prompt_id, cx)
- }) else {
- return Task::ready(());
- };
- cx.spawn(async move |cx| {
- if let Ok((metadata, text)) = load_task.await {
- if let Some(title) = metadata.title.clone() {
- context_store
- .update(cx, |context_store, _cx| {
- context_store.replace_context(AssistantContext::Rules(RulesContext {
- id,
- prompt_id,
- title,
- text: text.into(),
- }));
- })
- .ok();
- return;
}
+ } else {
+ None
}
- context_store
- .update(cx, |context_store, cx| {
- context_store.remove_context(id, cx);
- })
- .ok();
- })
-}
-
-fn refresh_context_buffer(context_buffer: &ContextBuffer, cx: &App) -> Option<Task<ContextBuffer>> {
- let buffer = context_buffer.buffer.read(cx);
- if buffer.version.changed_since(&context_buffer.version) {
- load_context_buffer(context_buffer.buffer.clone(), cx).log_err()
- } else {
- None
- }
-}
-
-fn refresh_context_excerpt(
- context_buffer: &ContextBuffer,
- range: Range<Anchor>,
- cx: &App,
-) -> Option<impl Future<Output = (Range<Point>, ContextBuffer)> + use<>> {
- let buffer = context_buffer.buffer.read(cx);
- if buffer.version.changed_since(&context_buffer.version) {
- let (line_range, context_buffer_task) =
- load_context_buffer_range(context_buffer.buffer.clone(), range, cx).log_err()?;
- Some(context_buffer_task.map(move |context_buffer| (line_range, context_buffer)))
- } else {
- None
- }
-}
-
-fn refresh_context_symbol(
- context_symbol: &ContextSymbol,
- cx: &App,
-) -> Option<impl Future<Output = ContextSymbol> + use<>> {
- let buffer = context_symbol.buffer.read(cx);
- let project_path = buffer.project_path(cx)?;
- if buffer.version.changed_since(&context_symbol.buffer_version) {
- let (_line_range, context_buffer_task) = load_context_buffer_range(
- context_symbol.buffer.clone(),
- context_symbol.enclosing_range.clone(),
- cx,
- )
- .log_err()?;
- let name = context_symbol.id.name.clone();
- let range = context_symbol.id.range.clone();
- let enclosing_range = context_symbol.enclosing_range.clone();
- Some(context_buffer_task.map(move |context_buffer| {
- make_context_symbol(context_buffer, project_path, name, range, enclosing_range)
- }))
- } else {
- None
}
}
@@ -12,9 +12,9 @@ use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
-use workspace::{Workspace, notifications::NotifyResultExt};
+use workspace::Workspace;
-use crate::context::{ContextId, ContextKind};
+use crate::context::{AgentContext, ContextKind};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
@@ -32,6 +32,7 @@ pub struct ContextStrip {
focus_handle: FocusHandle,
suggest_context_kind: SuggestContextKind,
workspace: WeakEntity<Workspace>,
+ thread_store: Option<WeakEntity<ThreadStore>>,
_subscriptions: Vec<Subscription>,
focused_index: Option<usize>,
children_bounds: Option<Vec<Bounds<Pixels>>>,
@@ -73,12 +74,31 @@ impl ContextStrip {
focus_handle,
suggest_context_kind,
workspace,
+ thread_store,
_subscriptions: subscriptions,
focused_index: None,
children_bounds: None,
}
}
+ fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
+ if let Some(workspace) = self.workspace.upgrade() {
+ let project = workspace.read(cx).project().read(cx);
+ let prompt_store = self
+ .thread_store
+ .as_ref()
+ .and_then(|thread_store| thread_store.upgrade())
+ .and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
+ self.context_store
+ .read(cx)
+ .context()
+ .flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
+ .collect::<Vec<_>>()
+ } else {
+ Vec::new()
+ }
+ }
+
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
match self.suggest_context_kind {
SuggestContextKind::File => self.suggested_file(cx),
@@ -93,22 +113,19 @@ impl ContextStrip {
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
-
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
- .will_include_buffer(active_buffer.remote_id(), &project_path)
+ .file_path_included(&project_path, cx)
.is_some()
{
return None;
}
let file_name = active_buffer.file()?.file_name(cx);
-
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
-
Some(SuggestedContext::File {
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
@@ -135,7 +152,6 @@ impl ContextStrip {
.context_store
.read(cx)
.includes_thread(active_thread.id())
- .is_some()
{
return None;
}
@@ -272,12 +288,12 @@ impl ContextStrip {
best.map(|(index, _, _)| index)
}
- fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
+ fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
- crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
+ crate::active_thread::open_context(context, workspace, window, cx);
}
fn remove_focused_context(
@@ -287,17 +303,17 @@ impl ContextStrip {
cx: &mut Context<Self>,
) {
if let Some(index) = self.focused_index {
- let mut is_empty = false;
+ let added_contexts = self.added_contexts(cx);
+ let Some(context) = added_contexts.get(index) else {
+ return;
+ };
self.context_store.update(cx, |this, cx| {
- if let Some(item) = this.context().get(index) {
- this.remove_context(item.id(), cx);
- }
-
- is_empty = this.context().is_empty();
+ this.remove_context(&context.context, cx);
});
- if is_empty {
+ let is_now_empty = added_contexts.len() == 1;
+ if is_now_empty {
cx.emit(ContextStripEvent::BlurredEmpty);
} else {
self.focused_index = Some(index.saturating_sub(1));
@@ -306,49 +322,28 @@ impl ContextStrip {
}
}
- fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
+ fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
// We only suggest one item after the actual context
- self.focused_index == Some(context.len())
+ self.focused_index == Some(added_contexts.len())
}
fn accept_suggested_context(
&mut self,
_: &AcceptSuggestedContext,
- window: &mut Window,
+ _window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(suggested) = self.suggested_context(cx) {
- let context_store = self.context_store.read(cx);
-
- if self.is_suggested_focused(context_store.context()) {
- self.add_suggested_context(&suggested, window, cx);
+ if self.is_suggested_focused(&self.added_contexts(cx)) {
+ self.add_suggested_context(&suggested, cx);
}
}
}
- fn add_suggested_context(
- &mut self,
- suggested: &SuggestedContext,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let task = self.context_store.update(cx, |context_store, cx| {
- context_store.accept_suggested_context(&suggested, cx)
+ fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
+ self.context_store.update(cx, |context_store, cx| {
+ context_store.add_suggested_context(&suggested, cx)
});
-
- cx.spawn_in(window, async move |this, cx| {
- match task.await.notify_async_err(cx) {
- None => {}
- Some(()) => {
- if let Some(this) = this.upgrade() {
- this.update(cx, |_, cx| cx.notify())?;
- }
- }
- }
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
-
cx.notify();
}
}
@@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let context_store = self.context_store.read(cx);
- let context = context_store.context();
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
- let suggested_context = self.suggested_context(cx);
-
- let added_contexts = context
- .iter()
- .map(|c| AddedContext::new(c, cx))
- .collect::<Vec<_>>();
+ let added_contexts = self.added_contexts(cx);
let dupe_names = added_contexts
.iter()
.map(|c| c.name.clone())
@@ -380,6 +368,14 @@ impl Render for ContextStrip {
.filter(|(a, b)| a == b)
.map(|(a, _)| a)
.collect::<HashSet<SharedString>>();
+ let no_added_context = added_contexts.is_empty();
+
+ let suggested_context = self.suggested_context(cx).map(|suggested_context| {
+ (
+ suggested_context,
+ self.is_suggested_focused(&added_contexts),
+ )
+ });
h_flex()
.flex_wrap()
@@ -436,7 +432,7 @@ impl Render for ContextStrip {
})
.with_handle(self.context_picker_menu_handle.clone()),
)
- .when(context.is_empty() && suggested_context.is_none(), {
+ .when(no_added_context && suggested_context.is_none(), {
|parent| {
parent.child(
h_flex()
@@ -466,16 +462,17 @@ impl Render for ContextStrip {
.enumerate()
.map(|(i, added_context)| {
let name = added_context.name.clone();
- let id = added_context.id;
+ let context = added_context.context.clone();
ContextPill::added(
added_context,
dupe_names.contains(&name),
self.focused_index == Some(i),
Some({
+ let context = context.clone();
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, cx| {
- this.remove_context(id, cx);
+ this.remove_context(&context, cx);
});
cx.notify();
}))
@@ -484,7 +481,7 @@ impl Render for ContextStrip {
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
- this.open_context(id, window, cx);
+ this.open_context(&context, window, cx);
} else {
this.focused_index = Some(i);
}
@@ -493,22 +490,22 @@ impl Render for ContextStrip {
})
}),
)
- .when_some(suggested_context, |el, suggested| {
+ .when_some(suggested_context, |el, (suggested, focused)| {
el.child(
ContextPill::suggested(
suggested.name().clone(),
suggested.icon_path(),
suggested.kind(),
- self.is_suggested_focused(&context),
+ focused,
)
.on_click(Rc::new(cx.listener(
- move |this, _event, window, cx| {
- this.add_suggested_context(&suggested, window, cx);
+ move |this, _event, _window, cx| {
+ this.add_suggested_context(&suggested, cx);
},
))),
)
})
- .when(!context.is_empty(), {
+ .when(!no_added_context, {
move |parent| {
parent.child(
IconButton::new("remove-all-context", IconName::Eraser)
@@ -534,6 +531,7 @@ impl Render for ContextStrip {
)
}
})
+ .into_any()
}
}
@@ -51,7 +51,10 @@ impl HistoryStore {
return history_entries;
}
- for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
+ for thread in self
+ .thread_store
+ .update(cx, |this, _cx| this.reverse_chronological_threads())
+ {
history_entries.push(HistoryEntry::Thread(thread));
}
@@ -32,6 +32,7 @@ use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
+use prompt_store::PromptStore;
use settings::{Settings, SettingsStore};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
@@ -245,9 +246,13 @@ impl InlineAssistant {
.map_or(false, |model| model.provider.is_authenticated(cx))
};
- let thread_store = workspace
+ let assistant_panel = workspace
.panel::<AssistantPanel>(cx)
- .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
+ .map(|assistant_panel| assistant_panel.read(cx));
+ let prompt_store = assistant_panel
+ .and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
+ let thread_store =
+ assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
let handle_assist =
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
@@ -257,6 +262,7 @@ impl InlineAssistant {
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
+ prompt_store,
thread_store,
window,
cx,
@@ -269,6 +275,7 @@ impl InlineAssistant {
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
+ prompt_store,
thread_store,
window,
cx,
@@ -323,6 +330,7 @@ impl InlineAssistant {
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -437,6 +445,8 @@ impl InlineAssistant {
range.clone(),
None,
context_store.clone(),
+ project.clone(),
+ prompt_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -525,6 +535,7 @@ impl InlineAssistant {
initial_transaction_id: Option<TransactionId>,
focus: bool,
workspace: Entity<Workspace>,
+ prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -543,7 +554,7 @@ impl InlineAssistant {
}
let project = workspace.read(cx).project().downgrade();
- let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
+ let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
let codegen = cx.new(|cx| {
BufferCodegen::new(
@@ -551,6 +562,8 @@ impl InlineAssistant {
range.clone(),
initial_transaction_id,
context_store.clone(),
+ project,
+ prompt_store,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
+ let prompt_store = PromptStore::global(cx);
window.spawn(cx, async move |cx| {
let workspace = workspace.upgrade().context("workspace was released")?;
let editor = editor.upgrade().context("editor was released")?;
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
})?
.context("invalid range")?;
+ let prompt_store = prompt_store.await.ok();
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
let assist_id = assistant.suggest_assist(
&editor,
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
None,
true,
workspace,
+ prompt_store,
thread_store,
window,
cx,
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
-use crate::context::{AssistantContext, format_context_as_string};
+use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
@@ -13,6 +13,8 @@ use editor::{
};
use file_icons::FileIcons;
use fs::Fs;
+use futures::future::Shared;
+use futures::{FutureExt as _, future};
use gpui::{
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
@@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
+use prompt_store::PromptStore;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@@ -31,7 +34,7 @@ use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
-use crate::context_store::{ContextStore, refresh_context_store_text};
+use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
@@ -49,9 +52,12 @@ pub struct MessageEditor {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
+ last_loaded_context: Option<ContextLoadResult>,
+ context_load_task: Option<Shared<Task<()>>>,
profile_selector: Entity<ProfileSelector>,
edits_expanded: bool,
editor_is_expanded: bool,
@@ -68,6 +74,7 @@ impl MessageEditor {
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
@@ -135,13 +142,11 @@ impl MessageEditor {
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
- EditorEvent::BufferEdited => {
- this.message_or_context_changed(true, cx);
- }
+ EditorEvent::BufferEdited => this.handle_message_changed(cx),
_ => {}
}),
cx.observe(&context_store, |this, _, cx| {
- this.message_or_context_changed(false, cx);
+ this.handle_context_changed(cx)
}),
];
@@ -152,8 +157,11 @@ impl MessageEditor {
incompatible_tools_state: incompatible_tools.clone(),
workspace,
context_store,
+ prompt_store,
context_strip,
context_picker_menu_handle,
+ context_load_task: None,
+ last_loaded_context: None,
model_selector: cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
@@ -175,6 +183,10 @@ impl MessageEditor {
}
}
+ pub fn context_store(&self) -> &Entity<ContextStore> {
+ &self.context_store
+ }
+
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
@@ -214,6 +226,7 @@ impl MessageEditor {
) {
self.context_picker_menu_handle.toggle(window, cx);
}
+
pub fn remove_all_context(
&mut self,
_: &RemoveAllContext,
@@ -270,57 +283,44 @@ impl MessageEditor {
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
- let refresh_task =
- refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
- let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
-
let thread = self.thread.clone();
- let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
+ let context_task = self.wait_for_context(cx);
let window_handle = window.window_handle();
- cx.spawn(async move |this, cx| {
- let checkpoint = checkpoint.await.ok();
- refresh_task.await;
- wait_for_images.await;
+ cx.spawn(async move |_this, cx| {
+ let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
+ let loaded_context = loaded_context.unwrap_or_default();
thread
.update(cx, |thread, cx| {
- let context = context_store.read(cx).context().clone();
- thread.insert_user_message(user_message, context, checkpoint, cx);
+ thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
})
.log_err();
- context_store
- .update(cx, |context_store, cx| {
- let excerpt_ids = context_store
- .context()
- .iter()
- .filter(|ctx| {
- matches!(
- ctx,
- AssistantContext::Selection(_) | AssistantContext::Image(_)
- )
- })
- .map(|ctx| ctx.id())
- .collect::<Vec<_>>();
-
- for id in excerpt_ids {
- context_store.remove_context(id, cx);
- }
+ thread
+ .update(cx, |thread, cx| {
+ thread.advance_prompt_id();
+ thread.send_to_model(model, Some(window_handle), cx);
})
.log_err();
+ })
+ .detach();
+ }
+ fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
+ let context_store = self.context_store.clone();
+ cx.spawn(async move |this, cx| {
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
- .log_err()
+ .ok()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
- .log_err();
+ .ok();
wait_for_summaries.await;
@@ -328,18 +328,9 @@ impl MessageEditor {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
- .log_err();
+ .ok();
}
-
- // Send to model after summaries are done
- thread
- .update(cx, |thread, cx| {
- thread.advance_prompt_id();
- thread.send_to_model(model, Some(window_handle), cx);
- })
- .log_err();
})
- .detach();
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1015,6 +1006,49 @@ impl MessageEditor {
self.update_token_count_task.is_some()
}
+ fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
+ self.message_or_context_changed(true, cx);
+ }
+
+ fn handle_context_changed(&mut self, cx: &mut Context<Self>) {
+ let summaries_task = self.wait_for_summaries(cx);
+ let load_task = cx.spawn(async move |this, cx| {
+ // Waits for detailed summaries before `load_context`, as it directly reads these from
+ // the thread. TODO: Would be cleaner to have context loading await on summarization.
+ summaries_task.await;
+ let Ok(load_task) = this.update(cx, |this, cx| {
+ let new_context = this.context_store.read_with(cx, |context_store, cx| {
+ context_store.new_context_for_thread(this.thread.read(cx))
+ });
+ load_context(new_context, &this.project, &this.prompt_store, cx)
+ }) else {
+ return;
+ };
+ let result = load_task.await;
+ this.update(cx, |this, cx| {
+ this.last_loaded_context = Some(result);
+ this.context_load_task = None;
+ this.message_or_context_changed(false, cx);
+ })
+ .ok();
+ });
+ // Replace existing load task, if any, causing it to be cancelled.
+ self.context_load_task = Some(load_task.shared());
+ }
+
+ fn wait_for_context(&self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
+ if let Some(context_load_task) = self.context_load_task.clone() {
+ cx.spawn(async move |this, cx| {
+ context_load_task.await;
+ this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
+ .ok()
+ .flatten()
+ })
+ } else {
+ Task::ready(self.last_loaded_context.clone())
+ }
+ }
+
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
@@ -1024,9 +1058,7 @@ impl MessageEditor {
return;
};
- let context_store = self.context_store.clone();
let editor = self.editor.clone();
- let thread = self.thread.clone();
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
@@ -1035,27 +1067,33 @@ impl MessageEditor {
.await;
}
- let token_count = if let Some(task) = cx.update(|cx| {
- let context = context_store.read(cx).context().iter();
- let new_context = thread.read(cx).filter_new_context(context);
- let context_text =
- format_context_as_string(new_context, cx).unwrap_or(String::new());
+ let token_count = if let Some(task) = this.update(cx, |this, cx| {
+ let loaded_context = this
+ .last_loaded_context
+ .as_ref()
+ .map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
- let content = context_text + &message_text;
-
- if content.is_empty() {
+ if message_text.is_empty()
+ && loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
+ {
return None;
}
+ let mut request_message = LanguageModelRequestMessage {
+ role: language_model::Role::User,
+ content: Vec::new(),
+ cache: false,
+ };
+
+ if let Some(loaded_context) = loaded_context {
+ loaded_context.add_to_request_message(&mut request_message);
+ }
+
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
- messages: vec![LanguageModelRequestMessage {
- role: language_model::Role::User,
- content: vec![content.into()],
- cache: false,
- }],
+ messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
@@ -32,7 +32,7 @@ impl TerminalCodegen {
}
}
- pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
+ pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
@@ -45,6 +45,7 @@ impl TerminalCodegen {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
+ let prompt = prompt_task.await;
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;
@@ -1,4 +1,4 @@
-use crate::context::attach_context_to_message;
+use crate::context::load_context;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::{
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
-use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
+use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
-use prompt_store::PromptBuilder;
+use prompt_store::{PromptBuilder, PromptStore};
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::TerminalView;
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
terminal_view: &Entity<TerminalView>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
+ prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
prompt_editor,
workspace.clone(),
context_store,
+ prompt_store,
window,
cx,
);
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
.log_err();
let codegen = assist.codegen.clone();
- let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
+ let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
return;
};
- codegen.update(cx, |codegen, cx| codegen.start(request, cx));
+ codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
}
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
&self,
assist_id: TerminalInlineAssistId,
cx: &mut App,
- ) -> Result<LanguageModelRequest> {
+ ) -> Result<Task<LanguageModelRequest>> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let shell = std::env::var("SHELL").ok();
@@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
&latest_output,
)?;
- let mut request_message = LanguageModelRequestMessage {
- role: Role::User,
- content: vec![],
- cache: false,
- };
-
- attach_context_to_message(
- &mut request_message,
- assist.context_store.read(cx).context().iter(),
- cx,
- );
-
- request_message.content.push(prompt.into());
-
- Ok(LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- messages: vec![request_message],
- tools: Vec::new(),
- stop: Vec::new(),
- temperature: None,
- })
+ let contexts = assist
+ .context_store
+ .read(cx)
+ .context()
+ .cloned()
+ .collect::<Vec<_>>();
+ let context_load_task = assist.workspace.update(cx, |workspace, cx| {
+ let project = workspace.project();
+ load_context(contexts, project, &assist.prompt_store, cx)
+ })?;
+
+ Ok(cx.background_spawn(async move {
+ let mut request_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![],
+ cache: false,
+ };
+
+ context_load_task
+ .await
+ .loaded_context
+ .add_to_request_message(&mut request_message);
+
+ request_message.content.push(prompt.into());
+
+ LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ messages: vec![request_message],
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: None,
+ }
+ }))
}
fn finish_assist(
@@ -380,6 +394,7 @@ struct TerminalInlineAssist {
codegen: Entity<TerminalCodegen>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -390,6 +405,7 @@ impl TerminalInlineAssist {
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
+ prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut App,
) -> Self {
@@ -400,6 +416,7 @@ impl TerminalInlineAssist {
codegen: codegen.clone(),
workspace: workspace.clone(),
context_store,
+ prompt_store,
_subscriptions: vec![
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
TerminalInlineAssistant::update_global(cx, |this, cx| {
@@ -8,7 +8,7 @@ use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
-use collections::{BTreeMap, HashMap};
+use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
@@ -18,9 +18,9 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
- LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
- LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
+ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
};
@@ -35,7 +35,7 @@ use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
-use crate::context::{AssistantContext, ContextId, format_context_as_string};
+use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
@@ -98,8 +98,7 @@ pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
- pub context: String,
- pub images: Vec<LanguageModelImage>,
+ pub loaded_context: LoadedContext,
}
impl Message {
@@ -138,8 +137,8 @@ impl Message {
pub fn to_string(&self) -> String {
let mut result = String::new();
- if !self.context.is_empty() {
- result.push_str(&self.context);
+ if !self.loaded_context.text.is_empty() {
+ result.push_str(&self.loaded_context.text);
}
for segment in &self.segments {
@@ -294,8 +293,6 @@ pub struct Thread {
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
- context: BTreeMap<ContextId, AssistantContext>,
- context_by_message: HashMap<MessageId, Vec<ContextId>>,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
@@ -345,8 +342,6 @@ impl Thread {
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
- context: BTreeMap::default(),
- context_by_message: HashMap::default(),
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@@ -418,14 +413,15 @@ impl Thread {
}
})
.collect(),
- context: message.context,
- images: Vec::new(),
+ loaded_context: LoadedContext {
+ contexts: Vec::new(),
+ text: message.context,
+ images: Vec::new(),
+ },
})
.collect(),
next_message_id,
last_prompt_id: PromptId::new(),
- context: BTreeMap::default(),
- context_by_message: HashMap::default(),
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@@ -660,21 +656,17 @@ impl Thread {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
- self.context_by_message.remove(&deleted_message.id);
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
- pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
- self.context_by_message
- .get(&id)
+ pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
+ self.messages
+ .iter()
+ .find(|message| message.id == id)
.into_iter()
- .flat_map(|context| {
- context
- .iter()
- .filter_map(|context_id| self.context.get(&context_id))
- })
+ .flat_map(|message| message.loaded_context.contexts.iter())
}
pub fn is_turn_end(&self, ix: usize) -> bool {
@@ -736,91 +728,27 @@ impl Thread {
self.tool_use.tool_result_card(id).cloned()
}
- /// Filter out contexts that have already been included in previous messages
- pub fn filter_new_context<'a>(
- &self,
- context: impl Iterator<Item = &'a AssistantContext>,
- ) -> impl Iterator<Item = &'a AssistantContext> {
- context.filter(|ctx| self.is_context_new(ctx))
- }
-
- fn is_context_new(&self, context: &AssistantContext) -> bool {
- !self.context.contains_key(&context.id())
- }
-
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
- context: Vec<AssistantContext>,
+ loaded_context: ContextLoadResult,
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
- let text = text.into();
-
- let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
-
- let new_context: Vec<_> = context
- .into_iter()
- .filter(|ctx| self.is_context_new(ctx))
- .collect();
-
- if !new_context.is_empty() {
- if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
- if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
- message.context = context_string;
- }
- }
-
- if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
- message.images = new_context
- .iter()
- .filter_map(|context| {
- if let AssistantContext::Image(image_context) = context {
- image_context.image_task.clone().now_or_never().flatten()
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
- }
-
+ if !loaded_context.referenced_buffers.is_empty() {
self.action_log.update(cx, |log, cx| {
- // Track all buffers added as context
- for ctx in &new_context {
- match ctx {
- AssistantContext::File(file_ctx) => {
- log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
- }
- AssistantContext::Directory(dir_ctx) => {
- for context_buffer in &dir_ctx.context_buffers {
- log.track_buffer(context_buffer.buffer.clone(), cx);
- }
- }
- AssistantContext::Symbol(symbol_ctx) => {
- log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
- }
- AssistantContext::Selection(selection_context) => {
- log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
- }
- AssistantContext::FetchedUrl(_)
- | AssistantContext::Thread(_)
- | AssistantContext::Rules(_)
- | AssistantContext::Image(_) => {}
- }
+ for buffer in loaded_context.referenced_buffers {
+ log.track_buffer(buffer, cx);
}
});
}
- let context_ids = new_context
- .iter()
- .map(|context| context.id())
- .collect::<Vec<_>>();
- self.context.extend(
- new_context
- .into_iter()
- .map(|context| (context.id(), context)),
+ let message_id = self.insert_message(
+ Role::User,
+ vec![MessageSegment::Text(text.into())],
+ loaded_context.loaded_context,
+ cx,
);
- self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
@@ -834,10 +762,19 @@ impl Thread {
message_id
}
+ pub fn insert_assistant_message(
+ &mut self,
+ segments: Vec<MessageSegment>,
+ cx: &mut Context<Self>,
+ ) -> MessageId {
+ self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
+ }
+
pub fn insert_message(
&mut self,
role: Role,
segments: Vec<MessageSegment>,
+ loaded_context: LoadedContext,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
@@ -845,8 +782,7 @@ impl Thread {
id,
role,
segments,
- context: String::new(),
- images: Vec::new(),
+ loaded_context,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@@ -875,7 +811,6 @@ impl Thread {
return false;
};
self.messages.remove(index);
- self.context_by_message.remove(&id);
self.touch_updated_at();
cx.emit(ThreadEvent::MessageDeleted(id));
true
@@ -962,7 +897,7 @@ impl Thread {
content: tool_result.content.clone(),
})
.collect(),
- context: message.context.clone(),
+ context: message.loaded_context.text.clone(),
})
.collect(),
initial_project_snapshot,
@@ -1080,26 +1015,9 @@ impl Thread {
cache: false,
};
- if !message.context.is_empty() {
- request_message
- .content
- .push(MessageContent::Text(message.context.to_string()));
- }
-
- if !message.images.is_empty() {
- // Some providers only support image parts after an initial text part
- if request_message.content.is_empty() {
- request_message
- .content
- .push(MessageContent::Text("Images attached by user:".to_string()));
- }
-
- for image in &message.images {
- request_message
- .content
- .push(MessageContent::Image(image.clone()))
- }
- }
+ message
+ .loaded_context
+ .add_to_request_message(&mut request_message);
for segment in &message.segments {
match segment {
@@ -1301,11 +1219,11 @@ impl Thread {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
- request_assistant_message_id = Some(thread.insert_message(
- Role::Assistant,
- vec![MessageSegment::Text(String::new())],
- cx,
- ));
+ request_assistant_message_id =
+ Some(thread.insert_assistant_message(
+ vec![MessageSegment::Text(String::new())],
+ cx,
+ ));
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@@ -1334,11 +1252,11 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
- request_assistant_message_id = Some(thread.insert_message(
- Role::Assistant,
- vec![MessageSegment::Text(chunk.to_string())],
- cx,
- ));
+ request_assistant_message_id =
+ Some(thread.insert_assistant_message(
+ vec![MessageSegment::Text(chunk.to_string())],
+ cx,
+ ));
};
}
}
@@ -1361,14 +1279,14 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
- request_assistant_message_id = Some(thread.insert_message(
- Role::Assistant,
- vec![MessageSegment::Thinking {
- text: chunk.to_string(),
- signature,
- }],
- cx,
- ));
+ request_assistant_message_id =
+ Some(thread.insert_assistant_message(
+ vec![MessageSegment::Thinking {
+ text: chunk.to_string(),
+ signature,
+ }],
+ cx,
+ ));
};
}
}
@@ -1376,7 +1294,7 @@ impl Thread {
let last_assistant_message_id = request_assistant_message_id
.unwrap_or_else(|| {
let new_assistant_message_id =
- thread.insert_message(Role::Assistant, vec![], cx);
+ thread.insert_assistant_message(vec![], cx);
request_assistant_message_id =
Some(new_assistant_message_id);
new_assistant_message_id
@@ -2097,8 +2015,16 @@ impl Thread {
}
)?;
- if !message.context.is_empty() {
- writeln!(markdown, "{}", message.context)?;
+ if !message.loaded_context.text.is_empty() {
+ writeln!(markdown, "{}", message.loaded_context.text)?;
+ }
+
+ if !message.loaded_context.images.is_empty() {
+ writeln!(
+ markdown,
+ "\n{} images attached as context.\n",
+ message.loaded_context.images.len()
+ )?;
}
for segment in &message.segments {
@@ -2373,7 +2299,7 @@ struct PendingCompletion {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{ThreadStore, context_store::ContextStore, thread_store};
+ use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use context_server::ContextServerSettings;
use editor::EditorSettings;
@@ -2404,12 +2330,14 @@ mod tests {
.await
.unwrap();
- let context =
- context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
+ let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
+ let loaded_context = cx
+ .update(|cx| load_context(vec![context], &project, &None, cx))
+ .await;
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message("Please explain this code", vec![context], None, cx)
+ thread.insert_user_message("Please explain this code", loaded_context, None, cx)
});
// Check content and context in message object
@@ -2443,7 +2371,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
- assert_eq!(message.context, expected_context);
+ assert_eq!(message.loaded_context.text, expected_context);
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2470,48 +2398,50 @@ fn main() {{
let (_, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
- // Open files individually
+ // First message with context 1
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
- add_file_to_context(&project, &context_store, "test/file2.rs", cx)
- .await
- .unwrap();
- add_file_to_context(&project, &context_store, "test/file3.rs", cx)
- .await
- .unwrap();
-
- // Get the context objects
- let contexts = context_store.update(cx, |store, _| store.context().clone());
- assert_eq!(contexts.len(), 3);
-
- // First message with context 1
+ let new_contexts = context_store.update(cx, |store, cx| {
+ store.new_context_for_thread(thread.read(cx))
+ });
+ assert_eq!(new_contexts.len(), 1);
+ let loaded_context = cx
+ .update(|cx| load_context(new_contexts, &project, &None, cx))
+ .await;
let message1_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
+ thread.insert_user_message("Message 1", loaded_context, None, cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
+ add_file_to_context(&project, &context_store, "test/file2.rs", cx)
+ .await
+ .unwrap();
+ let new_contexts = context_store.update(cx, |store, cx| {
+ store.new_context_for_thread(thread.read(cx))
+ });
+ assert_eq!(new_contexts.len(), 1);
+ let loaded_context = cx
+ .update(|cx| load_context(new_contexts, &project, &None, cx))
+ .await;
let message2_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message(
- "Message 2",
- vec![contexts[0].clone(), contexts[1].clone()],
- None,
- cx,
- )
+ thread.insert_user_message("Message 2", loaded_context, None, cx)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
+ //
+ add_file_to_context(&project, &context_store, "test/file3.rs", cx)
+ .await
+ .unwrap();
+ let new_contexts = context_store.update(cx, |store, cx| {
+ store.new_context_for_thread(thread.read(cx))
+ });
+ assert_eq!(new_contexts.len(), 1);
+ let loaded_context = cx
+ .update(|cx| load_context(new_contexts, &project, &None, cx))
+ .await;
let message3_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message(
- "Message 3",
- vec![
- contexts[0].clone(),
- contexts[1].clone(),
- contexts[2].clone(),
- ],
- None,
- cx,
- )
+ thread.insert_user_message("Message 3", loaded_context, None, cx)
});
// Check what contexts are included in each message
@@ -2524,16 +2454,16 @@ fn main() {{
});
// First message should include context 1
- assert!(message1.context.contains("file1.rs"));
+ assert!(message1.loaded_context.text.contains("file1.rs"));
// Second message should include only context 2 (not 1)
- assert!(!message2.context.contains("file1.rs"));
- assert!(message2.context.contains("file2.rs"));
+ assert!(!message2.loaded_context.text.contains("file1.rs"));
+ assert!(message2.loaded_context.text.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
- assert!(!message3.context.contains("file1.rs"));
- assert!(!message3.context.contains("file2.rs"));
- assert!(message3.context.contains("file3.rs"));
+ assert!(!message3.loaded_context.text.contains("file1.rs"));
+ assert!(!message3.loaded_context.text.contains("file2.rs"));
+ assert!(message3.loaded_context.text.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2570,7 +2500,12 @@ fn main() {{
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
+ thread.insert_user_message(
+ "What is the best way to learn Rust?",
+ ContextLoadResult::default(),
+ None,
+ cx,
+ )
});
// Check content and context in message object
@@ -2583,7 +2518,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
- assert_eq!(message.context, "");
+ assert_eq!(message.loaded_context.text, "");
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2596,12 +2531,17 @@ fn main() {{
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
- thread.insert_user_message("Are there any good books?", vec![], None, cx)
+ thread.insert_user_message(
+ "Are there any good books?",
+ ContextLoadResult::default(),
+ None,
+ cx,
+ )
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
- assert_eq!(message2.context, "");
+ assert_eq!(message2.loaded_context.text, "");
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2635,12 +2575,14 @@ fn main() {{
.await
.unwrap();
- let context =
- context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
+ let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
+ let loaded_context = cx
+ .update(|cx| load_context(vec![context], &project, &None, cx))
+ .await;
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
- thread.insert_user_message("Explain this code", vec![context], None, cx)
+ thread.insert_user_message("Explain this code", loaded_context, None, cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
@@ -2668,7 +2610,12 @@ fn main() {{
// Insert another user message without context
thread.update(cx, |thread, cx| {
- thread.insert_user_message("What does the code do now?", vec![], None, cx)
+ thread.insert_user_message(
+ "What does the code do now?",
+ ContextLoadResult::default(),
+ None,
+ cx,
+ )
});
// Create a new request and check for the stale buffer warning
@@ -2735,6 +2682,7 @@ fn main() {{
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
+ None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
@@ -2759,15 +2707,15 @@ fn main() {{
.unwrap();
let buffer = project
- .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
+ .update(cx, |project, cx| {
+ project.open_buffer(buffer_path.clone(), cx)
+ })
.await
.unwrap();
- context_store
- .update(cx, |store, cx| {
- store.add_file_from_buffer(buffer.clone(), cx)
- })
- .await?;
+ context_store.update(cx, |context_store, cx| {
+ context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
+ });
Ok(buffer)
}
@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
- ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
- RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
+ ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
+ UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@@ -82,12 +82,11 @@ impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
+ prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
- let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
- let prompt_store = prompt_store.await.ok();
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
@@ -349,25 +348,8 @@ impl ThreadStore {
self.context_server_manager.clone()
}
- pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
- self.prompt_store.clone()
- }
-
- pub fn load_rules(
- &self,
- prompt_id: UserPromptId,
- cx: &App,
- ) -> Task<Result<(PromptMetadata, String)>> {
- let prompt_id = PromptId::User { uuid: prompt_id };
- let Some(prompt_store) = self.prompt_store.as_ref() else {
- return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
- };
- let prompt_store = prompt_store.read(cx);
- let Some(metadata) = prompt_store.metadata(prompt_id) else {
- return Task::ready(Err(anyhow!("User rules not found in library.")));
- };
- let text_task = prompt_store.load(prompt_id, cx);
- cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
+ pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
+ &self.prompt_store
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
@@ -379,16 +361,12 @@ impl ThreadStore {
self.threads.len()
}
- pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
+ pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads
}
- pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
- self.threads().into_iter().take(limit).collect()
- }
-
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| {
Thread::new(
@@ -1,14 +1,13 @@
-use std::sync::Arc;
use std::{rc::Rc, time::Duration};
use file_icons::FileIcons;
-use futures::FutureExt;
-use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
-use gpui::{ClickEvent, Task};
-use language_model::LanguageModelImage;
+use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between};
+use project::Project;
+use prompt_store::PromptStore;
+use text::OffsetRangeExt;
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
-use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
+use crate::context::{AgentContext, ContextKind, ImageStatus};
#[derive(IntoElement)]
pub enum ContextPill {
@@ -73,9 +72,7 @@ impl ContextPill {
pub fn id(&self) -> ElementId {
match self {
- Self::Added { context, .. } => {
- ElementId::NamedInteger("context-pill".into(), context.id.0)
- }
+ Self::Added { context, .. } => context.context.element_id("context-pill".into()),
Self::Suggested { .. } => "suggested-context-pill".into(),
}
}
@@ -199,14 +196,17 @@ impl RenderOnce for ContextPill {
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
- IconButton::new(("remove", context.id.0), IconName::Close)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::XSmall)
- .tooltip(Tooltip::text("Remove Context"))
- .on_click({
- let on_remove = on_remove.clone();
- move |event, window, cx| on_remove(event, window, cx)
- }),
+ IconButton::new(
+ context.context.element_id("remove".into()),
+ IconName::Close,
+ )
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .tooltip(Tooltip::text("Remove Context"))
+ .on_click({
+ let on_remove = on_remove.clone();
+ move |event, window, cx| on_remove(event, window, cx)
+ }),
)
})
.when_some(on_click.as_ref(), |element, on_click| {
@@ -262,9 +262,11 @@ pub enum ContextStatus {
Error { message: SharedString },
}
-#[derive(RegisterComponent)]
+// TODO: Component commented out due to new dependency on `Project`.
+//
+// #[derive(RegisterComponent)]
pub struct AddedContext {
- pub id: ContextId,
+ pub context: AgentContext,
pub kind: ContextKind,
pub name: SharedString,
pub parent: Option<SharedString>,
@@ -275,10 +277,19 @@ pub struct AddedContext {
}
impl AddedContext {
- pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
+ /// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
+ /// `None` if `DirectoryContext` or `RulesContext` no longer exist.
+ ///
+ /// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
+ pub fn new(
+ context: AgentContext,
+ prompt_store: Option<&Entity<PromptStore>>,
+ project: &Project,
+ cx: &App,
+ ) -> Option<AddedContext> {
match context {
- AssistantContext::File(file_context) => {
- let full_path = file_context.context_buffer.full_path(cx);
+ AgentContext::File(ref file_context) => {
+ let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -289,8 +300,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
- AddedContext {
- id: file_context.id,
+ Some(AddedContext {
kind: ContextKind::File,
name,
parent,
@@ -298,18 +308,16 @@ impl AddedContext {
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
- }
+ context,
+ })
}
- AssistantContext::Directory(directory_context) => {
- let worktree = directory_context.worktree.read(cx);
- // If the directory no longer exists, use its last known path.
- let full_path = worktree
- .entry_for_id(directory_context.entry_id)
- .map_or_else(
- || directory_context.last_path.clone(),
- |entry| worktree.full_path(&entry.path).into(),
- );
+ AgentContext::Directory(ref directory_context) => {
+ let worktree = project
+ .worktree_for_entry(directory_context.entry_id, cx)?
+ .read(cx);
+ let entry = worktree.entry_for_id(directory_context.entry_id)?;
+ let full_path = worktree.full_path(&entry.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -320,8 +328,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
- AddedContext {
- id: directory_context.id,
+ Some(AddedContext {
kind: ContextKind::Directory,
name,
parent,
@@ -329,33 +336,34 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
- }
+ context,
+ })
}
- AssistantContext::Symbol(symbol_context) => AddedContext {
- id: symbol_context.id,
+ AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
kind: ContextKind::Symbol,
- name: symbol_context.context_symbol.id.name.clone(),
+ name: symbol_context.symbol.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
- },
+ context,
+ }),
- AssistantContext::Selection(selection_context) => {
- let full_path = selection_context.context_buffer.full_path(cx);
+ AgentContext::Selection(ref selection_context) => {
+ let buffer = selection_context.buffer.read(cx);
+ let full_path = buffer.file()?.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| full_path_string.clone());
- let line_range_text = format!(
- " ({}-{})",
- selection_context.line_range.start.row + 1,
- selection_context.line_range.end.row + 1
- );
+ let line_range = selection_context.range.to_point(&buffer.snapshot());
+
+ let line_range_text =
+ format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
full_path_string.push_str(&line_range_text);
name.push_str(&line_range_text);
@@ -365,16 +373,17 @@ impl AddedContext {
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
- AddedContext {
- id: selection_context.id,
+ Some(AddedContext {
kind: ContextKind::Selection,
name: name.into(),
parent,
tooltip: None,
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
+ render_preview: None,
+ /*
render_preview: Some(Rc::new({
- let content = selection_context.context_buffer.text.clone();
+ let content = selection_context.text.clone();
move |_, cx| {
div()
.id("context-pill-selection-preview")
@@ -385,11 +394,12 @@ impl AddedContext {
.into_any_element()
}
})),
- }
+ */
+ context,
+ })
}
- AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
- id: fetched_url_context.id,
+ AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
kind: ContextKind::FetchedUrl,
name: fetched_url_context.url.clone(),
parent: None,
@@ -397,12 +407,12 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
- },
+ context,
+ }),
- AssistantContext::Thread(thread_context) => AddedContext {
- id: thread_context.id,
+ AgentContext::Thread(ref thread_context) => Some(AddedContext {
kind: ContextKind::Thread,
- name: thread_context.summary(cx),
+ name: thread_context.name(cx),
parent: None,
tooltip: None,
icon_path: None,
@@ -418,36 +428,41 @@ impl AddedContext {
ContextStatus::Ready
},
render_preview: None,
- },
+ context,
+ }),
- AssistantContext::Rules(user_rules_context) => AddedContext {
- id: user_rules_context.id,
- kind: ContextKind::Rules,
- name: user_rules_context.title.clone(),
- parent: None,
- tooltip: None,
- icon_path: None,
- status: ContextStatus::Ready,
- render_preview: None,
- },
+ AgentContext::Rules(ref user_rules_context) => {
+ let name = prompt_store
+ .as_ref()?
+ .read(cx)
+ .metadata(user_rules_context.prompt_id.into())?
+ .title?;
+ Some(AddedContext {
+ kind: ContextKind::Rules,
+ name: name.clone(),
+ parent: None,
+ tooltip: None,
+ icon_path: None,
+ status: ContextStatus::Ready,
+ render_preview: None,
+ context,
+ })
+ }
- AssistantContext::Image(image_context) => AddedContext {
- id: image_context.id,
+ AgentContext::Image(ref image_context) => Some(AddedContext {
kind: ContextKind::Image,
name: "Image".into(),
parent: None,
tooltip: None,
icon_path: None,
- status: if image_context.is_loading() {
- ContextStatus::Loading {
+ status: match image_context.status() {
+ ImageStatus::Loading => ContextStatus::Loading {
message: "Loadingβ¦".into(),
- }
- } else if image_context.is_error() {
- ContextStatus::Error {
+ },
+ ImageStatus::Error => ContextStatus::Error {
message: "Failed to load image".into(),
- }
- } else {
- ContextStatus::Ready
+ },
+ ImageStatus::Ready => ContextStatus::Ready,
},
render_preview: Some(Rc::new({
let image = image_context.original_image.clone();
@@ -458,7 +473,8 @@ impl AddedContext {
.into_any_element()
}
})),
- },
+ context,
+ }),
}
}
}
@@ -478,6 +494,8 @@ impl Render for ContextPillPreview {
}
}
+// TODO: Component commented out due to new dependency on `Project`.
+/*
impl Component for AddedContext {
fn scope() -> ComponentScope {
ComponentScope::Agent
@@ -487,12 +505,13 @@ impl Component for AddedContext {
"AddedContext"
}
- fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
+ fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
+ let next_context_id = ContextId::zero();
let image_ready = (
"Ready",
AddedContext::new(
- &AssistantContext::Image(ImageContext {
- id: ContextId(0),
+ AgentContext::Image(ImageContext {
+ context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
}),
@@ -503,8 +522,8 @@ impl Component for AddedContext {
let image_loading = (
"Loading",
AddedContext::new(
- &AssistantContext::Image(ImageContext {
- id: ContextId(1),
+ AgentContext::Image(ImageContext {
+ context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: cx
.background_spawn(async move {
@@ -520,8 +539,8 @@ impl Component for AddedContext {
let image_error = (
"Error",
AddedContext::new(
- &AssistantContext::Image(ImageContext {
- id: ContextId(2),
+ AgentContext::Image(ImageContext {
+ context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
}),
@@ -544,5 +563,8 @@ impl Component for AddedContext {
)
.into_any(),
)
+
+ None
}
}
+*/
@@ -25,7 +25,7 @@ use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::Project;
-use prompt_store::{PromptBuilder, PromptId, UserPromptId};
+use prompt_store::{PromptBuilder, UserPromptId};
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
@@ -1059,9 +1059,9 @@ impl AssistantPanel {
None,
))
}),
- action.prompt_to_select.map(|uuid| PromptId::User {
- uuid: UserPromptId(uuid),
- }),
+ action
+ .prompt_to_select
+ .map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -10,7 +10,7 @@ use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
};
-use agent::ThreadEvent;
+use agent::{ContextLoadResult, ThreadEvent};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
@@ -115,7 +115,12 @@ impl ExampleContext {
pub fn push_user_message(&mut self, text: impl ToString) {
self.app
.update_entity(&self.agent_thread, |thread, cx| {
- thread.insert_user_message(text.to_string(), vec![], None, cx);
+ thread.insert_user_message(
+ text.to_string(),
+ ContextLoadResult::default(),
+ None,
+ cx,
+ );
})
.unwrap();
}
@@ -218,8 +218,14 @@ impl ExampleInstance {
});
let tools = cx.new(|_| ToolWorkingSet::default());
- let thread_store =
- ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
+ let prompt_store = None;
+ let thread_store = ThreadStore::load(
+ project.clone(),
+ tools,
+ prompt_store,
+ app_state.prompt_builder.clone(),
+ cx,
+ );
let meta = self.thread.meta();
let this = self.clone();
@@ -60,9 +60,7 @@ pub enum PromptId {
impl PromptId {
pub fn new() -> PromptId {
- PromptId::User {
- uuid: UserPromptId::new(),
- }
+ UserPromptId::new().into()
}
pub fn is_built_in(&self) -> bool {
@@ -70,6 +68,12 @@ impl PromptId {
}
}
+impl From<UserPromptId> for PromptId {
+ fn from(uuid: UserPromptId) -> Self {
+ PromptId::User { uuid }
+ }
+}
+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserPromptId(pub Uuid);
@@ -227,9 +231,7 @@ impl PromptStore {
.collect::<heed::Result<HashMap<_, _>>>()?;
for (prompt_id_v1, metadata_v1) in metadata_v1 {
- let prompt_id_v2 = PromptId::User {
- uuid: UserPromptId(prompt_id_v1.0),
- };
+ let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
continue;
};