diff --git a/Cargo.lock b/Cargo.lock index 0b24221bb6594478b70e50be0c03e2456b97e402..d33d31d9fdc5ab0e7819cbaf1d15c0a149d56627 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,7 +142,7 @@ dependencies = [ "agent_servers", "agent_settings", "anyhow", - "assistant_context", + "assistant_text_thread", "chrono", "client", "clock", @@ -315,9 +315,9 @@ dependencies = [ "ai_onboarding", "anyhow", "arrayvec", - "assistant_context", "assistant_slash_command", "assistant_slash_commands", + "assistant_text_thread", "audio", "buffer_diff", "chrono", @@ -803,107 +803,107 @@ dependencies = [ ] [[package]] -name = "assistant_context" +name = "assistant_slash_command" version = "0.1.0" dependencies = [ - "agent_settings", "anyhow", - "assistant_slash_command", - "assistant_slash_commands", - "chrono", - "client", - "clock", - "cloud_llm_client", + "async-trait", "collections", - "context_server", - "fs", + "derive_more 0.99.20", + "extension", "futures 0.3.31", - "fuzzy", "gpui", - "indoc", "language", "language_model", - "log", - "open_ai", "parking_lot", - "paths", "pretty_assertions", - "project", - "prompt_store", - "proto", - "rand 0.9.2", - "regex", - "rpc", "serde", "serde_json", - "settings", - "smallvec", - "smol", - "telemetry_events", - "text", "ui", - "unindent", "util", - "uuid", "workspace", - "zed_env_vars", ] [[package]] -name = "assistant_slash_command" +name = "assistant_slash_commands" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", + "assistant_slash_command", + "chrono", "collections", - "derive_more 0.99.20", - "extension", + "context_server", + "editor", + "feature_flags", + "fs", "futures 0.3.31", + "fuzzy", + "globset", "gpui", + "html_to_markdown", + "http_client", "language", - "language_model", - "parking_lot", "pretty_assertions", + "project", + "prompt_store", + "rope", "serde", "serde_json", + "settings", + "smol", + "text", "ui", "util", "workspace", + "worktree", + "zlog", ] [[package]] -name = "assistant_slash_commands" +name = "assistant_text_thread" version = "0.1.0" dependencies = [ + "agent_settings", "anyhow", "assistant_slash_command", + "assistant_slash_commands", "chrono", + "client", + "clock", + "cloud_llm_client", "collections", "context_server", - "editor", - "feature_flags", "fs", "futures 0.3.31", "fuzzy", - "globset", "gpui", - "html_to_markdown", - "http_client", + "indoc", "language", + "language_model", + "log", + "open_ai", + "parking_lot", + "paths", "pretty_assertions", "project", "prompt_store", - "rope", + "proto", + "rand 0.9.2", + "regex", + "rpc", "serde", "serde_json", "settings", + "smallvec", "smol", + "telemetry_events", "text", "ui", + "unindent", "util", + "uuid", "workspace", - "worktree", - "zlog", + "zed_env_vars", ] [[package]] @@ -3324,8 +3324,8 @@ version = "0.44.0" dependencies = [ "agent_settings", "anyhow", - "assistant_context", "assistant_slash_command", + "assistant_text_thread", "async-trait", "async-tungstenite", "audio", diff --git a/Cargo.toml b/Cargo.toml index c0c0ffc1508aaa51465db7a30cccfcfa04fd8467..e0682924bc377d40a0711630c77ad4dd000515b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ "crates/anthropic", "crates/askpass", "crates/assets", - "crates/assistant_context", + "crates/assistant_text_thread", "crates/assistant_slash_command", "crates/assistant_slash_commands", "crates/audio", @@ -246,7 +246,7 @@ ai_onboarding = { path = "crates/ai_onboarding" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } assets = { path = "crates/assets" } -assistant_context = { path = "crates/assistant_context" } +assistant_text_thread = { path = "crates/assistant_text_thread" } assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } audio = { path = "crates/audio" } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 9e5b6ad66096b784bfb496b71ef1ee5cb30005cb..e0f2d9dcb97e298dd3c906e3f902974821efcdc0 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -24,7 +24,7 @@ agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true anyhow.workspace = true -assistant_context.workspace = true +assistant_text_thread.workspace = true chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true @@ -76,7 +76,7 @@ zstd.workspace = true [dev-dependencies] agent_servers = { workspace = true, "features" = ["test-support"] } -assistant_context = { workspace = true, "features" = ["test-support"] } +assistant_text_thread = { workspace = true, "features" = ["test-support"] } client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } context_server = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 65eb25e6ac9d005fc2e18901a56287e2938e5bb8..63ee0adf191cbe309229c57b950d11ca7a3680e3 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1266,8 +1266,9 @@ mod internal_tests { ) .await; let project = Project::test(fs.clone(), [], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = + cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let agent = NativeAgent::new( project.clone(), history_store, @@ -1327,8 +1328,9 @@ mod internal_tests { let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = + cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let connection = NativeAgentConnection( NativeAgent::new( project.clone(), @@ -1402,8 +1404,9 @@ mod internal_tests { .await; let project = Project::test(fs.clone(), [], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = + cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); // Create the agent and connection let agent = NativeAgent::new( @@ -1474,8 +1477,9 @@ mod internal_tests { ) .await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = + cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let agent = NativeAgent::new( project.clone(), history_store.clone(), diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index c342110f3ee289b6e84241517b69fe9a86efcf16..3bfbd99677feed5db53d96d2fa96316ac49abce4 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -2,12 +2,12 @@ use crate::{DbThread, DbThreadMetadata, ThreadsDatabase}; use acp_thread::MentionUri; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; -use assistant_context::{AssistantContext, SavedContextMetadata}; +use assistant_text_thread::{SavedTextThreadMetadata, TextThread}; use chrono::{DateTime, Utc}; use db::kvp::KEY_VALUE_STORE; use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; use itertools::Itertools; -use paths::contexts_dir; +use paths::text_threads_dir; use project::Project; use serde::{Deserialize, Serialize}; use std::{collections::VecDeque, path::Path, rc::Rc, sync::Arc, time::Duration}; @@ -50,21 +50,23 @@ pub fn load_agent_thread( #[derive(Clone, Debug)] pub enum HistoryEntry { AcpThread(DbThreadMetadata), - TextThread(SavedContextMetadata), + TextThread(SavedTextThreadMetadata), } impl HistoryEntry { pub fn updated_at(&self) -> DateTime { match self { HistoryEntry::AcpThread(thread) => thread.updated_at, - HistoryEntry::TextThread(context) => context.mtime.to_utc(), + HistoryEntry::TextThread(text_thread) => text_thread.mtime.to_utc(), } } pub fn id(&self) -> HistoryEntryId { match self { HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()), - HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()), + HistoryEntry::TextThread(text_thread) => { + HistoryEntryId::TextThread(text_thread.path.clone()) + } } } @@ -74,9 +76,9 @@ impl HistoryEntry { id: thread.id.clone(), name: thread.title.to_string(), }, - HistoryEntry::TextThread(context) => MentionUri::TextThread { - path: context.path.as_ref().to_owned(), - name: context.title.to_string(), + HistoryEntry::TextThread(text_thread) => MentionUri::TextThread { + path: text_thread.path.as_ref().to_owned(), + name: text_thread.title.to_string(), }, } } @@ -90,7 +92,7 @@ impl HistoryEntry { &thread.title } } - HistoryEntry::TextThread(context) => &context.title, + HistoryEntry::TextThread(text_thread) => &text_thread.title, } } } @@ -120,7 +122,7 @@ enum SerializedRecentOpen { pub struct HistoryStore { threads: Vec, entries: Vec, - text_thread_store: Entity, + text_thread_store: Entity, recently_opened_entries: VecDeque, _subscriptions: Vec, _save_recently_opened_entries_task: Task<()>, @@ -128,7 +130,7 @@ pub struct HistoryStore { impl HistoryStore { pub fn new( - text_thread_store: Entity, + text_thread_store: Entity, cx: &mut Context, ) -> Self { let subscriptions = @@ -192,16 +194,16 @@ impl HistoryStore { cx: &mut Context, ) -> Task> { self.text_thread_store - .update(cx, |store, cx| store.delete_local_context(path, cx)) + .update(cx, |store, cx| store.delete_local(path, cx)) } pub fn load_text_thread( &self, path: Arc, cx: &mut Context, - ) -> Task>> { + ) -> Task>> { self.text_thread_store - .update(cx, |store, cx| store.open_local_context(path, cx)) + .update(cx, |store, cx| store.open_local(path, cx)) } pub fn reload(&self, cx: &mut Context) { @@ -243,7 +245,7 @@ impl HistoryStore { history_entries.extend( self.text_thread_store .read(cx) - .unordered_contexts() + .unordered_text_threads() .cloned() .map(HistoryEntry::TextThread), ); @@ -278,14 +280,14 @@ impl HistoryStore { let context_entries = self .text_thread_store .read(cx) - .unordered_contexts() - .flat_map(|context| { + .unordered_text_threads() + .flat_map(|text_thread| { self.recently_opened_entries .iter() .enumerate() .flat_map(|(index, entry)| match entry { - HistoryEntryId::TextThread(path) if &context.path == path => { - Some((index, HistoryEntry::TextThread(context.clone()))) + HistoryEntryId::TextThread(path) if &text_thread.path == path => { + Some((index, HistoryEntry::TextThread(text_thread.clone()))) } _ => None, }) @@ -347,7 +349,7 @@ impl HistoryStore { acp::SessionId(id.as_str().into()), )), SerializedRecentOpen::TextThread(file_name) => Some( - HistoryEntryId::TextThread(contexts_dir().join(file_name).into()), + HistoryEntryId::TextThread(text_threads_dir().join(file_name).into()), ), }) .collect(); diff --git a/crates/agent/src/native_agent_server.rs b/crates/agent/src/native_agent_server.rs index 0dde0ff98552d4292a4391d2aec4f36419228a25..b28009223b7a7f2232b440282a0d6f61907f442c 100644 --- a/crates/agent/src/native_agent_server.rs +++ b/crates/agent/src/native_agent_server.rs @@ -81,7 +81,7 @@ impl AgentServer for NativeAgentServer { mod tests { use super::*; - use assistant_context::ContextStore; + use assistant_text_thread::TextThreadStore; use gpui::AppContext; agent_servers::e2e_tests::common_e2e_tests!( @@ -116,8 +116,9 @@ mod tests { }); let history = cx.update(|cx| { - let context_store = cx.new(move |cx| ContextStore::fake(project.clone(), cx)); - cx.new(move |cx| HistoryStore::new(context_store, cx)) + let text_thread_store = + cx.new(move |cx| TextThreadStore::fake(project.clone(), cx)); + cx.new(move |cx| HistoryStore::new(text_thread_store, cx)) }); NativeAgentServer::new(fs.clone(), history) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 66b006893e50b9c59701eff850adb7747f96e3b5..ddddbfc5279ca23fb95527892e929b23b8cefbf6 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -1834,8 +1834,9 @@ async fn test_agent_connection(cx: &mut TestAppContext) { fake_fs.insert_tree(path!("/test"), json!({})).await; let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let cwd = Path::new("/test"); - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = + cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); // Create agent and connection let agent = NativeAgent::new( diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index f763d6f91e45d1e8b5a035c22fbb7ab65de93dd9..724b53a017911edbd6e9dd88c410daf794889d4e 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -25,7 +25,7 @@ agent_settings.workspace = true ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true -assistant_context.workspace = true +assistant_text_thread.workspace = true assistant_slash_command.workspace = true assistant_slash_commands.workspace = true audio.workspace = true @@ -102,7 +102,7 @@ zed_actions.workspace = true [dev-dependencies] acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } -assistant_context = { workspace = true, features = ["test-support"] } +assistant_text_thread = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } db = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 8123c4a422b9d95a2da45e75ceb4079675d845fd..4c058b984f4fa24074ea9e9d81e43c1d73d87d1f 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -402,7 +402,7 @@ mod tests { use agent::HistoryStore; use agent_client_protocol as acp; use agent_settings::AgentSettings; - use assistant_context::ContextStore; + use assistant_text_thread::TextThreadStore; use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; use editor::{EditorSettings, RowInfo}; use fs::FakeFs; @@ -466,8 +466,8 @@ mod tests { connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) }); - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let view_state = cx.new(|_cx| { EntryViewState::new( diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 53a26d9fabdd59e93efbc615ce5be5d1c2d492fb..c24cefcf2d5fc04baffeb9f3d1a1ecaf9dd05268 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -629,12 +629,12 @@ impl MessageEditor { path: PathBuf, cx: &mut Context, ) -> Task> { - let context = self.history_store.update(cx, |store, cx| { + let text_thread_task = self.history_store.update(cx, |store, cx| { store.load_text_thread(path.as_path().into(), cx) }); cx.spawn(async move |_, cx| { - let context = context.await?; - let xml = context.update(cx, |context, cx| context.to_xml(cx))?; + let text_thread = text_thread_task.await?; + let xml = text_thread.update(cx, |text_thread, cx| text_thread.to_xml(cx))?; Ok(Mention::Text { content: xml, tracked_buffers: Vec::new(), @@ -1591,7 +1591,7 @@ mod tests { use acp_thread::MentionUri; use agent::{HistoryStore, outline}; use agent_client_protocol as acp; - use assistant_context::ContextStore; + use assistant_text_thread::TextThreadStore; use editor::{AnchorRangeExt as _, Editor, EditorMode}; use fs::FakeFs; use futures::StreamExt as _; @@ -1622,8 +1622,8 @@ mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let message_editor = cx.update(|window, cx| { cx.new(|cx| { @@ -1727,8 +1727,8 @@ mod tests { .await; let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); // Start with no available commands - simulating Claude which doesn't support slash commands let available_commands = Rc::new(RefCell::new(vec![])); @@ -1891,8 +1891,8 @@ mod tests { let mut cx = VisualTestContext::from_window(*window, cx); - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); let available_commands = Rc::new(RefCell::new(vec![ acp::AvailableCommand { @@ -2131,8 +2131,8 @@ mod tests { opened_editors.push(buffer); } - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| { @@ -2658,8 +2658,8 @@ mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); let message_editor = cx.update(|window, cx| { cx.new(|cx| { diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index aacae785a1f6ba727089c053588e6f0bc2ae24a2..d96c3b3219717b3ffa7310d207a323bc5fb222b0 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -324,8 +324,8 @@ impl AcpThreadHistory { HistoryEntry::AcpThread(thread) => self .history_store .update(cx, |this, cx| this.delete_thread(thread.id.clone(), cx)), - HistoryEntry::TextThread(context) => self.history_store.update(cx, |this, cx| { - this.delete_text_thread(context.path.clone(), cx) + HistoryEntry::TextThread(text_thread) => self.history_store.update(cx, |this, cx| { + this.delete_text_thread(text_thread.path.clone(), cx) }), }; task.detach_and_log_err(cx); @@ -635,12 +635,12 @@ impl RenderOnce for AcpHistoryEntryElement { }); } } - HistoryEntry::TextThread(context) => { + HistoryEntry::TextThread(text_thread) => { if let Some(panel) = workspace.read(cx).panel::(cx) { panel.update(cx, |panel, cx| { panel .open_saved_text_thread( - context.path.clone(), + text_thread.path.clone(), window, cx, ) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index adf279c82036e8f8219c5647f016ec4fc887a046..8e5396590fe0170b536075bff210c859435a4b3c 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -5414,9 +5414,11 @@ impl AcpThreadView { HistoryEntry::AcpThread(thread) => self.history_store.update(cx, |history, cx| { history.delete_thread(thread.id.clone(), cx) }), - HistoryEntry::TextThread(context) => self.history_store.update(cx, |history, cx| { - history.delete_text_thread(context.path.clone(), cx) - }), + HistoryEntry::TextThread(text_thread) => { + self.history_store.update(cx, |history, cx| { + history.delete_text_thread(text_thread.path.clone(), cx) + }) + } }; task.detach_and_log_err(cx); } @@ -5735,7 +5737,7 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { pub(crate) mod tests { use acp_thread::StubAgentConnection; use agent_client_protocol::SessionId; - use assistant_context::ContextStore; + use assistant_text_thread::TextThreadStore; use editor::EditorSettings; use fs::FakeFs; use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext}; @@ -5898,10 +5900,10 @@ pub(crate) mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let context_store = - cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx))); + let text_thread_store = + cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); let history_store = - cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx))); + cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(text_thread_store, cx))); let thread_view = cx.update(|window, cx| { cx.new(|cx| { @@ -6170,10 +6172,10 @@ pub(crate) mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let context_store = - cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx))); + let text_thread_store = + cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); let history_store = - cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx))); + cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(text_thread_store, cx))); let connection = Rc::new(StubAgentConnection::new()); let thread_view = cx.update(|window, cx| { diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 19f56b26b5b9621b92c307690baefd332da183b0..deb202832469eaa16b3eab3bced0236dc5467c53 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -36,8 +36,8 @@ use crate::{ use agent_settings::AgentSettings; use ai_onboarding::AgentPanelOnboarding; use anyhow::{Result, anyhow}; -use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; +use assistant_text_thread::{TextThread, TextThreadEvent, TextThreadSummary}; use client::{UserStore, zed_urls}; use cloud_llm_client::{Plan, PlanV1, PlanV2, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; @@ -199,7 +199,7 @@ enum ActiveView { thread_view: Entity, }, TextThread { - context_editor: Entity, + text_thread_editor: Entity, title_editor: Entity, buffer_search_bar: Entity, _subscriptions: Vec, @@ -301,13 +301,13 @@ impl ActiveView { } pub fn text_thread( - context_editor: Entity, + text_thread_editor: Entity, acp_history_store: Entity, language_registry: Arc, window: &mut Window, cx: &mut App, ) -> Self { - let title = context_editor.read(cx).title(cx).to_string(); + let title = text_thread_editor.read(cx).title(cx).to_string(); let editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); @@ -323,7 +323,7 @@ impl ActiveView { let subscriptions = vec![ window.subscribe(&editor, cx, { { - let context_editor = context_editor.clone(); + let text_thread_editor = text_thread_editor.clone(); move |editor, event, window, cx| match event { EditorEvent::BufferEdited => { if suppress_first_edit { @@ -332,19 +332,19 @@ impl ActiveView { } let new_summary = editor.read(cx).text(cx); - context_editor.update(cx, |context_editor, cx| { - context_editor - .context() - .update(cx, |assistant_context, cx| { - assistant_context.set_custom_summary(new_summary, cx); + text_thread_editor.update(cx, |text_thread_editor, cx| { + text_thread_editor + .text_thread() + .update(cx, |text_thread, cx| { + text_thread.set_custom_summary(new_summary, cx); }) }) } EditorEvent::Blurred => { if editor.read(cx).text(cx).is_empty() { - let summary = context_editor + let summary = text_thread_editor .read(cx) - .context() + .text_thread() .read(cx) .summary() .or_default(); @@ -358,17 +358,17 @@ impl ActiveView { } } }), - window.subscribe(&context_editor.read(cx).context().clone(), cx, { + window.subscribe(&text_thread_editor.read(cx).text_thread().clone(), cx, { let editor = editor.clone(); - move |assistant_context, event, window, cx| match event { - ContextEvent::SummaryGenerated => { - let summary = assistant_context.read(cx).summary().or_default(); + move |text_thread, event, window, cx| match event { + TextThreadEvent::SummaryGenerated => { + let summary = text_thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); }) } - ContextEvent::PathChanged { old_path, new_path } => { + TextThreadEvent::PathChanged { old_path, new_path } => { acp_history_store.update(cx, |history_store, cx| { if let Some(old_path) = old_path { history_store @@ -389,11 +389,11 @@ impl ActiveView { let buffer_search_bar = cx.new(|cx| BufferSearchBar::new(Some(language_registry), window, cx)); buffer_search_bar.update(cx, |buffer_search_bar, cx| { - buffer_search_bar.set_active_pane_item(Some(&context_editor), window, cx) + buffer_search_bar.set_active_pane_item(Some(&text_thread_editor), window, cx) }); Self::TextThread { - context_editor, + text_thread_editor, title_editor: editor, buffer_search_bar, _subscriptions: subscriptions, @@ -410,7 +410,7 @@ pub struct AgentPanel { language_registry: Arc, acp_history: Entity, history_store: Entity, - text_thread_store: Entity, + text_thread_store: Entity, prompt_store: Option>, context_server_registry: Entity, inline_assist_context_store: Entity, @@ -474,7 +474,7 @@ impl AgentPanel { let text_thread_store = workspace .update(cx, |workspace, cx| { let project = workspace.project().clone(); - assistant_context::ContextStore::new( + assistant_text_thread::TextThreadStore::new( project, prompt_builder, slash_commands, @@ -512,7 +512,7 @@ impl AgentPanel { fn new( workspace: &Workspace, - text_thread_store: Entity, + text_thread_store: Entity, prompt_store: Option>, window: &mut Window, cx: &mut Context, @@ -565,8 +565,8 @@ impl AgentPanel { DefaultView::TextThread => { let context = text_thread_store.update(cx, |store, cx| store.create(cx)); let lsp_adapter_delegate = make_lsp_adapter_delegate(&project.clone(), cx).unwrap(); - let context_editor = cx.new(|cx| { - let mut editor = TextThreadEditor::for_context( + let text_thread_editor = cx.new(|cx| { + let mut editor = TextThreadEditor::for_text_thread( context, fs.clone(), workspace.clone(), @@ -579,7 +579,7 @@ impl AgentPanel { editor }); ActiveView::text_thread( - context_editor, + text_thread_editor, history_store.clone(), language_registry.clone(), window, @@ -736,8 +736,8 @@ impl AgentPanel { .log_err() .flatten(); - let context_editor = cx.new(|cx| { - let mut editor = TextThreadEditor::for_context( + let text_thread_editor = cx.new(|cx| { + let mut editor = TextThreadEditor::for_text_thread( context, self.fs.clone(), self.workspace.clone(), @@ -757,7 +757,7 @@ impl AgentPanel { self.set_active_view( ActiveView::text_thread( - context_editor.clone(), + text_thread_editor.clone(), self.history_store.clone(), self.language_registry.clone(), window, @@ -766,7 +766,7 @@ impl AgentPanel { window, cx, ); - context_editor.focus_handle(cx).focus(window); + text_thread_editor.focus_handle(cx).focus(window); } fn external_thread( @@ -905,20 +905,20 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) -> Task> { - let context = self + let text_thread_task = self .history_store .update(cx, |store, cx| store.load_text_thread(path, cx)); cx.spawn_in(window, async move |this, cx| { - let context = context.await?; + let text_thread = text_thread_task.await?; this.update_in(cx, |this, window, cx| { - this.open_text_thread(context, window, cx); + this.open_text_thread(text_thread, window, cx); }) }) } pub(crate) fn open_text_thread( &mut self, - context: Entity, + text_thread: Entity, window: &mut Window, cx: &mut Context, ) { @@ -926,8 +926,8 @@ impl AgentPanel { .log_err() .flatten(); let editor = cx.new(|cx| { - TextThreadEditor::for_context( - context, + TextThreadEditor::for_text_thread( + text_thread, self.fs.clone(), self.workspace.clone(), self.project.clone(), @@ -965,8 +965,10 @@ impl AgentPanel { ActiveView::ExternalAgentThread { thread_view } => { thread_view.focus_handle(cx).focus(window); } - ActiveView::TextThread { context_editor, .. } => { - context_editor.focus_handle(cx).focus(window); + ActiveView::TextThread { + text_thread_editor, .. + } => { + text_thread_editor.focus_handle(cx).focus(window); } ActiveView::History | ActiveView::Configuration => {} } @@ -1183,9 +1185,11 @@ impl AgentPanel { } } - pub(crate) fn active_context_editor(&self) -> Option> { + pub(crate) fn active_text_thread_editor(&self) -> Option> { match &self.active_view { - ActiveView::TextThread { context_editor, .. } => Some(context_editor.clone()), + ActiveView::TextThread { + text_thread_editor, .. + } => Some(text_thread_editor.clone()), _ => None, } } @@ -1206,16 +1210,16 @@ impl AgentPanel { let new_is_special = new_is_history || new_is_config; match &new_view { - ActiveView::TextThread { context_editor, .. } => { - self.history_store.update(cx, |store, cx| { - if let Some(path) = context_editor.read(cx).context().read(cx).path() { - store.push_recently_opened_entry( - agent::HistoryEntryId::TextThread(path.clone()), - cx, - ) - } - }) - } + ActiveView::TextThread { + text_thread_editor, .. + } => self.history_store.update(cx, |store, cx| { + if let Some(path) = text_thread_editor.read(cx).text_thread().read(cx).path() { + store.push_recently_opened_entry( + agent::HistoryEntryId::TextThread(path.clone()), + cx, + ) + } + }), ActiveView::ExternalAgentThread { .. } => {} ActiveView::History | ActiveView::Configuration => {} } @@ -1372,7 +1376,9 @@ impl Focusable for AgentPanel { match &self.active_view { ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), ActiveView::History => self.acp_history.focus_handle(cx), - ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), + ActiveView::TextThread { + text_thread_editor, .. + } => text_thread_editor.focus_handle(cx), ActiveView::Configuration => { if let Some(configuration) = self.configuration.as_ref() { configuration.focus_handle(cx) @@ -1507,17 +1513,17 @@ impl AgentPanel { } ActiveView::TextThread { title_editor, - context_editor, + text_thread_editor, .. } => { - let summary = context_editor.read(cx).context().read(cx).summary(); + let summary = text_thread_editor.read(cx).text_thread().read(cx).summary(); match summary { - ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) + TextThreadSummary::Pending => Label::new(TextThreadSummary::DEFAULT) .color(Color::Muted) .truncate() .into_any_element(), - ContextSummary::Content(summary) => { + TextThreadSummary::Content(summary) => { if summary.done { div() .w_full() @@ -1530,17 +1536,17 @@ impl AgentPanel { .into_any_element() } } - ContextSummary::Error => h_flex() + TextThreadSummary::Error => h_flex() .w_full() .child(title_editor.clone()) .child( IconButton::new("retry-summary-generation", IconName::RotateCcw) .icon_size(IconSize::Small) .on_click({ - let context_editor = context_editor.clone(); + let text_thread_editor = text_thread_editor.clone(); move |_, _window, cx| { - context_editor.update(cx, |context_editor, cx| { - context_editor.regenerate_summary(cx); + text_thread_editor.update(cx, |text_thread_editor, cx| { + text_thread_editor.regenerate_summary(cx); }); } }) @@ -2243,7 +2249,7 @@ impl AgentPanel { fn render_text_thread( &self, - context_editor: &Entity, + text_thread_editor: &Entity, buffer_search_bar: &Entity, window: &mut Window, cx: &mut Context, @@ -2277,7 +2283,7 @@ impl AgentPanel { ) }) }) - .child(context_editor.clone()) + .child(text_thread_editor.clone()) .child(self.render_drag_target(cx)) } @@ -2353,10 +2359,12 @@ impl AgentPanel { thread_view.insert_dragged_files(paths, added_worktrees, window, cx); }); } - ActiveView::TextThread { context_editor, .. } => { - context_editor.update(cx, |context_editor, cx| { + ActiveView::TextThread { + text_thread_editor, .. + } => { + text_thread_editor.update(cx, |text_thread_editor, cx| { TextThreadEditor::insert_dragged_files( - context_editor, + text_thread_editor, paths, added_worktrees, window, @@ -2427,7 +2435,7 @@ impl Render for AgentPanel { .child(self.render_drag_target(cx)), ActiveView::History => parent.child(self.acp_history.clone()), ActiveView::TextThread { - context_editor, + text_thread_editor, buffer_search_bar, .. } => { @@ -2450,7 +2458,7 @@ impl Render for AgentPanel { } }) .child(self.render_text_thread( - context_editor, + text_thread_editor, buffer_search_bar, window, cx, @@ -2528,17 +2536,17 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { pub struct ConcreteAssistantPanelDelegate; impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { - fn active_context_editor( + fn active_text_thread_editor( &self, workspace: &mut Workspace, _window: &mut Window, cx: &mut Context, ) -> Option> { let panel = workspace.panel::(cx)?; - panel.read(cx).active_context_editor() + panel.read(cx).active_text_thread_editor() } - fn open_saved_context( + fn open_local_text_thread( &self, workspace: &mut Workspace, path: Arc, @@ -2554,10 +2562,10 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { }) } - fn open_remote_context( + fn open_remote_text_thread( &self, _workspace: &mut Workspace, - _context_id: assistant_context::ContextId, + _text_thread_id: assistant_text_thread::TextThreadId, _window: &mut Window, _cx: &mut Context, ) -> Task>> { @@ -2588,15 +2596,15 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { thread_view.update(cx, |thread_view, cx| { thread_view.insert_selections(window, cx); }); - } else if let Some(context_editor) = panel.active_context_editor() { + } else if let Some(text_thread_editor) = panel.active_text_thread_editor() { let snapshot = buffer.read(cx).snapshot(cx); let selection_ranges = selection_ranges .into_iter() .map(|range| range.to_point(&snapshot)) .collect::>(); - context_editor.update(cx, |context_editor, cx| { - context_editor.quote_ranges(selection_ranges, snapshot, window, cx) + text_thread_editor.update(cx, |text_thread_editor, cx| { + text_thread_editor.quote_ranges(selection_ranges, snapshot, window, cx) }); } }); diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index cc0d212a86f5db3b0d5cf8ad4b0457689512f33c..7869aa4e0191f393a05ff1b2c0307bccaef41dc8 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -250,7 +250,7 @@ pub fn init( ) { AgentSettings::register(cx); - assistant_context::init(client.clone(), cx); + assistant_text_thread::init(client.clone(), cx); rules_library::init(cx); if !is_eval { // Initializing the language model from the user settings messes with the eval, so we only initialize them when diff --git a/crates/agent_ui/src/context.rs b/crates/agent_ui/src/context.rs index 3d0600605153fd8343205f3889953c100bde7a7a..2a1ff4a1d9d3e0bb6c8b128cf7f944e9ed3ff657 100644 --- a/crates/agent_ui/src/context.rs +++ b/crates/agent_ui/src/context.rs @@ -1,5 +1,5 @@ use agent::outline; -use assistant_context::AssistantContext; +use assistant_text_thread::TextThread; use futures::future; use futures::{FutureExt, future::Shared}; use gpui::{App, AppContext as _, ElementId, Entity, SharedString, Task}; @@ -581,7 +581,7 @@ impl Display for ThreadContext { #[derive(Debug, Clone)] pub struct TextThreadContextHandle { - pub context: Entity, + pub text_thread: Entity, pub context_id: ContextId, } @@ -595,20 +595,20 @@ pub struct TextThreadContext { impl TextThreadContextHandle { // pub fn lookup_key() -> pub fn eq_for_key(&self, other: &Self) -> bool { - self.context == other.context + self.text_thread == other.text_thread } pub fn hash_for_key(&self, state: &mut H) { - self.context.hash(state) + self.text_thread.hash(state) } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary().or_default() + self.text_thread.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task> { let title = self.title(cx); - let text = self.context.read(cx).to_xml(cx); + let text = self.text_thread.read(cx).to_xml(cx); let context = AgentContext::TextThread(TextThreadContext { title, text: text.into(), diff --git a/crates/agent_ui/src/context_store.rs b/crates/agent_ui/src/context_store.rs index e2ee1cd0c94fd6132719ffcc0bd352865b5f9cf9..18aa59c8f716d59e4a0d717904b09472494c4dbc 100644 --- a/crates/agent_ui/src/context_store.rs +++ b/crates/agent_ui/src/context_store.rs @@ -5,7 +5,7 @@ use crate::context::{ }; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; -use assistant_context::AssistantContext; +use assistant_text_thread::TextThread; use collections::{HashSet, IndexSet}; use futures::{self, FutureExt}; use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity}; @@ -200,13 +200,13 @@ impl ContextStore { pub fn add_text_thread( &mut self, - context: Entity, + text_thread: Entity, remove_if_exists: bool, cx: &mut Context, ) -> Option { let context_id = self.next_context_id.post_inc(); let context = AgentContextHandle::TextThread(TextThreadContextHandle { - context, + text_thread, context_id, }); @@ -353,21 +353,15 @@ impl ContextStore { ); }; } - // SuggestedContext::Thread { thread, name: _ } => { - // if let Some(thread) = thread.upgrade() { - // let context_id = self.next_context_id.post_inc(); - // self.insert_context( - // AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }), - // cx, - // ); - // } - // } - SuggestedContext::TextThread { context, name: _ } => { - if let Some(context) = context.upgrade() { + SuggestedContext::TextThread { + text_thread, + name: _, + } => { + if let Some(text_thread) = text_thread.upgrade() { let context_id = self.next_context_id.post_inc(); self.insert_context( AgentContextHandle::TextThread(TextThreadContextHandle { - context, + text_thread, context_id, }), cx, @@ -392,7 +386,7 @@ impl ContextStore { // } AgentContextHandle::TextThread(text_thread_context) => { self.context_text_thread_paths - .extend(text_thread_context.context.read(cx).path().cloned()); + .extend(text_thread_context.text_thread.read(cx).path().cloned()); } _ => {} } @@ -414,7 +408,7 @@ impl ContextStore { .remove(thread_context.thread.read(cx).id()); } AgentContextHandle::TextThread(text_thread_context) => { - if let Some(path) = text_thread_context.context.read(cx).path() { + if let Some(path) = text_thread_context.text_thread.read(cx).path() { self.context_text_thread_paths.remove(path); } } @@ -538,13 +532,9 @@ pub enum SuggestedContext { icon_path: Option, buffer: WeakEntity, }, - // Thread { - // name: SharedString, - // thread: WeakEntity, - // }, TextThread { name: SharedString, - context: WeakEntity, + text_thread: WeakEntity, }, } @@ -552,7 +542,6 @@ impl SuggestedContext { pub fn name(&self) -> &SharedString { match self { Self::File { name, .. } => name, - // Self::Thread { name, .. } => name, Self::TextThread { name, .. } => name, } } @@ -560,7 +549,6 @@ impl SuggestedContext { pub fn icon_path(&self) -> Option { match self { Self::File { icon_path, .. } => icon_path.clone(), - // Self::Thread { .. } => None, Self::TextThread { .. } => None, } } @@ -568,7 +556,6 @@ impl SuggestedContext { pub fn kind(&self) -> ContextKind { match self { Self::File { .. } => ContextKind::File, - // Self::Thread { .. } => ContextKind::Thread, Self::TextThread { .. } => ContextKind::TextThread, } } diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 3eaf59aba39cbaef12e7a4079209956e0e8bed17..d2393ac4f612cebc6cf97d10a38894e7022e53b9 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -132,19 +132,19 @@ impl ContextStrip { let workspace = self.workspace.upgrade()?; let panel = workspace.read(cx).panel::(cx)?.read(cx); - if let Some(active_context_editor) = panel.active_context_editor() { - let context = active_context_editor.read(cx).context(); - let weak_context = context.downgrade(); - let context = context.read(cx); - let path = context.path()?; + if let Some(active_text_thread_editor) = panel.active_text_thread_editor() { + let text_thread = active_text_thread_editor.read(cx).text_thread(); + let weak_text_thread = text_thread.downgrade(); + let text_thread = text_thread.read(cx); + let path = text_thread.path()?; if self.context_store.read(cx).includes_text_thread(path) { return None; } Some(SuggestedContext::TextThread { - name: context.summary().or_default(), - context: weak_context, + name: text_thread.summary().or_default(), + text_thread: weak_text_thread, }) } else { None @@ -332,7 +332,7 @@ impl ContextStrip { AgentContextHandle::TextThread(text_thread_context) => { workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - let context = text_thread_context.context.clone(); + let context = text_thread_context.text_thread.clone(); window.defer(cx, move |window, cx| { panel.update(cx, |panel, cx| { panel.open_text_thread(context, window, cx) diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 0f7f1f1d78056553f758796c7e6b2f14781fce0f..b05dba59e6b19fa5091903882748de853cd9cb93 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -1508,8 +1508,8 @@ impl InlineAssistant { return Some(InlineAssistTarget::Terminal(terminal_view)); } - let context_editor = agent_panel - .and_then(|panel| panel.read(cx).active_context_editor()) + let text_thread_editor = agent_panel + .and_then(|panel| panel.read(cx).active_text_thread_editor()) .and_then(|editor| { let editor = &editor.read(cx).editor().clone(); if editor.read(cx).is_focused(window) { @@ -1519,8 +1519,8 @@ impl InlineAssistant { } }); - if let Some(context_editor) = context_editor { - Some(InlineAssistTarget::Editor(context_editor)) + if let Some(text_thread_editor) = text_thread_editor { + Some(InlineAssistTarget::Editor(text_thread_editor)) } else if let Some(workspace_editor) = workspace .active_item(cx) .and_then(|item| item.act_as::(cx)) diff --git a/crates/agent_ui/src/slash_command_picker.rs b/crates/agent_ui/src/slash_command_picker.rs index a6bb61510cbeb557e22018c73082bba17d177d7e..0c3cf37599887fe8e97dcdc67bb0bd7e28a744a7 100644 --- a/crates/agent_ui/src/slash_command_picker.rs +++ b/crates/agent_ui/src/slash_command_picker.rs @@ -155,8 +155,8 @@ impl PickerDelegate for SlashCommandDelegate { match command { SlashCommandEntry::Info(info) => { self.active_context_editor - .update(cx, |context_editor, cx| { - context_editor.insert_command(&info.name, window, cx) + .update(cx, |text_thread_editor, cx| { + text_thread_editor.insert_command(&info.name, window, cx) }) .ok(); } diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 5aa6f1f6d9405dc7556cb87c82d5300308f059d1..667ccb8938b892dcf59232d5cd7ea8dda04bc4b2 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -74,10 +74,10 @@ use workspace::{ use zed_actions::agent::{AddSelectionToThread, ToggleModelSelector}; use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}; -use assistant_context::{ - AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId, - InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus, - PendingSlashCommandStatus, ThoughtProcessOutputSection, +use assistant_text_thread::{ + CacheStatus, Content, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, + MessageMetadata, MessageStatus, PendingSlashCommandStatus, TextThread, TextThreadEvent, + TextThreadId, ThoughtProcessOutputSection, }; actions!( @@ -126,14 +126,14 @@ pub enum ThoughtProcessStatus { } pub trait AgentPanelDelegate { - fn active_context_editor( + fn active_text_thread_editor( &self, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) -> Option>; - fn open_saved_context( + fn open_local_text_thread( &self, workspace: &mut Workspace, path: Arc, @@ -141,10 +141,10 @@ pub trait AgentPanelDelegate { cx: &mut Context, ) -> Task>; - fn open_remote_context( + fn open_remote_text_thread( &self, workspace: &mut Workspace, - context_id: ContextId, + text_thread_id: TextThreadId, window: &mut Window, cx: &mut Context, ) -> Task>>; @@ -177,7 +177,7 @@ struct GlobalAssistantPanelDelegate(Arc); impl Global for GlobalAssistantPanelDelegate {} pub struct TextThreadEditor { - context: Entity, + text_thread: Entity, fs: Arc, slash_commands: Arc, workspace: WeakEntity, @@ -223,8 +223,8 @@ impl TextThreadEditor { .detach(); } - pub fn for_context( - context: Entity, + pub fn for_text_thread( + text_thread: Entity, fs: Arc, workspace: WeakEntity, project: Entity, @@ -233,14 +233,14 @@ impl TextThreadEditor { cx: &mut Context, ) -> Self { let completion_provider = SlashCommandCompletionProvider::new( - context.read(cx).slash_commands().clone(), + text_thread.read(cx).slash_commands().clone(), Some(cx.entity().downgrade()), Some(workspace.clone()), ); let editor = cx.new(|cx| { let mut editor = - Editor::for_buffer(context.read(cx).buffer().clone(), None, window, cx); + Editor::for_buffer(text_thread.read(cx).buffer().clone(), None, window, cx); editor.disable_scrollbars_and_minimap(window, cx); editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); editor.set_show_line_numbers(false, cx); @@ -264,18 +264,24 @@ impl TextThreadEditor { }); let _subscriptions = vec![ - cx.observe(&context, |_, _, cx| cx.notify()), - cx.subscribe_in(&context, window, Self::handle_context_event), + cx.observe(&text_thread, |_, _, cx| cx.notify()), + cx.subscribe_in(&text_thread, window, Self::handle_text_thread_event), cx.subscribe_in(&editor, window, Self::handle_editor_event), cx.subscribe_in(&editor, window, Self::handle_editor_search_event), cx.observe_global_in::(window, Self::settings_changed), ]; - let slash_command_sections = context.read(cx).slash_command_output_sections().to_vec(); - let thought_process_sections = context.read(cx).thought_process_output_sections().to_vec(); - let slash_commands = context.read(cx).slash_commands().clone(); + let slash_command_sections = text_thread + .read(cx) + .slash_command_output_sections() + .to_vec(); + let thought_process_sections = text_thread + .read(cx) + .thought_process_output_sections() + .to_vec(); + let slash_commands = text_thread.read(cx).slash_commands().clone(); let mut this = Self { - context, + text_thread, slash_commands, editor, lsp_adapter_delegate, @@ -337,8 +343,8 @@ impl TextThreadEditor { }); } - pub fn context(&self) -> &Entity { - &self.context + pub fn text_thread(&self) -> &Entity { + &self.text_thread } pub fn editor(&self) -> &Entity { @@ -350,9 +356,9 @@ impl TextThreadEditor { self.editor.update(cx, |editor, cx| { editor.insert(&format!("/{command_name}\n\n"), window, cx) }); - let command = self.context.update(cx, |context, cx| { - context.reparse(cx); - context.parsed_slash_commands()[0].clone() + let command = self.text_thread.update(cx, |text_thread, cx| { + text_thread.reparse(cx); + text_thread.parsed_slash_commands()[0].clone() }); self.run_command( command.source_range, @@ -375,11 +381,14 @@ impl TextThreadEditor { fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { self.last_error = None; - if let Some(user_message) = self.context.update(cx, |context, cx| context.assist(cx)) { + if let Some(user_message) = self + .text_thread + .update(cx, |text_thread, cx| text_thread.assist(cx)) + { let new_selection = { let cursor = user_message .start - .to_offset(self.context.read(cx).buffer().read(cx)); + .to_offset(self.text_thread.read(cx).buffer().read(cx)); cursor..cursor }; self.editor.update(cx, |editor, cx| { @@ -403,8 +412,8 @@ impl TextThreadEditor { self.last_error = None; if self - .context - .update(cx, |context, cx| context.cancel_last_assist(cx)) + .text_thread + .update(cx, |text_thread, cx| text_thread.cancel_last_assist(cx)) { return; } @@ -419,13 +428,13 @@ impl TextThreadEditor { cx: &mut Context, ) { let cursors = self.cursors(cx); - self.context.update(cx, |context, cx| { - let messages = context + self.text_thread.update(cx, |text_thread, cx| { + let messages = text_thread .messages_for_offsets(cursors, cx) .into_iter() .map(|message| message.id) .collect(); - context.cycle_message_roles(messages, cx) + text_thread.cycle_message_roles(messages, cx) }); } @@ -491,11 +500,11 @@ impl TextThreadEditor { let selections = self.editor.read(cx).selections.disjoint_anchors_arc(); let mut commands_by_range = HashMap::default(); let workspace = self.workspace.clone(); - self.context.update(cx, |context, cx| { - context.reparse(cx); + self.text_thread.update(cx, |text_thread, cx| { + text_thread.reparse(cx); for selection in selections.iter() { if let Some(command) = - context.pending_command_for_position(selection.head().text_anchor, cx) + text_thread.pending_command_for_position(selection.head().text_anchor, cx) { commands_by_range .entry(command.source_range.clone()) @@ -533,14 +542,14 @@ impl TextThreadEditor { cx: &mut Context, ) { if let Some(command) = self.slash_commands.command(name, cx) { - let context = self.context.read(cx); - let sections = context + let text_thread = self.text_thread.read(cx); + let sections = text_thread .slash_command_output_sections() .iter() - .filter(|section| section.is_valid(context.buffer().read(cx))) + .filter(|section| section.is_valid(text_thread.buffer().read(cx))) .cloned() .collect::>(); - let snapshot = context.buffer().read(cx).snapshot(); + let snapshot = text_thread.buffer().read(cx).snapshot(); let output = command.run( arguments, §ions, @@ -550,8 +559,8 @@ impl TextThreadEditor { window, cx, ); - self.context.update(cx, |context, cx| { - context.insert_command_output( + self.text_thread.update(cx, |text_thread, cx| { + text_thread.insert_command_output( command_range, name, output, @@ -562,32 +571,32 @@ impl TextThreadEditor { } } - fn handle_context_event( + fn handle_text_thread_event( &mut self, - _: &Entity, - event: &ContextEvent, + _: &Entity, + event: &TextThreadEvent, window: &mut Window, cx: &mut Context, ) { - let context_editor = cx.entity().downgrade(); + let text_thread_editor = cx.entity().downgrade(); match event { - ContextEvent::MessagesEdited => { + TextThreadEvent::MessagesEdited => { self.update_message_headers(cx); self.update_image_blocks(cx); - self.context.update(cx, |context, cx| { - context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); + self.text_thread.update(cx, |text_thread, cx| { + text_thread.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); }); } - ContextEvent::SummaryChanged => { + TextThreadEvent::SummaryChanged => { cx.emit(EditorEvent::TitleChanged); - self.context.update(cx, |context, cx| { - context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); + self.text_thread.update(cx, |text_thread, cx| { + text_thread.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); }); } - ContextEvent::SummaryGenerated => {} - ContextEvent::PathChanged { .. } => {} - ContextEvent::StartedThoughtProcess(range) => { + TextThreadEvent::SummaryGenerated => {} + TextThreadEvent::PathChanged { .. } => {} + TextThreadEvent::StartedThoughtProcess(range) => { let creases = self.insert_thought_process_output_sections( [( ThoughtProcessOutputSection { @@ -600,7 +609,7 @@ impl TextThreadEditor { ); self.pending_thought_process = Some((creases[0], range.start)); } - ContextEvent::EndedThoughtProcess(end) => { + TextThreadEvent::EndedThoughtProcess(end) => { if let Some((crease_id, start)) = self.pending_thought_process.take() { self.editor.update(cx, |editor, cx| { let multi_buffer_snapshot = editor.buffer().read(cx).snapshot(cx); @@ -626,7 +635,7 @@ impl TextThreadEditor { ); } } - ContextEvent::StreamedCompletion => { + TextThreadEvent::StreamedCompletion => { self.editor.update(cx, |editor, cx| { if let Some(scroll_position) = self.scroll_position { let snapshot = editor.snapshot(window, cx); @@ -641,7 +650,7 @@ impl TextThreadEditor { } }); } - ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => { + TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => { self.editor.update(cx, |editor, cx| { let buffer = editor.buffer().read(cx).snapshot(cx); let (&excerpt_id, _, _) = buffer.as_singleton().unwrap(); @@ -657,12 +666,12 @@ impl TextThreadEditor { updated.iter().map(|command| { let workspace = self.workspace.clone(); let confirm_command = Arc::new({ - let context_editor = context_editor.clone(); + let text_thread_editor = text_thread_editor.clone(); let command = command.clone(); move |window: &mut Window, cx: &mut App| { - context_editor - .update(cx, |context_editor, cx| { - context_editor.run_command( + text_thread_editor + .update(cx, |text_thread_editor, cx| { + text_thread_editor.run_command( command.source_range.clone(), &command.name, &command.arguments, @@ -712,17 +721,17 @@ impl TextThreadEditor { ); }) } - ContextEvent::InvokedSlashCommandChanged { command_id } => { + TextThreadEvent::InvokedSlashCommandChanged { command_id } => { self.update_invoked_slash_command(*command_id, window, cx); } - ContextEvent::SlashCommandOutputSectionAdded { section } => { + TextThreadEvent::SlashCommandOutputSectionAdded { section } => { self.insert_slash_command_output_sections([section.clone()], false, window, cx); } - ContextEvent::Operation(_) => {} - ContextEvent::ShowAssistError(error_message) => { + TextThreadEvent::Operation(_) => {} + TextThreadEvent::ShowAssistError(error_message) => { self.last_error = Some(AssistError::Message(error_message.clone())); } - ContextEvent::ShowPaymentRequiredError => { + TextThreadEvent::ShowPaymentRequiredError => { self.last_error = Some(AssistError::PaymentRequired); } } @@ -735,14 +744,14 @@ impl TextThreadEditor { cx: &mut Context, ) { if let Some(invoked_slash_command) = - self.context.read(cx).invoked_slash_command(&command_id) + self.text_thread.read(cx).invoked_slash_command(&command_id) && let InvokedSlashCommandStatus::Finished = invoked_slash_command.status { let run_commands_in_ranges = invoked_slash_command.run_commands_in_ranges.clone(); for range in run_commands_in_ranges { - let commands = self.context.update(cx, |context, cx| { - context.reparse(cx); - context + let commands = self.text_thread.update(cx, |text_thread, cx| { + text_thread.reparse(cx); + text_thread .pending_commands_for_range(range.clone(), cx) .to_vec() }); @@ -763,7 +772,7 @@ impl TextThreadEditor { self.editor.update(cx, |editor, cx| { if let Some(invoked_slash_command) = - self.context.read(cx).invoked_slash_command(&command_id) + self.text_thread.read(cx).invoked_slash_command(&command_id) { if let InvokedSlashCommandStatus::Finished = invoked_slash_command.status { let buffer = editor.buffer().read(cx).snapshot(cx); @@ -790,7 +799,7 @@ impl TextThreadEditor { let buffer = editor.buffer().read(cx).snapshot(cx); let (&excerpt_id, _buffer_id, _buffer_snapshot) = buffer.as_singleton().unwrap(); - let context = self.context.downgrade(); + let context = self.text_thread.downgrade(); let range = buffer .anchor_range_in_excerpt(excerpt_id, invoked_slash_command.range.clone()) .unwrap(); @@ -1020,7 +1029,7 @@ impl TextThreadEditor { let render_block = |message: MessageMetadata| -> RenderBlock { Arc::new({ - let context = self.context.clone(); + let text_thread = self.text_thread.clone(); move |cx| { let message_id = MessageId(message.timestamp); @@ -1093,10 +1102,10 @@ impl TextThreadEditor { ) }) .on_click({ - let context = context.clone(); + let text_thread = text_thread.clone(); move |_, _window, cx| { - context.update(cx, |context, cx| { - context.cycle_message_roles( + text_thread.update(cx, |text_thread, cx| { + text_thread.cycle_message_roles( HashSet::from_iter(Some(message_id)), cx, ) @@ -1158,11 +1167,11 @@ impl TextThreadEditor { .icon_position(IconPosition::Start) .tooltip(Tooltip::text("View Details")) .on_click({ - let context = context.clone(); + let text_thread = text_thread.clone(); let error = error.clone(); move |_, _window, cx| { - context.update(cx, |_, cx| { - cx.emit(ContextEvent::ShowAssistError( + text_thread.update(cx, |_, cx| { + cx.emit(TextThreadEvent::ShowAssistError( error.clone(), )); }); @@ -1205,7 +1214,7 @@ impl TextThreadEditor { }; let mut new_blocks = vec![]; let mut block_index_to_message = vec![]; - for message in self.context.read(cx).messages(cx) { + for message in self.text_thread.read(cx).messages(cx) { if blocks_to_remove.remove(&message.id).is_some() { // This is an old message that we might modify. let Some((meta, block_id)) = old_blocks.get_mut(&message.id) else { @@ -1246,18 +1255,18 @@ impl TextThreadEditor { ) -> Option<(String, bool)> { const CODE_FENCE_DELIMITER: &str = "```"; - let context_editor = context_editor_view.read(cx).editor.clone(); - context_editor.update(cx, |context_editor, cx| { - let display_map = context_editor.display_snapshot(cx); - if context_editor + let text_thread_editor = context_editor_view.read(cx).editor.clone(); + text_thread_editor.update(cx, |text_thread_editor, cx| { + let display_map = text_thread_editor.display_snapshot(cx); + if text_thread_editor .selections .newest::(&display_map) .is_empty() { - let snapshot = context_editor.buffer().read(cx).snapshot(cx); + let snapshot = text_thread_editor.buffer().read(cx).snapshot(cx); let (_, _, snapshot) = snapshot.as_singleton()?; - let head = context_editor + let head = text_thread_editor .selections .newest::(&display_map) .head(); @@ -1277,8 +1286,8 @@ impl TextThreadEditor { (!text.is_empty()).then_some((text, true)) } else { - let selection = context_editor.selections.newest_adjusted(&display_map); - let buffer = context_editor.buffer().read(cx).snapshot(cx); + let selection = text_thread_editor.selections.newest_adjusted(&display_map); + let buffer = text_thread_editor.buffer().read(cx).snapshot(cx); let selected_text = buffer.text_for_range(selection.range()).collect::(); (!selected_text.is_empty()).then_some((selected_text, false)) @@ -1296,7 +1305,7 @@ impl TextThreadEditor { return; }; let Some(context_editor_view) = - agent_panel_delegate.active_context_editor(workspace, window, cx) + agent_panel_delegate.active_text_thread_editor(workspace, window, cx) else { return; }; @@ -1324,7 +1333,7 @@ impl TextThreadEditor { let result = maybe!({ let agent_panel_delegate = ::try_global(cx)?; let context_editor_view = - agent_panel_delegate.active_context_editor(workspace, window, cx)?; + agent_panel_delegate.active_text_thread_editor(workspace, window, cx)?; Self::get_selection_or_code_block(&context_editor_view, cx) }); let Some((text, is_code_block)) = result else { @@ -1361,7 +1370,7 @@ impl TextThreadEditor { return; }; let Some(context_editor_view) = - agent_panel_delegate.active_context_editor(workspace, window, cx) + agent_panel_delegate.active_text_thread_editor(workspace, window, cx) else { return; }; @@ -1622,29 +1631,33 @@ impl TextThreadEditor { ) }); - let context = self.context.read(cx); + let text_thread = self.text_thread.read(cx); let mut text = String::new(); // If selection is empty, we want to copy the entire line if selection.range().is_empty() { - let snapshot = context.buffer().read(cx).snapshot(); + let snapshot = text_thread.buffer().read(cx).snapshot(); let point = snapshot.offset_to_point(selection.range().start); selection.start = snapshot.point_to_offset(Point::new(point.row, 0)); selection.end = snapshot .point_to_offset(cmp::min(Point::new(point.row + 1, 0), snapshot.max_point())); - for chunk in context.buffer().read(cx).text_for_range(selection.range()) { + for chunk in text_thread + .buffer() + .read(cx) + .text_for_range(selection.range()) + { text.push_str(chunk); } } else { - for message in context.messages(cx) { + for message in text_thread.messages(cx) { if message.offset_range.start >= selection.range().end { break; } else if message.offset_range.end >= selection.range().start { let range = cmp::max(message.offset_range.start, selection.range().start) ..cmp::min(message.offset_range.end, selection.range().end); if !range.is_empty() { - for chunk in context.buffer().read(cx).text_for_range(range) { + for chunk in text_thread.buffer().read(cx).text_for_range(range) { text.push_str(chunk); } if message.offset_range.end < selection.range().end { @@ -1755,7 +1768,7 @@ impl TextThreadEditor { }); }); - self.context.update(cx, |context, cx| { + self.text_thread.update(cx, |text_thread, cx| { for image in images { let Some(render_image) = image.to_image_data(cx.svg_renderer()).log_err() else { @@ -1765,7 +1778,7 @@ impl TextThreadEditor { let image_task = LanguageModelImage::from_image(Arc::new(image), cx).shared(); for image_position in image_positions.iter() { - context.insert_content( + text_thread.insert_content( Content::Image { anchor: image_position.text_anchor, image_id, @@ -1786,7 +1799,7 @@ impl TextThreadEditor { let excerpt_id = *buffer.as_singleton().unwrap().0; let old_blocks = std::mem::take(&mut self.image_blocks); let new_blocks = self - .context + .text_thread .read(cx) .contents(cx) .map( @@ -1834,36 +1847,36 @@ impl TextThreadEditor { } fn split(&mut self, _: &Split, _window: &mut Window, cx: &mut Context) { - self.context.update(cx, |context, cx| { + self.text_thread.update(cx, |text_thread, cx| { let selections = self.editor.read(cx).selections.disjoint_anchors_arc(); for selection in selections.as_ref() { let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx); let range = selection .map(|endpoint| endpoint.to_offset(&buffer)) .range(); - context.split_message(range, cx); + text_thread.split_message(range, cx); } }); } fn save(&mut self, _: &Save, _window: &mut Window, cx: &mut Context) { - self.context.update(cx, |context, cx| { - context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx) + self.text_thread.update(cx, |text_thread, cx| { + text_thread.save(Some(Duration::from_millis(500)), self.fs.clone(), cx) }); } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary().or_default() + self.text_thread.read(cx).summary().or_default() } pub fn regenerate_summary(&mut self, cx: &mut Context) { - self.context - .update(cx, |context, cx| context.summarize(true, cx)); + self.text_thread + .update(cx, |text_thread, cx| text_thread.summarize(true, cx)); } fn render_remaining_tokens(&self, cx: &App) -> Option> { let (token_count_color, token_count, max_token_count, tooltip) = - match token_state(&self.context, cx)? { + match token_state(&self.text_thread, cx)? { TokenState::NoTokensLeft { max_token_count, token_count, @@ -1911,7 +1924,7 @@ impl TextThreadEditor { fn render_send_button(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let focus_handle = self.focus_handle(cx); - let (style, tooltip) = match token_state(&self.context, cx) { + let (style, tooltip) = match token_state(&self.text_thread, cx) { Some(TokenState::NoTokensLeft { .. }) => ( ButtonStyle::Tinted(TintColor::Error), Some(Tooltip::text("Token limit reached")(window, cx)), @@ -1986,7 +1999,7 @@ impl TextThreadEditor { } fn render_burn_mode_toggle(&self, cx: &mut Context) -> Option { - let context = self.context().read(cx); + let text_thread = self.text_thread().read(cx); let active_model = LanguageModelRegistry::read_global(cx) .default_model() .map(|default| default.model)?; @@ -1994,7 +2007,7 @@ impl TextThreadEditor { return None; } - let active_completion_mode = context.completion_mode(); + let active_completion_mode = text_thread.completion_mode(); let burn_mode_enabled = active_completion_mode == CompletionMode::Burn; let icon = if burn_mode_enabled { IconName::ZedBurnModeOn @@ -2009,8 +2022,8 @@ impl TextThreadEditor { .toggle_state(burn_mode_enabled) .selected_icon_color(Color::Error) .on_click(cx.listener(move |this, _event, _window, cx| { - this.context().update(cx, |context, _cx| { - context.set_completion_mode(match active_completion_mode { + this.text_thread().update(cx, |text_thread, _cx| { + text_thread.set_completion_mode(match active_completion_mode { CompletionMode::Burn => CompletionMode::Normal, CompletionMode::Normal => CompletionMode::Burn, }); @@ -2637,10 +2650,10 @@ impl FollowableItem for TextThreadEditor { } fn to_state_proto(&self, window: &Window, cx: &App) -> Option { - let context = self.context.read(cx); + let text_thread = self.text_thread.read(cx); Some(proto::view::Variant::ContextEditor( proto::view::ContextEditor { - context_id: context.id().to_proto(), + context_id: text_thread.id().to_proto(), editor: if let Some(proto::view::Variant::Editor(proto)) = self.editor.read(cx).to_state_proto(window, cx) { @@ -2666,22 +2679,22 @@ impl FollowableItem for TextThreadEditor { unreachable!() }; - let context_id = ContextId::from_proto(state.context_id); + let text_thread_id = TextThreadId::from_proto(state.context_id); let editor_state = state.editor?; let project = workspace.read(cx).project().clone(); let agent_panel_delegate = ::try_global(cx)?; - let context_editor_task = workspace.update(cx, |workspace, cx| { - agent_panel_delegate.open_remote_context(workspace, context_id, window, cx) + let text_thread_editor_task = workspace.update(cx, |workspace, cx| { + agent_panel_delegate.open_remote_text_thread(workspace, text_thread_id, window, cx) }); Some(window.spawn(cx, async move |cx| { - let context_editor = context_editor_task.await?; - context_editor - .update_in(cx, |context_editor, window, cx| { - context_editor.remote_id = Some(id); - context_editor.editor.update(cx, |editor, cx| { + let text_thread_editor = text_thread_editor_task.await?; + text_thread_editor + .update_in(cx, |text_thread_editor, window, cx| { + text_thread_editor.remote_id = Some(id); + text_thread_editor.editor.update(cx, |editor, cx| { editor.apply_update_proto( &project, proto::update_view::Variant::Editor(proto::update_view::Editor { @@ -2698,7 +2711,7 @@ impl FollowableItem for TextThreadEditor { }) })? .await?; - Ok(context_editor) + Ok(text_thread_editor) })) } @@ -2745,7 +2758,7 @@ impl FollowableItem for TextThreadEditor { } fn dedup(&self, existing: &Self, _window: &Window, cx: &App) -> Option { - if existing.context.read(cx).id() == self.context.read(cx).id() { + if existing.text_thread.read(cx).id() == self.text_thread.read(cx).id() { Some(item::Dedup::KeepExisting) } else { None @@ -2757,17 +2770,17 @@ enum PendingSlashCommand {} fn invoked_slash_command_fold_placeholder( command_id: InvokedSlashCommandId, - context: WeakEntity, + text_thread: WeakEntity, ) -> FoldPlaceholder { FoldPlaceholder { constrain_width: false, merge_adjacent: false, render: Arc::new(move |fold_id, _, cx| { - let Some(context) = context.upgrade() else { + let Some(text_thread) = text_thread.upgrade() else { return Empty.into_any(); }; - let Some(command) = context.read(cx).invoked_slash_command(&command_id) else { + let Some(command) = text_thread.read(cx).invoked_slash_command(&command_id) else { return Empty.into_any(); }; @@ -2808,14 +2821,15 @@ enum TokenState { }, } -fn token_state(context: &Entity, cx: &App) -> Option { +fn token_state(text_thread: &Entity, cx: &App) -> Option { const WARNING_TOKEN_THRESHOLD: f32 = 0.8; let model = LanguageModelRegistry::read_global(cx) .default_model()? .model; - let token_count = context.read(cx).token_count()?; - let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into()); + let token_count = text_thread.read(cx).token_count()?; + let max_token_count = + model.max_token_count_for_mode(text_thread.read(cx).completion_mode().into()); let token_state = if max_token_count.saturating_sub(token_count) == 0 { TokenState::NoTokensLeft { max_token_count, @@ -2927,7 +2941,7 @@ mod tests { #[gpui::test] async fn test_copy_paste_whole_message(cx: &mut TestAppContext) { - let (context, context_editor, mut cx) = setup_context_editor_text(vec![ + let (context, text_thread_editor, mut cx) = setup_text_thread_editor_text(vec![ (Role::User, "What is the Zed editor?"), ( Role::Assistant, @@ -2937,8 +2951,8 @@ mod tests { ],cx).await; // Select & Copy whole user message - assert_copy_paste_context_editor( - &context_editor, + assert_copy_paste_text_thread_editor( + &text_thread_editor, message_range(&context, 0, &mut cx), indoc! {" What is the Zed editor? @@ -2949,8 +2963,8 @@ mod tests { ); // Select & Copy whole assistant message - assert_copy_paste_context_editor( - &context_editor, + assert_copy_paste_text_thread_editor( + &text_thread_editor, message_range(&context, 1, &mut cx), indoc! {" What is the Zed editor? @@ -2964,7 +2978,7 @@ mod tests { #[gpui::test] async fn test_copy_paste_no_selection(cx: &mut TestAppContext) { - let (context, context_editor, mut cx) = setup_context_editor_text( + let (context, text_thread_editor, mut cx) = setup_text_thread_editor_text( vec![ (Role::User, "user1"), (Role::Assistant, "assistant1"), @@ -2977,8 +2991,8 @@ mod tests { // Copy and paste first assistant message let message_2_range = message_range(&context, 1, &mut cx); - assert_copy_paste_context_editor( - &context_editor, + assert_copy_paste_text_thread_editor( + &text_thread_editor, message_2_range.start..message_2_range.start, indoc! {" user1 @@ -2991,8 +3005,8 @@ mod tests { // Copy and cut second assistant message let message_3_range = message_range(&context, 2, &mut cx); - assert_copy_paste_context_editor( - &context_editor, + assert_copy_paste_text_thread_editor( + &text_thread_editor, message_3_range.start..message_3_range.start, indoc! {" user1 @@ -3079,29 +3093,29 @@ mod tests { } } - async fn setup_context_editor_text( + async fn setup_text_thread_editor_text( messages: Vec<(Role, &str)>, cx: &mut TestAppContext, ) -> ( - Entity, + Entity, Entity, VisualTestContext, ) { cx.update(init_test); let fs = FakeFs::new(cx.executor()); - let context = create_context_with_messages(messages, cx); + let text_thread = create_text_thread_with_messages(messages, cx); let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); let workspace = window.root(cx).unwrap(); let mut cx = VisualTestContext::from_window(*window, cx); - let context_editor = window + let text_thread_editor = window .update(&mut cx, |_, window, cx| { cx.new(|cx| { - TextThreadEditor::for_context( - context.clone(), + TextThreadEditor::for_text_thread( + text_thread.clone(), fs, workspace.downgrade(), project, @@ -3113,59 +3127,59 @@ mod tests { }) .unwrap(); - (context, context_editor, cx) + (text_thread, text_thread_editor, cx) } fn message_range( - context: &Entity, + text_thread: &Entity, message_ix: usize, cx: &mut TestAppContext, ) -> Range { - context.update(cx, |context, cx| { - context + text_thread.update(cx, |text_thread, cx| { + text_thread .messages(cx) .nth(message_ix) .unwrap() .anchor_range - .to_offset(&context.buffer().read(cx).snapshot()) + .to_offset(&text_thread.buffer().read(cx).snapshot()) }) } - fn assert_copy_paste_context_editor( - context_editor: &Entity, + fn assert_copy_paste_text_thread_editor( + text_thread_editor: &Entity, range: Range, expected_text: &str, cx: &mut VisualTestContext, ) { - context_editor.update_in(cx, |context_editor, window, cx| { - context_editor.editor.update(cx, |editor, cx| { + text_thread_editor.update_in(cx, |text_thread_editor, window, cx| { + text_thread_editor.editor.update(cx, |editor, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([range]) }); }); - context_editor.copy(&Default::default(), window, cx); + text_thread_editor.copy(&Default::default(), window, cx); - context_editor.editor.update(cx, |editor, cx| { + text_thread_editor.editor.update(cx, |editor, cx| { editor.move_to_end(&Default::default(), window, cx); }); - context_editor.paste(&Default::default(), window, cx); + text_thread_editor.paste(&Default::default(), window, cx); - context_editor.editor.update(cx, |editor, cx| { + text_thread_editor.editor.update(cx, |editor, cx| { assert_eq!(editor.text(cx), expected_text); }); }); } - fn create_context_with_messages( + fn create_text_thread_with_messages( mut messages: Vec<(Role, &str)>, cx: &mut TestAppContext, - ) -> Entity { + ) -> Entity { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); cx.new(|cx| { - let mut context = AssistantContext::local( + let mut text_thread = TextThread::local( registry, None, None, @@ -3173,33 +3187,33 @@ mod tests { Arc::new(SlashCommandWorkingSet::default()), cx, ); - let mut message_1 = context.messages(cx).next().unwrap(); + let mut message_1 = text_thread.messages(cx).next().unwrap(); let (role, text) = messages.remove(0); loop { if role == message_1.role { - context.buffer().update(cx, |buffer, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(message_1.offset_range, text)], None, cx); }); break; } let mut ids = HashSet::default(); ids.insert(message_1.id); - context.cycle_message_roles(ids, cx); - message_1 = context.messages(cx).next().unwrap(); + text_thread.cycle_message_roles(ids, cx); + message_1 = text_thread.messages(cx).next().unwrap(); } let mut last_message_id = message_1.id; for (role, text) in messages { - context.insert_message_after(last_message_id, role, MessageStatus::Done, cx); - let message = context.messages(cx).last().unwrap(); + text_thread.insert_message_after(last_message_id, role, MessageStatus::Done, cx); + let message = text_thread.messages(cx).last().unwrap(); last_message_id = message.id; - context.buffer().update(cx, |buffer, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(message.offset_range, text)], None, cx); }) } - context + text_thread }) } diff --git a/crates/agent_ui/src/ui/context_pill.rs b/crates/agent_ui/src/ui/context_pill.rs index 43d3799d697e28d43c71fc6e6e77cc058eaec5b2..89bf618a16d3fb8e7abc5afaf34ee6e8bb43ab67 100644 --- a/crates/agent_ui/src/ui/context_pill.rs +++ b/crates/agent_ui/src/ui/context_pill.rs @@ -497,9 +497,9 @@ impl AddedContext { icon_path: None, status: ContextStatus::Ready, render_hover: { - let context = handle.context.clone(); + let text_thread = handle.text_thread.clone(); Some(Rc::new(move |_, cx| { - let text = context.read(cx).to_xml(cx); + let text = text_thread.read(cx).to_xml(cx); ContextPillHover::new_text(text.into(), cx).into() })) }, diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_text_thread/Cargo.toml similarity index 95% rename from crates/assistant_context/Cargo.toml rename to crates/assistant_text_thread/Cargo.toml index 2d3e8bc4080a314c480bb11e459a745cb7ce6704..8dfdfa3828340217456088a246eee5b1568a7a77 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_text_thread/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "assistant_context" +name = "assistant_text_thread" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/assistant_context.rs" +path = "src/assistant_text_thread.rs" [features] test-support = [] diff --git a/crates/assistant_context/LICENSE-GPL b/crates/assistant_text_thread/LICENSE-GPL similarity index 100% rename from crates/assistant_context/LICENSE-GPL rename to crates/assistant_text_thread/LICENSE-GPL diff --git a/crates/assistant_text_thread/src/assistant_text_thread.rs b/crates/assistant_text_thread/src/assistant_text_thread.rs new file mode 100644 index 0000000000000000000000000000000000000000..7eab9800d5d6f43ba8eabec0682961e073781ace --- /dev/null +++ b/crates/assistant_text_thread/src/assistant_text_thread.rs @@ -0,0 +1,15 @@ +#[cfg(test)] +mod assistant_text_thread_tests; +mod text_thread; +mod text_thread_store; + +pub use crate::text_thread::*; +pub use crate::text_thread_store::*; + +use client::Client; +use gpui::App; +use std::sync::Arc; + +pub fn init(client: Arc, _: &mut App) { + text_thread_store::init(&client.into()); +} diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_text_thread/src/assistant_text_thread_tests.rs similarity index 75% rename from crates/assistant_context/src/assistant_context_tests.rs rename to crates/assistant_text_thread/src/assistant_text_thread_tests.rs index 2d987f9f845b471438cfb3eb0667fbc36161c53c..fbd5dcafa6e142538f1f5821bc9e0a89ccbfd881 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_text_thread/src/assistant_text_thread_tests.rs @@ -1,6 +1,6 @@ use crate::{ - AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary, - InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, + CacheStatus, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, TextThread, + TextThreadEvent, TextThreadId, TextThreadOperation, TextThreadSummary, }; use anyhow::Result; use assistant_slash_command::{ @@ -47,8 +47,8 @@ fn test_inserting_and_removing_messages(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry, None, None, @@ -57,21 +57,21 @@ fn test_inserting_and_removing_messages(cx: &mut App) { cx, ) }); - let buffer = context.read(cx).buffer.clone(); + let buffer = text_thread.read(cx).buffer().clone(); - let message_1 = context.read(cx).message_anchors[0].clone(); + let message_1 = text_thread.read(cx).message_anchors[0].clone(); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![(message_1.id, Role::User, 0..0)] ); - let message_2 = context.update(cx, |context, cx| { + let message_2 = text_thread.update(cx, |context, cx| { context .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap() }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..1), (message_2.id, Role::Assistant, 1..1) @@ -82,20 +82,20 @@ fn test_inserting_and_removing_messages(cx: &mut App) { buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..2), (message_2.id, Role::Assistant, 2..3) ] ); - let message_3 = context.update(cx, |context, cx| { + let message_3 = text_thread.update(cx, |context, cx| { context .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..2), (message_2.id, Role::Assistant, 2..4), @@ -103,13 +103,13 @@ fn test_inserting_and_removing_messages(cx: &mut App) { ] ); - let message_4 = context.update(cx, |context, cx| { + let message_4 = text_thread.update(cx, |context, cx| { context .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..2), (message_2.id, Role::Assistant, 2..4), @@ -122,7 +122,7 @@ fn test_inserting_and_removing_messages(cx: &mut App) { buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..2), (message_2.id, Role::Assistant, 2..4), @@ -134,7 +134,7 @@ fn test_inserting_and_removing_messages(cx: &mut App) { // Deleting across message boundaries merges the messages. buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..3), (message_3.id, Role::User, 3..4), @@ -144,7 +144,7 @@ fn test_inserting_and_removing_messages(cx: &mut App) { // Undoing the deletion should also undo the merge. buffer.update(cx, |buffer, cx| buffer.undo(cx)); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..2), (message_2.id, Role::Assistant, 2..4), @@ -156,7 +156,7 @@ fn test_inserting_and_removing_messages(cx: &mut App) { // Redoing the deletion should also redo the merge. buffer.update(cx, |buffer, cx| buffer.redo(cx)); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..3), (message_3.id, Role::User, 3..4), @@ -164,13 +164,13 @@ fn test_inserting_and_removing_messages(cx: &mut App) { ); // Ensure we can still insert after a merged message. - let message_5 = context.update(cx, |context, cx| { + let message_5 = text_thread.update(cx, |context, cx| { context .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) .unwrap() }); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..3), (message_5.id, Role::System, 3..4), @@ -186,8 +186,8 @@ fn test_message_splitting(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry.clone(), None, None, @@ -196,11 +196,11 @@ fn test_message_splitting(cx: &mut App) { cx, ) }); - let buffer = context.read(cx).buffer.clone(); + let buffer = text_thread.read(cx).buffer().clone(); - let message_1 = context.read(cx).message_anchors[0].clone(); + let message_1 = text_thread.read(cx).message_anchors[0].clone(); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![(message_1.id, Role::User, 0..0)] ); @@ -208,26 +208,28 @@ fn test_message_splitting(cx: &mut App) { buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) }); - let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let (_, message_2) = + text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx)); let message_2 = message_2.unwrap(); // We recycle newlines in the middle of a split message assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_2.id, Role::User, 4..16), ] ); - let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let (_, message_3) = + text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx)); let message_3 = message_3.unwrap(); // We don't recycle newlines at the end of a split message assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_3.id, Role::User, 4..5), @@ -235,11 +237,12 @@ fn test_message_splitting(cx: &mut App) { ] ); - let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let (_, message_4) = + text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx)); let message_4 = message_4.unwrap(); assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_3.id, Role::User, 4..5), @@ -248,11 +251,12 @@ fn test_message_splitting(cx: &mut App) { ] ); - let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let (_, message_5) = + text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx)); let message_5 = message_5.unwrap(); assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_3.id, Role::User, 4..5), @@ -263,12 +267,12 @@ fn test_message_splitting(cx: &mut App) { ); let (message_6, message_7) = - context.update(cx, |context, cx| context.split_message(14..16, cx)); + text_thread.update(cx, |text_thread, cx| text_thread.split_message(14..16, cx)); let message_6 = message_6.unwrap(); let message_7 = message_7.unwrap(); assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_3.id, Role::User, 4..5), @@ -287,8 +291,8 @@ fn test_messages_for_offsets(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry, None, None, @@ -297,32 +301,32 @@ fn test_messages_for_offsets(cx: &mut App) { cx, ) }); - let buffer = context.read(cx).buffer.clone(); + let buffer = text_thread.read(cx).buffer().clone(); - let message_1 = context.read(cx).message_anchors[0].clone(); + let message_1 = text_thread.read(cx).message_anchors[0].clone(); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![(message_1.id, Role::User, 0..0)] ); buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); - let message_2 = context - .update(cx, |context, cx| { - context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) + let message_2 = text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); - let message_3 = context - .update(cx, |context, cx| { - context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + let message_3 = text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_2.id, Role::User, 4..8), @@ -331,22 +335,22 @@ fn test_messages_for_offsets(cx: &mut App) { ); assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 9], cx), + message_ids_for_offsets(&text_thread, &[0, 4, 9], cx), [message_1.id, message_2.id, message_3.id] ); assert_eq!( - message_ids_for_offsets(&context, &[0, 1, 11], cx), + message_ids_for_offsets(&text_thread, &[0, 1, 11], cx), [message_1.id, message_3.id] ); - let message_4 = context - .update(cx, |context, cx| { - context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) + let message_4 = text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n"); assert_eq!( - messages(&context, cx), + messages(&text_thread, cx), vec![ (message_1.id, Role::User, 0..4), (message_2.id, Role::User, 4..8), @@ -355,12 +359,12 @@ fn test_messages_for_offsets(cx: &mut App) { ] ); assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 8, 12], cx), + message_ids_for_offsets(&text_thread, &[0, 4, 8, 12], cx), [message_1.id, message_2.id, message_3.id, message_4.id] ); fn message_ids_for_offsets( - context: &Entity, + context: &Entity, offsets: &[usize], cx: &App, ) -> Vec { @@ -398,8 +402,8 @@ async fn test_slash_commands(cx: &mut TestAppContext) { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry.clone(), None, None, @@ -417,19 +421,19 @@ async fn test_slash_commands(cx: &mut TestAppContext) { } let context_ranges = Rc::new(RefCell::new(ContextRanges::default())); - context.update(cx, |_, cx| { - cx.subscribe(&context, { + text_thread.update(cx, |_, cx| { + cx.subscribe(&text_thread, { let context_ranges = context_ranges.clone(); - move |context, _, event, _| { + move |text_thread, _, event, _| { let mut context_ranges = context_ranges.borrow_mut(); match event { - ContextEvent::InvokedSlashCommandChanged { command_id } => { - let command = context.invoked_slash_command(command_id).unwrap(); + TextThreadEvent::InvokedSlashCommandChanged { command_id } => { + let command = text_thread.invoked_slash_command(command_id).unwrap(); context_ranges .command_outputs .insert(*command_id, command.range.clone()); } - ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => { + TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => { for range in removed { context_ranges.parsed_commands.remove(range); } @@ -439,7 +443,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) { .insert(command.source_range.clone()); } } - ContextEvent::SlashCommandOutputSectionAdded { section } => { + TextThreadEvent::SlashCommandOutputSectionAdded { section } => { context_ranges.output_sections.insert(section.range.clone()); } _ => {} @@ -449,7 +453,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) { .detach(); }); - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone()); // Insert a slash command buffer.update(cx, |buffer, cx| { @@ -508,9 +512,9 @@ async fn test_slash_commands(cx: &mut TestAppContext) { ); let (command_output_tx, command_output_rx) = mpsc::unbounded(); - context.update(cx, |context, cx| { - let command_source_range = context.parsed_slash_commands[0].source_range.clone(); - context.insert_command_output( + text_thread.update(cx, |text_thread, cx| { + let command_source_range = text_thread.parsed_slash_commands[0].source_range.clone(); + text_thread.insert_command_output( command_source_range, "file", Task::ready(Ok(command_output_rx.boxed())), @@ -670,8 +674,8 @@ async fn test_serialization(cx: &mut TestAppContext) { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry.clone(), None, None, @@ -680,15 +684,15 @@ async fn test_serialization(cx: &mut TestAppContext) { cx, ) }); - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); - let message_1 = context.update(cx, |context, cx| { - context + let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone()); + let message_0 = text_thread.read_with(cx, |text_thread, _| text_thread.message_anchors[0].id); + let message_1 = text_thread.update(cx, |text_thread, cx| { + text_thread .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) .unwrap() }); - let message_2 = context.update(cx, |context, cx| { - context + let message_2 = text_thread.update(cx, |text_thread, cx| { + text_thread .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) .unwrap() }); @@ -696,15 +700,15 @@ async fn test_serialization(cx: &mut TestAppContext) { buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); buffer.finalize_last_transaction(); }); - let _message_3 = context.update(cx, |context, cx| { - context + let _message_3 = text_thread.update(cx, |text_thread, cx| { + text_thread .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) .unwrap() }); buffer.update(cx, |buffer, cx| buffer.undo(cx)); assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); assert_eq!( - cx.read(|cx| messages(&context, cx)), + cx.read(|cx| messages(&text_thread, cx)), [ (message_0, Role::User, 0..2), (message_1.id, Role::Assistant, 2..6), @@ -712,9 +716,9 @@ async fn test_serialization(cx: &mut TestAppContext) { ] ); - let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx)); + let serialized_context = text_thread.read_with(cx, |text_thread, cx| text_thread.serialize(cx)); let deserialized_context = cx.new(|cx| { - AssistantContext::deserialize( + TextThread::deserialize( serialized_context, Path::new("").into(), registry.clone(), @@ -726,7 +730,7 @@ async fn test_serialization(cx: &mut TestAppContext) { ) }); let deserialized_buffer = - deserialized_context.read_with(cx, |context, _| context.buffer.clone()); + deserialized_context.read_with(cx, |text_thread, _| text_thread.buffer().clone()); assert_eq!( deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n" @@ -762,14 +766,14 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone())); let network = Arc::new(Mutex::new(Network::new(rng.clone()))); - let mut contexts = Vec::new(); + let mut text_threads = Vec::new(); let num_peers = rng.random_range(min_peers..=max_peers); - let context_id = ContextId::new(); + let context_id = TextThreadId::new(); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); for i in 0..num_peers { let context = cx.new(|cx| { - AssistantContext::new( + TextThread::new( context_id.clone(), ReplicaId::new(i as u16), language::Capability::ReadWrite, @@ -786,7 +790,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std cx.subscribe(&context, { let network = network.clone(); move |_, event, _| { - if let ContextEvent::Operation(op) = event { + if let TextThreadEvent::Operation(op) = event { network .lock() .broadcast(ReplicaId::new(i as u16), vec![op.to_proto()]); @@ -796,7 +800,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std .detach(); }); - contexts.push(context); + text_threads.push(context); network.lock().add_peer(ReplicaId::new(i as u16)); } @@ -806,30 +810,30 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std || !network.lock().is_idle() || network.lock().contains_disconnected_peers() { - let context_index = rng.random_range(0..contexts.len()); - let context = &contexts[context_index]; + let context_index = rng.random_range(0..text_threads.len()); + let text_thread = &text_threads[context_index]; match rng.random_range(0..100) { 0..=29 if mutation_count > 0 => { log::info!("Context {}: edit buffer", context_index); - context.update(cx, |context, cx| { - context - .buffer + text_thread.update(cx, |text_thread, cx| { + text_thread + .buffer() .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx)); }); mutation_count -= 1; } 30..=44 if mutation_count > 0 => { - context.update(cx, |context, cx| { - let range = context.buffer.read(cx).random_byte_range(0, &mut rng); + text_thread.update(cx, |text_thread, cx| { + let range = text_thread.buffer().read(cx).random_byte_range(0, &mut rng); log::info!("Context {}: split message at {:?}", context_index, range); - context.split_message(range, cx); + text_thread.split_message(range, cx); }); mutation_count -= 1; } 45..=59 if mutation_count > 0 => { - context.update(cx, |context, cx| { - if let Some(message) = context.messages(cx).choose(&mut rng) { + text_thread.update(cx, |text_thread, cx| { + if let Some(message) = text_thread.messages(cx).choose(&mut rng) { let role = *[Role::User, Role::Assistant, Role::System] .choose(&mut rng) .unwrap(); @@ -839,13 +843,13 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std message.id, role ); - context.insert_message_after(message.id, role, MessageStatus::Done, cx); + text_thread.insert_message_after(message.id, role, MessageStatus::Done, cx); } }); mutation_count -= 1; } 60..=74 if mutation_count > 0 => { - context.update(cx, |context, cx| { + text_thread.update(cx, |text_thread, cx| { let command_text = "/".to_string() + slash_commands .command_names() @@ -854,7 +858,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std .clone() .as_ref(); - let command_range = context.buffer.update(cx, |buffer, cx| { + let command_range = text_thread.buffer().update(cx, |buffer, cx| { let offset = buffer.random_byte_range(0, &mut rng).start; buffer.edit( [(offset..offset, format!("\n{}\n", command_text))], @@ -908,9 +912,15 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std events.len() ); - let command_range = context.buffer.read(cx).anchor_after(command_range.start) - ..context.buffer.read(cx).anchor_after(command_range.end); - context.insert_command_output( + let command_range = text_thread + .buffer() + .read(cx) + .anchor_after(command_range.start) + ..text_thread + .buffer() + .read(cx) + .anchor_after(command_range.end); + text_thread.insert_command_output( command_range, "/command", Task::ready(Ok(stream::iter(events).boxed())), @@ -922,8 +932,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std mutation_count -= 1; } 75..=84 if mutation_count > 0 => { - context.update(cx, |context, cx| { - if let Some(message) = context.messages(cx).choose(&mut rng) { + text_thread.update(cx, |text_thread, cx| { + if let Some(message) = text_thread.messages(cx).choose(&mut rng) { let new_status = match rng.random_range(0..3) { 0 => MessageStatus::Done, 1 => MessageStatus::Pending, @@ -935,7 +945,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std message.id, new_status ); - context.update_metadata(message.id, cx, |metadata| { + text_thread.update_metadata(message.id, cx, |metadata| { metadata.status = new_status; }); } @@ -948,8 +958,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std network.lock().reconnect_peer(replica_id, ReplicaId::new(0)); let (ops_to_send, ops_to_receive) = cx.read(|cx| { - let host_context = &contexts[0].read(cx); - let guest_context = context.read(cx); + let host_context = &text_threads[0].read(cx); + let guest_context = text_thread.read(cx); ( guest_context.serialize_ops(&host_context.version(cx), cx), host_context.serialize_ops(&guest_context.version(cx), cx), @@ -959,7 +969,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std let ops_to_receive = ops_to_receive .await .into_iter() - .map(ContextOperation::from_proto) + .map(TextThreadOperation::from_proto) .collect::>>() .unwrap(); log::info!( @@ -970,7 +980,9 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std ); network.lock().broadcast(replica_id, ops_to_send); - context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx)); + text_thread.update(cx, |text_thread, cx| { + text_thread.apply_ops(ops_to_receive, cx) + }); } else if rng.random_bool(0.1) && replica_id != ReplicaId::new(0) { log::info!("Context {}: disconnecting", context_index); network.lock().disconnect_peer(replica_id); @@ -979,43 +991,43 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std let ops = network.lock().receive(replica_id); let ops = ops .into_iter() - .map(ContextOperation::from_proto) + .map(TextThreadOperation::from_proto) .collect::>>() .unwrap(); - context.update(cx, |context, cx| context.apply_ops(ops, cx)); + text_thread.update(cx, |text_thread, cx| text_thread.apply_ops(ops, cx)); } } } } cx.read(|cx| { - let first_context = contexts[0].read(cx); - for context in &contexts[1..] { - let context = context.read(cx); - assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops); + let first_context = text_threads[0].read(cx); + for text_thread in &text_threads[1..] { + let text_thread = text_thread.read(cx); + assert!(text_thread.pending_ops.is_empty(), "pending ops: {:?}", text_thread.pending_ops); assert_eq!( - context.buffer.read(cx).text(), - first_context.buffer.read(cx).text(), + text_thread.buffer().read(cx).text(), + first_context.buffer().read(cx).text(), "Context {:?} text != Context 0 text", - context.buffer.read(cx).replica_id() + text_thread.buffer().read(cx).replica_id() ); assert_eq!( - context.message_anchors, + text_thread.message_anchors, first_context.message_anchors, "Context {:?} messages != Context 0 messages", - context.buffer.read(cx).replica_id() + text_thread.buffer().read(cx).replica_id() ); assert_eq!( - context.messages_metadata, + text_thread.messages_metadata, first_context.messages_metadata, "Context {:?} message metadata != Context 0 message metadata", - context.buffer.read(cx).replica_id() + text_thread.buffer().read(cx).replica_id() ); assert_eq!( - context.slash_command_output_sections, + text_thread.slash_command_output_sections, first_context.slash_command_output_sections, "Context {:?} slash command output sections != Context 0 slash command output sections", - context.buffer.read(cx).replica_id() + text_thread.buffer().read(cx).replica_id() ); } }); @@ -1027,8 +1039,8 @@ fn test_mark_cache_anchors(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new(|cx| { - AssistantContext::local( + let text_thread = cx.new(|cx| { + TextThread::local( registry, None, None, @@ -1037,7 +1049,7 @@ fn test_mark_cache_anchors(cx: &mut App) { cx, ) }); - let buffer = context.read(cx).buffer.clone(); + let buffer = text_thread.read(cx).buffer().clone(); // Create a test cache configuration let cache_configuration = &Some(LanguageModelCacheConfiguration { @@ -1046,14 +1058,14 @@ fn test_mark_cache_anchors(cx: &mut App) { min_total_token: 10, }); - let message_1 = context.read(cx).message_anchors[0].clone(); + let message_1 = text_thread.read(cx).message_anchors[0].clone(); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, false, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, false, cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .count(), @@ -1062,41 +1074,41 @@ fn test_mark_cache_anchors(cx: &mut App) { ); buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); - let message_2 = context - .update(cx, |context, cx| { - context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx) + let message_2 = text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx)); - let message_3 = context - .update(cx, |context, cx| { - context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx) + let message_3 = text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx)); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, false, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, false, cx) }); assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc"); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .count(), 0, "Messages should not be marked for cache before going over the token minimum." ); - context.update(cx, |context, _| { - context.token_count = Some(20); + text_thread.update(cx, |text_thread, _| { + text_thread.token_count = Some(20); }); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, true, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, true, cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .collect::>(), @@ -1104,28 +1116,33 @@ fn test_mark_cache_anchors(cx: &mut App) { "Last message should not be an anchor on speculative request." ); - context - .update(cx, |context, cx| { - context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx) + text_thread + .update(cx, |text_thread, cx| { + text_thread.insert_message_after( + message_3.id, + Role::Assistant, + MessageStatus::Pending, + cx, + ) }) .unwrap(); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, false, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, false, cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .collect::>(), vec![false, true, true, false], "Most recent message should also be cached if not a speculative request." ); - context.update(cx, |context, cx| { - context.update_cache_status_for_completion(cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.update_cache_status_for_completion(cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .map(|(_, cache)| cache .as_ref() @@ -1141,11 +1158,11 @@ fn test_mark_cache_anchors(cx: &mut App) { ); buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx)); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, false, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, false, cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .map(|(_, cache)| cache .as_ref() @@ -1160,11 +1177,11 @@ fn test_mark_cache_anchors(cx: &mut App) { "Modifying a message should invalidate it's cache but leave previous messages." ); buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx)); - context.update(cx, |context, cx| { - context.mark_cache_anchors(cache_configuration, false, cx) + text_thread.update(cx, |text_thread, cx| { + text_thread.mark_cache_anchors(cache_configuration, false, cx) }); assert_eq!( - messages_cache(&context, cx) + messages_cache(&text_thread, cx) .iter() .map(|(_, cache)| cache .as_ref() @@ -1182,31 +1199,36 @@ fn test_mark_cache_anchors(cx: &mut App) { #[gpui::test] async fn test_summarization(cx: &mut TestAppContext) { - let (context, fake_model) = setup_context_editor_with_fake_model(cx); + let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx); // Initial state should be pending - context.read_with(cx, |context, _| { - assert!(matches!(context.summary(), ContextSummary::Pending)); - assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + text_thread.read_with(cx, |text_thread, _| { + assert!(matches!(text_thread.summary(), TextThreadSummary::Pending)); + assert_eq!( + text_thread.summary().or_default(), + TextThreadSummary::DEFAULT + ); }); - let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); - context.update(cx, |context, cx| { + let message_1 = text_thread.read_with(cx, |text_thread, _cx| { + text_thread.message_anchors[0].clone() + }); + text_thread.update(cx, |context, cx| { context .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap(); }); // Send a message - context.update(cx, |context, cx| { - context.assist(cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.assist(cx); }); simulate_successful_response(&fake_model, cx); // Should start generating summary when there are >= 2 messages - context.read_with(cx, |context, _| { - assert!(!context.summary().content().unwrap().done); + text_thread.read_with(cx, |text_thread, _| { + assert!(!text_thread.summary().content().unwrap().done); }); cx.run_until_parked(); @@ -1216,61 +1238,61 @@ async fn test_summarization(cx: &mut TestAppContext) { cx.run_until_parked(); // Summary should be set - context.read_with(cx, |context, _| { - assert_eq!(context.summary().or_default(), "Brief Introduction"); + text_thread.read_with(cx, |text_thread, _| { + assert_eq!(text_thread.summary().or_default(), "Brief Introduction"); }); // We should be able to manually set a summary - context.update(cx, |context, cx| { - context.set_custom_summary("Brief Intro".into(), cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.set_custom_summary("Brief Intro".into(), cx); }); - context.read_with(cx, |context, _| { - assert_eq!(context.summary().or_default(), "Brief Intro"); + text_thread.read_with(cx, |text_thread, _| { + assert_eq!(text_thread.summary().or_default(), "Brief Intro"); }); } #[gpui::test] async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { - let (context, fake_model) = setup_context_editor_with_fake_model(cx); + let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx); - test_summarize_error(&fake_model, &context, cx); + test_summarize_error(&fake_model, &text_thread, cx); // Now we should be able to set a summary - context.update(cx, |context, cx| { - context.set_custom_summary("Brief Intro".into(), cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.set_custom_summary("Brief Intro".into(), cx); }); - context.read_with(cx, |context, _| { - assert_eq!(context.summary().or_default(), "Brief Intro"); + text_thread.read_with(cx, |text_thread, _| { + assert_eq!(text_thread.summary().or_default(), "Brief Intro"); }); } #[gpui::test] async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { - let (context, fake_model) = setup_context_editor_with_fake_model(cx); + let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx); - test_summarize_error(&fake_model, &context, cx); + test_summarize_error(&fake_model, &text_thread, cx); // Sending another message should not trigger another summarize request - context.update(cx, |context, cx| { - context.assist(cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.assist(cx); }); simulate_successful_response(&fake_model, cx); - context.read_with(cx, |context, _| { + text_thread.read_with(cx, |text_thread, _| { // State is still Error, not Generating - assert!(matches!(context.summary(), ContextSummary::Error)); + assert!(matches!(text_thread.summary(), TextThreadSummary::Error)); }); // But the summarize request can be invoked manually - context.update(cx, |context, cx| { - context.summarize(true, cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.summarize(true, cx); }); - context.read_with(cx, |context, _| { - assert!(!context.summary().content().unwrap().done); + text_thread.read_with(cx, |text_thread, _| { + assert!(!text_thread.summary().content().unwrap().done); }); cx.run_until_parked(); @@ -1278,32 +1300,34 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { fake_model.end_last_completion_stream(); cx.run_until_parked(); - context.read_with(cx, |context, _| { - assert_eq!(context.summary().or_default(), "A successful summary"); + text_thread.read_with(cx, |text_thread, _| { + assert_eq!(text_thread.summary().or_default(), "A successful summary"); }); } fn test_summarize_error( model: &Arc, - context: &Entity, + text_thread: &Entity, cx: &mut TestAppContext, ) { - let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); - context.update(cx, |context, cx| { - context + let message_1 = text_thread.read_with(cx, |text_thread, _cx| { + text_thread.message_anchors[0].clone() + }); + text_thread.update(cx, |text_thread, cx| { + text_thread .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap(); }); // Send a message - context.update(cx, |context, cx| { - context.assist(cx); + text_thread.update(cx, |text_thread, cx| { + text_thread.assist(cx); }); simulate_successful_response(model, cx); - context.read_with(cx, |context, _| { - assert!(!context.summary().content().unwrap().done); + text_thread.read_with(cx, |text_thread, _| { + assert!(!text_thread.summary().content().unwrap().done); }); // Simulate summary request ending @@ -1312,15 +1336,18 @@ fn test_summarize_error( cx.run_until_parked(); // State is set to Error and default message - context.read_with(cx, |context, _| { - assert_eq!(*context.summary(), ContextSummary::Error); - assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + text_thread.read_with(cx, |text_thread, _| { + assert_eq!(*text_thread.summary(), TextThreadSummary::Error); + assert_eq!( + text_thread.summary().or_default(), + TextThreadSummary::DEFAULT + ); }); } fn setup_context_editor_with_fake_model( cx: &mut TestAppContext, -) -> (Entity, Arc) { +) -> (Entity, Arc) { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let fake_provider = Arc::new(FakeLanguageModelProvider::default()); @@ -1340,7 +1367,7 @@ fn setup_context_editor_with_fake_model( let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context = cx.new(|cx| { - AssistantContext::local( + TextThread::local( registry, None, None, @@ -1360,7 +1387,7 @@ fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestApp cx.run_until_parked(); } -fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Role, Range)> { +fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Role, Range)> { context .read(cx) .messages(cx) @@ -1369,7 +1396,7 @@ fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Rol } fn messages_cache( - context: &Entity, + context: &Entity, cx: &App, ) -> Vec<(MessageId, Option)> { context diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_text_thread/src/text_thread.rs similarity index 92% rename from crates/assistant_context/src/assistant_context.rs rename to crates/assistant_text_thread/src/text_thread.rs index 5a1fa707ff04ac3b0cd719c3d0a5e67dfeb3e625..9ad383cdfd43eed236268349e2ff97c34a0178c0 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -1,7 +1,3 @@ -#[cfg(test)] -mod assistant_context_tests; -mod context_store; - use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, bail}; use assistant_slash_command::{ @@ -9,7 +5,7 @@ use assistant_slash_command::{ SlashCommandResult, SlashCommandWorkingSet, }; use assistant_slash_commands::FileCommandMetadata; -use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; +use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use clock::ReplicaId; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::{HashMap, HashSet}; @@ -27,7 +23,7 @@ use language_model::{ report_assistant_event, }; use open_ai::Model as OpenAiModel; -use paths::contexts_dir; +use paths::text_threads_dir; use project::Project; use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; @@ -48,16 +44,10 @@ use ui::IconName; use util::{ResultExt, TryFutureExt, post_inc}; use uuid::Uuid; -pub use crate::context_store::*; - -pub fn init(client: Arc, _: &mut App) { - context_store::init(&client.into()); -} - #[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct ContextId(String); +pub struct TextThreadId(String); -impl ContextId { +impl TextThreadId { pub fn new() -> Self { Self(Uuid::new_v4().to_string()) } @@ -130,7 +120,7 @@ impl MessageStatus { } #[derive(Clone, Debug)] -pub enum ContextOperation { +pub enum TextThreadOperation { InsertMessage { anchor: MessageAnchor, metadata: MessageMetadata, @@ -142,7 +132,7 @@ pub enum ContextOperation { version: clock::Global, }, UpdateSummary { - summary: ContextSummaryContent, + summary: TextThreadSummaryContent, version: clock::Global, }, SlashCommandStarted { @@ -170,7 +160,7 @@ pub enum ContextOperation { BufferOperation(language::Operation), } -impl ContextOperation { +impl TextThreadOperation { pub fn from_proto(op: proto::ContextOperation) -> Result { match op.variant.context("invalid variant")? { proto::context_operation::Variant::InsertMessage(insert) => { @@ -212,7 +202,7 @@ impl ContextOperation { version: language::proto::deserialize_version(&update.version), }), proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary { - summary: ContextSummaryContent { + summary: TextThreadSummaryContent { text: update.summary, done: update.done, timestamp: language::proto::deserialize_timestamp( @@ -453,7 +443,7 @@ impl ContextOperation { } #[derive(Debug, Clone)] -pub enum ContextEvent { +pub enum TextThreadEvent { ShowAssistError(SharedString), ShowPaymentRequiredError, MessagesEdited, @@ -476,24 +466,24 @@ pub enum ContextEvent { SlashCommandOutputSectionAdded { section: SlashCommandOutputSection, }, - Operation(ContextOperation), + Operation(TextThreadOperation), } #[derive(Clone, Debug, Eq, PartialEq)] -pub enum ContextSummary { +pub enum TextThreadSummary { Pending, - Content(ContextSummaryContent), + Content(TextThreadSummaryContent), Error, } #[derive(Clone, Debug, Eq, PartialEq)] -pub struct ContextSummaryContent { +pub struct TextThreadSummaryContent { pub text: String, pub done: bool, pub timestamp: clock::Lamport, } -impl ContextSummary { +impl TextThreadSummary { pub const DEFAULT: &str = "New Text Thread"; pub fn or_default(&self) -> SharedString { @@ -505,48 +495,48 @@ impl ContextSummary { .map_or_else(|| message.into(), |content| content.text.clone().into()) } - pub fn content(&self) -> Option<&ContextSummaryContent> { + pub fn content(&self) -> Option<&TextThreadSummaryContent> { match self { - ContextSummary::Content(content) => Some(content), - ContextSummary::Pending | ContextSummary::Error => None, + TextThreadSummary::Content(content) => Some(content), + TextThreadSummary::Pending | TextThreadSummary::Error => None, } } - fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> { + fn content_as_mut(&mut self) -> Option<&mut TextThreadSummaryContent> { match self { - ContextSummary::Content(content) => Some(content), - ContextSummary::Pending | ContextSummary::Error => None, + TextThreadSummary::Content(content) => Some(content), + TextThreadSummary::Pending | TextThreadSummary::Error => None, } } - fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent { + fn content_or_set_empty(&mut self) -> &mut TextThreadSummaryContent { match self { - ContextSummary::Content(content) => content, - ContextSummary::Pending | ContextSummary::Error => { - let content = ContextSummaryContent { + TextThreadSummary::Content(content) => content, + TextThreadSummary::Pending | TextThreadSummary::Error => { + let content = TextThreadSummaryContent { text: "".to_string(), done: false, timestamp: clock::Lamport::MIN, }; - *self = ContextSummary::Content(content); + *self = TextThreadSummary::Content(content); self.content_as_mut().unwrap() } } } pub fn is_pending(&self) -> bool { - matches!(self, ContextSummary::Pending) + matches!(self, TextThreadSummary::Pending) } fn timestamp(&self) -> Option { match self { - ContextSummary::Content(content) => Some(content.timestamp), - ContextSummary::Pending | ContextSummary::Error => None, + TextThreadSummary::Content(content) => Some(content.timestamp), + TextThreadSummary::Pending | TextThreadSummary::Error => None, } } } -impl PartialOrd for ContextSummary { +impl PartialOrd for TextThreadSummary { fn partial_cmp(&self, other: &Self) -> Option { self.timestamp().partial_cmp(&other.timestamp()) } @@ -668,27 +658,27 @@ struct PendingCompletion { #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct InvokedSlashCommandId(clock::Lamport); -pub struct AssistantContext { - id: ContextId, +pub struct TextThread { + id: TextThreadId, timestamp: clock::Lamport, version: clock::Global, - pending_ops: Vec, - operations: Vec, + pub(crate) pending_ops: Vec, + operations: Vec, buffer: Entity, - parsed_slash_commands: Vec, + pub(crate) parsed_slash_commands: Vec, invoked_slash_commands: HashMap, edits_since_last_parse: language::Subscription, slash_commands: Arc, - slash_command_output_sections: Vec>, + pub(crate) slash_command_output_sections: Vec>, thought_process_output_sections: Vec>, - message_anchors: Vec, + pub(crate) message_anchors: Vec, contents: Vec, - messages_metadata: HashMap, - summary: ContextSummary, + pub(crate) messages_metadata: HashMap, + summary: TextThreadSummary, summary_task: Task>, completion_count: usize, pending_completions: Vec, - token_count: Option, + pub(crate) token_count: Option, pending_token_count: Task>, pending_save: Task>, pending_cache_warming_task: Task>, @@ -711,9 +701,9 @@ impl ContextAnnotation for ParsedSlashCommand { } } -impl EventEmitter for AssistantContext {} +impl EventEmitter for TextThread {} -impl AssistantContext { +impl TextThread { pub fn local( language_registry: Arc, project: Option>, @@ -723,7 +713,7 @@ impl AssistantContext { cx: &mut Context, ) -> Self { Self::new( - ContextId::new(), + TextThreadId::new(), ReplicaId::default(), language::Capability::ReadWrite, language_registry, @@ -744,7 +734,7 @@ impl AssistantContext { } pub fn new( - id: ContextId, + id: TextThreadId, replica_id: ReplicaId, capability: language::Capability, language_registry: Arc, @@ -780,7 +770,7 @@ impl AssistantContext { slash_command_output_sections: Vec::new(), thought_process_output_sections: Vec::new(), edits_since_last_parse: edits_since_last_slash_command_parse, - summary: ContextSummary::Pending, + summary: TextThreadSummary::Pending, summary_task: Task::ready(None), completion_count: Default::default(), pending_completions: Default::default(), @@ -823,12 +813,12 @@ impl AssistantContext { this } - pub(crate) fn serialize(&self, cx: &App) -> SavedContext { + pub(crate) fn serialize(&self, cx: &App) -> SavedTextThread { let buffer = self.buffer.read(cx); - SavedContext { + SavedTextThread { id: Some(self.id.clone()), zed: "context".into(), - version: SavedContext::VERSION.into(), + version: SavedTextThread::VERSION.into(), text: buffer.text(), messages: self .messages(cx) @@ -876,7 +866,7 @@ impl AssistantContext { } pub fn deserialize( - saved_context: SavedContext, + saved_context: SavedTextThread, path: Arc, language_registry: Arc, prompt_builder: Arc, @@ -885,7 +875,7 @@ impl AssistantContext { telemetry: Option>, cx: &mut Context, ) -> Self { - let id = saved_context.id.clone().unwrap_or_else(ContextId::new); + let id = saved_context.id.clone().unwrap_or_else(TextThreadId::new); let mut this = Self::new( id, ReplicaId::default(), @@ -906,7 +896,7 @@ impl AssistantContext { this } - pub fn id(&self) -> &ContextId { + pub fn id(&self) -> &TextThreadId { &self.id } @@ -914,9 +904,9 @@ impl AssistantContext { self.timestamp.replica_id } - pub fn version(&self, cx: &App) -> ContextVersion { - ContextVersion { - context: self.version.clone(), + pub fn version(&self, cx: &App) -> TextThreadVersion { + TextThreadVersion { + text_thread: self.version.clone(), buffer: self.buffer.read(cx).version(), } } @@ -938,7 +928,7 @@ impl AssistantContext { pub fn serialize_ops( &self, - since: &ContextVersion, + since: &TextThreadVersion, cx: &App, ) -> Task> { let buffer_ops = self @@ -949,7 +939,7 @@ impl AssistantContext { let mut context_ops = self .operations .iter() - .filter(|op| !since.context.observed(op.timestamp())) + .filter(|op| !since.text_thread.observed(op.timestamp())) .cloned() .collect::>(); context_ops.extend(self.pending_ops.iter().cloned()); @@ -973,13 +963,13 @@ impl AssistantContext { pub fn apply_ops( &mut self, - ops: impl IntoIterator, + ops: impl IntoIterator, cx: &mut Context, ) { let mut buffer_ops = Vec::new(); for op in ops { match op { - ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op), + TextThreadOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op), op @ _ => self.pending_ops.push(op), } } @@ -988,7 +978,7 @@ impl AssistantContext { self.flush_ops(cx); } - fn flush_ops(&mut self, cx: &mut Context) { + fn flush_ops(&mut self, cx: &mut Context) { let mut changed_messages = HashSet::default(); let mut summary_generated = false; @@ -1001,7 +991,7 @@ impl AssistantContext { let timestamp = op.timestamp(); match op.clone() { - ContextOperation::InsertMessage { + TextThreadOperation::InsertMessage { anchor, metadata, .. } => { if self.messages_metadata.contains_key(&anchor.id) { @@ -1011,7 +1001,7 @@ impl AssistantContext { self.insert_message(anchor, metadata, cx); } } - ContextOperation::UpdateMessage { + TextThreadOperation::UpdateMessage { message_id, metadata: new_metadata, .. @@ -1022,7 +1012,7 @@ impl AssistantContext { changed_messages.insert(message_id); } } - ContextOperation::UpdateSummary { + TextThreadOperation::UpdateSummary { summary: new_summary, .. } => { @@ -1031,11 +1021,11 @@ impl AssistantContext { .timestamp() .is_none_or(|current_timestamp| new_summary.timestamp > current_timestamp) { - self.summary = ContextSummary::Content(new_summary); + self.summary = TextThreadSummary::Content(new_summary); summary_generated = true; } } - ContextOperation::SlashCommandStarted { + TextThreadOperation::SlashCommandStarted { id, output_range, name, @@ -1052,9 +1042,9 @@ impl AssistantContext { timestamp: id.0, }, ); - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id }); + cx.emit(TextThreadEvent::InvokedSlashCommandChanged { command_id: id }); } - ContextOperation::SlashCommandOutputSectionAdded { section, .. } => { + TextThreadOperation::SlashCommandOutputSectionAdded { section, .. } => { let buffer = self.buffer.read(cx); if let Err(ix) = self .slash_command_output_sections @@ -1062,10 +1052,10 @@ impl AssistantContext { { self.slash_command_output_sections .insert(ix, section.clone()); - cx.emit(ContextEvent::SlashCommandOutputSectionAdded { section }); + cx.emit(TextThreadEvent::SlashCommandOutputSectionAdded { section }); } } - ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => { + TextThreadOperation::ThoughtProcessOutputSectionAdded { section, .. } => { let buffer = self.buffer.read(cx); if let Err(ix) = self .thought_process_output_sections @@ -1075,7 +1065,7 @@ impl AssistantContext { .insert(ix, section.clone()); } } - ContextOperation::SlashCommandFinished { + TextThreadOperation::SlashCommandFinished { id, error_message, timestamp, @@ -1094,10 +1084,10 @@ impl AssistantContext { slash_command.status = InvokedSlashCommandStatus::Finished; } } - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id }); + cx.emit(TextThreadEvent::InvokedSlashCommandChanged { command_id: id }); } } - ContextOperation::BufferOperation(_) => unreachable!(), + TextThreadOperation::BufferOperation(_) => unreachable!(), } self.version.observe(timestamp); @@ -1107,43 +1097,43 @@ impl AssistantContext { if !changed_messages.is_empty() { self.message_roles_updated(changed_messages, cx); - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); cx.notify(); } if summary_generated { - cx.emit(ContextEvent::SummaryChanged); - cx.emit(ContextEvent::SummaryGenerated); + cx.emit(TextThreadEvent::SummaryChanged); + cx.emit(TextThreadEvent::SummaryGenerated); cx.notify(); } } - fn can_apply_op(&self, op: &ContextOperation, cx: &App) -> bool { + fn can_apply_op(&self, op: &TextThreadOperation, cx: &App) -> bool { if !self.version.observed_all(op.version()) { return false; } match op { - ContextOperation::InsertMessage { anchor, .. } => self + TextThreadOperation::InsertMessage { anchor, .. } => self .buffer .read(cx) .version .observed(anchor.start.timestamp), - ContextOperation::UpdateMessage { message_id, .. } => { + TextThreadOperation::UpdateMessage { message_id, .. } => { self.messages_metadata.contains_key(message_id) } - ContextOperation::UpdateSummary { .. } => true, - ContextOperation::SlashCommandStarted { output_range, .. } => { + TextThreadOperation::UpdateSummary { .. } => true, + TextThreadOperation::SlashCommandStarted { output_range, .. } => { self.has_received_operations_for_anchor_range(output_range.clone(), cx) } - ContextOperation::SlashCommandOutputSectionAdded { section, .. } => { + TextThreadOperation::SlashCommandOutputSectionAdded { section, .. } => { self.has_received_operations_for_anchor_range(section.range.clone(), cx) } - ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => { + TextThreadOperation::ThoughtProcessOutputSectionAdded { section, .. } => { self.has_received_operations_for_anchor_range(section.range.clone(), cx) } - ContextOperation::SlashCommandFinished { .. } => true, - ContextOperation::BufferOperation(_) => { + TextThreadOperation::SlashCommandFinished { .. } => true, + TextThreadOperation::BufferOperation(_) => { panic!("buffer operations should always be applied") } } @@ -1164,9 +1154,9 @@ impl AssistantContext { observed_start && observed_end } - fn push_op(&mut self, op: ContextOperation, cx: &mut Context) { + fn push_op(&mut self, op: TextThreadOperation, cx: &mut Context) { self.operations.push(op.clone()); - cx.emit(ContextEvent::Operation(op)); + cx.emit(TextThreadEvent::Operation(op)); } pub fn buffer(&self) -> &Entity { @@ -1189,7 +1179,7 @@ impl AssistantContext { self.path.as_ref() } - pub fn summary(&self) -> &ContextSummary { + pub fn summary(&self) -> &TextThreadSummary { &self.summary } @@ -1250,13 +1240,13 @@ impl AssistantContext { language::BufferEvent::Operation { operation, is_local: true, - } => cx.emit(ContextEvent::Operation(ContextOperation::BufferOperation( - operation.clone(), - ))), + } => cx.emit(TextThreadEvent::Operation( + TextThreadOperation::BufferOperation(operation.clone()), + )), language::BufferEvent::Edited => { self.count_remaining_tokens(cx); self.reparse(cx); - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); } _ => {} } @@ -1522,7 +1512,7 @@ impl AssistantContext { if !updated_parsed_slash_commands.is_empty() || !removed_parsed_slash_command_ranges.is_empty() { - cx.emit(ContextEvent::ParsedSlashCommandsUpdated { + cx.emit(TextThreadEvent::ParsedSlashCommandsUpdated { removed: removed_parsed_slash_command_ranges, updated: updated_parsed_slash_commands, }); @@ -1596,7 +1586,7 @@ impl AssistantContext { && (!command.range.start.is_valid(buffer) || !command.range.end.is_valid(buffer)) { command.status = InvokedSlashCommandStatus::Finished; - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id }); + cx.emit(TextThreadEvent::InvokedSlashCommandChanged { command_id }); invalidated_command_ids.push(command_id); } } @@ -1605,7 +1595,7 @@ impl AssistantContext { let version = self.version.clone(); let timestamp = self.next_timestamp(); self.push_op( - ContextOperation::SlashCommandFinished { + TextThreadOperation::SlashCommandFinished { id: command_id, timestamp, error_message: None, @@ -1910,9 +1900,9 @@ impl AssistantContext { } } - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id }); + cx.emit(TextThreadEvent::InvokedSlashCommandChanged { command_id }); this.push_op( - ContextOperation::SlashCommandFinished { + TextThreadOperation::SlashCommandFinished { id: command_id, timestamp, error_message, @@ -1935,9 +1925,9 @@ impl AssistantContext { timestamp: command_id.0, }, ); - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id }); + cx.emit(TextThreadEvent::InvokedSlashCommandChanged { command_id }); self.push_op( - ContextOperation::SlashCommandStarted { + TextThreadOperation::SlashCommandStarted { id: command_id, output_range: command_range, name: name.to_string(), @@ -1961,13 +1951,13 @@ impl AssistantContext { }; self.slash_command_output_sections .insert(insertion_ix, section.clone()); - cx.emit(ContextEvent::SlashCommandOutputSectionAdded { + cx.emit(TextThreadEvent::SlashCommandOutputSectionAdded { section: section.clone(), }); let version = self.version.clone(); let timestamp = self.next_timestamp(); self.push_op( - ContextOperation::SlashCommandOutputSectionAdded { + TextThreadOperation::SlashCommandOutputSectionAdded { timestamp, section, version, @@ -1996,7 +1986,7 @@ impl AssistantContext { let version = self.version.clone(); let timestamp = self.next_timestamp(); self.push_op( - ContextOperation::ThoughtProcessOutputSectionAdded { + TextThreadOperation::ThoughtProcessOutputSectionAdded { timestamp, section, version, @@ -2115,7 +2105,7 @@ impl AssistantContext { let end = buffer .anchor_before(message_old_end_offset + chunk_len); context_event = Some( - ContextEvent::StartedThoughtProcess(start..end), + TextThreadEvent::StartedThoughtProcess(start..end), ); } else { // This ensures that all the thinking chunks are inserted inside the thinking tag @@ -2133,7 +2123,7 @@ impl AssistantContext { if let Some(start) = thought_process_stack.pop() { let end = buffer.anchor_before(message_old_end_offset); context_event = - Some(ContextEvent::EndedThoughtProcess(end)); + Some(TextThreadEvent::EndedThoughtProcess(end)); thought_process_output_section = Some(ThoughtProcessOutputSection { range: start..end, @@ -2163,7 +2153,7 @@ impl AssistantContext { cx.emit(context_event); } - cx.emit(ContextEvent::StreamedCompletion); + cx.emit(TextThreadEvent::StreamedCompletion); Some(()) })?; @@ -2184,7 +2174,7 @@ impl AssistantContext { this.update(cx, |this, cx| { let error_message = if let Some(error) = result.as_ref().err() { if error.is::() { - cx.emit(ContextEvent::ShowPaymentRequiredError); + cx.emit(TextThreadEvent::ShowPaymentRequiredError); this.update_metadata(assistant_message_id, cx, |metadata| { metadata.status = MessageStatus::Canceled; }); @@ -2195,7 +2185,7 @@ impl AssistantContext { .map(|err| err.to_string()) .collect::>() .join("\n"); - cx.emit(ContextEvent::ShowAssistError(SharedString::from( + cx.emit(TextThreadEvent::ShowAssistError(SharedString::from( error_message.clone(), ))); this.update_metadata(assistant_message_id, cx, |metadata| { @@ -2412,13 +2402,13 @@ impl AssistantContext { if let Some(metadata) = self.messages_metadata.get_mut(&id) { f(metadata); metadata.timestamp = timestamp; - let operation = ContextOperation::UpdateMessage { + let operation = TextThreadOperation::UpdateMessage { message_id: id, metadata: metadata.clone(), version, }; self.push_op(operation, cx); - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); cx.notify(); } } @@ -2482,7 +2472,7 @@ impl AssistantContext { }; self.insert_message(anchor.clone(), metadata.clone(), cx); self.push_op( - ContextOperation::InsertMessage { + TextThreadOperation::InsertMessage { anchor: anchor.clone(), metadata, version, @@ -2505,7 +2495,7 @@ impl AssistantContext { Err(ix) => ix, }; self.contents.insert(insertion_ix, content); - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); } pub fn contents<'a>(&'a self, cx: &'a App) -> impl 'a + Iterator { @@ -2580,7 +2570,7 @@ impl AssistantContext { }; self.insert_message(suffix.clone(), suffix_metadata.clone(), cx); self.push_op( - ContextOperation::InsertMessage { + TextThreadOperation::InsertMessage { anchor: suffix.clone(), metadata: suffix_metadata, version, @@ -2630,7 +2620,7 @@ impl AssistantContext { }; self.insert_message(selection.clone(), selection_metadata.clone(), cx); self.push_op( - ContextOperation::InsertMessage { + TextThreadOperation::InsertMessage { anchor: selection.clone(), metadata: selection_metadata, version, @@ -2642,7 +2632,7 @@ impl AssistantContext { }; if !edited_buffer { - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); } new_messages } else { @@ -2656,7 +2646,7 @@ impl AssistantContext { new_metadata: MessageMetadata, cx: &mut Context, ) { - cx.emit(ContextEvent::MessagesEdited); + cx.emit(TextThreadEvent::MessagesEdited); self.messages_metadata.insert(new_anchor.id, new_metadata); @@ -2692,15 +2682,15 @@ impl AssistantContext { // If there is no summary, it is set with `done: false` so that "Loading Summary…" can // be displayed. match self.summary { - ContextSummary::Pending | ContextSummary::Error => { - self.summary = ContextSummary::Content(ContextSummaryContent { + TextThreadSummary::Pending | TextThreadSummary::Error => { + self.summary = TextThreadSummary::Content(TextThreadSummaryContent { text: "".to_string(), done: false, timestamp: clock::Lamport::MIN, }); replace_old = true; } - ContextSummary::Content(_) => {} + TextThreadSummary::Content(_) => {} } self.summary_task = cx.spawn(async move |this, cx| { @@ -2722,13 +2712,13 @@ impl AssistantContext { } summary.text.extend(lines.next()); summary.timestamp = timestamp; - let operation = ContextOperation::UpdateSummary { + let operation = TextThreadOperation::UpdateSummary { summary: summary.clone(), version, }; this.push_op(operation, cx); - cx.emit(ContextEvent::SummaryChanged); - cx.emit(ContextEvent::SummaryGenerated); + cx.emit(TextThreadEvent::SummaryChanged); + cx.emit(TextThreadEvent::SummaryGenerated); })?; // Stop if the LLM generated multiple lines. @@ -2752,13 +2742,13 @@ impl AssistantContext { if let Some(summary) = this.summary.content_as_mut() { summary.done = true; summary.timestamp = timestamp; - let operation = ContextOperation::UpdateSummary { + let operation = TextThreadOperation::UpdateSummary { summary: summary.clone(), version, }; this.push_op(operation, cx); - cx.emit(ContextEvent::SummaryChanged); - cx.emit(ContextEvent::SummaryGenerated); + cx.emit(TextThreadEvent::SummaryChanged); + cx.emit(TextThreadEvent::SummaryGenerated); } })?; @@ -2768,8 +2758,8 @@ impl AssistantContext { if let Err(err) = result { this.update(cx, |this, cx| { - this.summary = ContextSummary::Error; - cx.emit(ContextEvent::SummaryChanged); + this.summary = TextThreadSummary::Error; + cx.emit(TextThreadEvent::SummaryChanged); }) .log_err(); log::error!("Error generating context summary: {}", err); @@ -2875,7 +2865,7 @@ impl AssistantContext { &mut self, debounce: Option, fs: Arc, - cx: &mut Context, + cx: &mut Context, ) { if self.replica_id() != ReplicaId::default() { // Prevent saving a remote context for now. @@ -2906,7 +2896,7 @@ impl AssistantContext { let mut discriminant = 1; let mut new_path; loop { - new_path = contexts_dir().join(&format!( + new_path = text_threads_dir().join(&format!( "{} - {}.zed.json", summary.trim(), discriminant @@ -2918,7 +2908,7 @@ impl AssistantContext { } } - fs.create_dir(contexts_dir().as_ref()).await?; + fs.create_dir(text_threads_dir().as_ref()).await?; // rename before write ensures that only one file exists if let Some(old_path) = old_path.as_ref() @@ -2940,7 +2930,7 @@ impl AssistantContext { let new_path: Arc = new_path.clone().into(); move |this, cx| { this.path = Some(new_path.clone()); - cx.emit(ContextEvent::PathChanged { old_path, new_path }); + cx.emit(TextThreadEvent::PathChanged { old_path, new_path }); } }) .ok(); @@ -2959,7 +2949,7 @@ impl AssistantContext { summary.timestamp = timestamp; summary.done = true; summary.text = custom_summary; - cx.emit(ContextEvent::SummaryChanged); + cx.emit(TextThreadEvent::SummaryChanged); } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { @@ -2979,23 +2969,23 @@ impl AssistantContext { } #[derive(Debug, Default)] -pub struct ContextVersion { - context: clock::Global, +pub struct TextThreadVersion { + text_thread: clock::Global, buffer: clock::Global, } -impl ContextVersion { +impl TextThreadVersion { pub fn from_proto(proto: &proto::ContextVersion) -> Self { Self { - context: language::proto::deserialize_version(&proto.context_version), + text_thread: language::proto::deserialize_version(&proto.context_version), buffer: language::proto::deserialize_version(&proto.buffer_version), } } - pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion { + pub fn to_proto(&self, context_id: TextThreadId) -> proto::ContextVersion { proto::ContextVersion { context_id: context_id.to_proto(), - context_version: language::proto::serialize_version(&self.context), + context_version: language::proto::serialize_version(&self.text_thread), buffer_version: language::proto::serialize_version(&self.buffer), } } @@ -3063,8 +3053,8 @@ pub struct SavedMessage { } #[derive(Serialize, Deserialize)] -pub struct SavedContext { - pub id: Option, +pub struct SavedTextThread { + pub id: Option, pub zed: String, pub version: String, pub text: String, @@ -3076,7 +3066,7 @@ pub struct SavedContext { pub thought_process_output_sections: Vec>, } -impl SavedContext { +impl SavedTextThread { pub const VERSION: &'static str = "0.4.0"; pub fn from_json(json: &str) -> Result { @@ -3086,9 +3076,9 @@ impl SavedContext { .context("version not found")? { serde_json::Value::String(version) => match version.as_str() { - SavedContext::VERSION => { - Ok(serde_json::from_value::(saved_context_json)?) - } + SavedTextThread::VERSION => Ok(serde_json::from_value::( + saved_context_json, + )?), SavedContextV0_3_0::VERSION => { let saved_context = serde_json::from_value::(saved_context_json)?; @@ -3113,8 +3103,8 @@ impl SavedContext { fn into_ops( self, buffer: &Entity, - cx: &mut Context, - ) -> Vec { + cx: &mut Context, + ) -> Vec { let mut operations = Vec::new(); let mut version = clock::Global::new(); let mut next_timestamp = clock::Lamport::new(ReplicaId::default()); @@ -3124,7 +3114,7 @@ impl SavedContext { if message.id == MessageId(clock::Lamport::MIN) { first_message_metadata = Some(message.metadata); } else { - operations.push(ContextOperation::InsertMessage { + operations.push(TextThreadOperation::InsertMessage { anchor: MessageAnchor { id: message.id, start: buffer.read(cx).anchor_before(message.start), @@ -3144,7 +3134,7 @@ impl SavedContext { if let Some(metadata) = first_message_metadata { let timestamp = next_timestamp.tick(); - operations.push(ContextOperation::UpdateMessage { + operations.push(TextThreadOperation::UpdateMessage { message_id: MessageId(clock::Lamport::MIN), metadata: MessageMetadata { role: metadata.role, @@ -3160,7 +3150,7 @@ impl SavedContext { let buffer = buffer.read(cx); for section in self.slash_command_output_sections { let timestamp = next_timestamp.tick(); - operations.push(ContextOperation::SlashCommandOutputSectionAdded { + operations.push(TextThreadOperation::SlashCommandOutputSectionAdded { timestamp, section: SlashCommandOutputSection { range: buffer.anchor_after(section.range.start) @@ -3177,7 +3167,7 @@ impl SavedContext { for section in self.thought_process_output_sections { let timestamp = next_timestamp.tick(); - operations.push(ContextOperation::ThoughtProcessOutputSectionAdded { + operations.push(TextThreadOperation::ThoughtProcessOutputSectionAdded { timestamp, section: ThoughtProcessOutputSection { range: buffer.anchor_after(section.range.start) @@ -3190,8 +3180,8 @@ impl SavedContext { } let timestamp = next_timestamp.tick(); - operations.push(ContextOperation::UpdateSummary { - summary: ContextSummaryContent { + operations.push(TextThreadOperation::UpdateSummary { + summary: TextThreadSummaryContent { text: self.summary, done: true, timestamp, @@ -3221,7 +3211,7 @@ struct SavedMessageMetadataPreV0_4_0 { #[derive(Serialize, Deserialize)] struct SavedContextV0_3_0 { - id: Option, + id: Option, zed: String, version: String, text: String, @@ -3234,11 +3224,11 @@ struct SavedContextV0_3_0 { impl SavedContextV0_3_0 { const VERSION: &'static str = "0.3.0"; - fn upgrade(self) -> SavedContext { - SavedContext { + fn upgrade(self) -> SavedTextThread { + SavedTextThread { id: self.id, zed: self.zed, - version: SavedContext::VERSION.into(), + version: SavedTextThread::VERSION.into(), text: self.text, messages: self .messages @@ -3270,7 +3260,7 @@ impl SavedContextV0_3_0 { #[derive(Serialize, Deserialize)] struct SavedContextV0_2_0 { - id: Option, + id: Option, zed: String, version: String, text: String, @@ -3282,7 +3272,7 @@ struct SavedContextV0_2_0 { impl SavedContextV0_2_0 { const VERSION: &'static str = "0.2.0"; - fn upgrade(self) -> SavedContext { + fn upgrade(self) -> SavedTextThread { SavedContextV0_3_0 { id: self.id, zed: self.zed, @@ -3299,7 +3289,7 @@ impl SavedContextV0_2_0 { #[derive(Serialize, Deserialize)] struct SavedContextV0_1_0 { - id: Option, + id: Option, zed: String, version: String, text: String, @@ -3313,7 +3303,7 @@ struct SavedContextV0_1_0 { impl SavedContextV0_1_0 { const VERSION: &'static str = "0.1.0"; - fn upgrade(self) -> SavedContext { + fn upgrade(self) -> SavedTextThread { SavedContextV0_2_0 { id: self.id, zed: self.zed, @@ -3328,7 +3318,7 @@ impl SavedContextV0_1_0 { } #[derive(Debug, Clone)] -pub struct SavedContextMetadata { +pub struct SavedTextThreadMetadata { pub title: SharedString, pub path: Arc, pub mtime: chrono::DateTime, diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_text_thread/src/text_thread_store.rs similarity index 71% rename from crates/assistant_context/src/context_store.rs rename to crates/assistant_text_thread/src/text_thread_store.rs index 5fac44e31f4cc073af8fe6bbb57f75fc03b27f45..19c317baf0fa728c77faebc388b5e36008aa39b3 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_text_thread/src/text_thread_store.rs @@ -1,6 +1,6 @@ use crate::{ - AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, - SavedContextMetadata, + SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId, + TextThreadOperation, TextThreadVersion, }; use anyhow::{Context as _, Result}; use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet}; @@ -11,9 +11,9 @@ use context_server::ContextServerId; use fs::{Fs, RemoveOptions}; use futures::StreamExt; use fuzzy::StringMatchCandidate; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity}; use language::LanguageRegistry; -use paths::contexts_dir; +use paths::text_threads_dir; use project::{ Project, context_server_store::{ContextServerStatus, ContextServerStore}, @@ -27,24 +27,24 @@ use util::{ResultExt, TryFutureExt}; use zed_env_vars::ZED_STATELESS; pub(crate) fn init(client: &AnyProtoClient) { - client.add_entity_message_handler(ContextStore::handle_advertise_contexts); - client.add_entity_request_handler(ContextStore::handle_open_context); - client.add_entity_request_handler(ContextStore::handle_create_context); - client.add_entity_message_handler(ContextStore::handle_update_context); - client.add_entity_request_handler(ContextStore::handle_synchronize_contexts); + client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts); + client.add_entity_request_handler(TextThreadStore::handle_open_context); + client.add_entity_request_handler(TextThreadStore::handle_create_context); + client.add_entity_message_handler(TextThreadStore::handle_update_context); + client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts); } #[derive(Clone)] -pub struct RemoteContextMetadata { - pub id: ContextId, +pub struct RemoteTextThreadMetadata { + pub id: TextThreadId, pub summary: Option, } -pub struct ContextStore { - contexts: Vec, - contexts_metadata: Vec, +pub struct TextThreadStore { + text_threads: Vec, + text_threads_metadata: Vec, context_server_slash_command_ids: HashMap>, - host_contexts: Vec, + host_text_threads: Vec, fs: Arc, languages: Arc, slash_commands: Arc, @@ -58,34 +58,28 @@ pub struct ContextStore { prompt_builder: Arc, } -pub enum ContextStoreEvent { - ContextCreated(ContextId), +enum TextThreadHandle { + Weak(WeakEntity), + Strong(Entity), } -impl EventEmitter for ContextStore {} - -enum ContextHandle { - Weak(WeakEntity), - Strong(Entity), -} - -impl ContextHandle { - fn upgrade(&self) -> Option> { +impl TextThreadHandle { + fn upgrade(&self) -> Option> { match self { - ContextHandle::Weak(weak) => weak.upgrade(), - ContextHandle::Strong(strong) => Some(strong.clone()), + TextThreadHandle::Weak(weak) => weak.upgrade(), + TextThreadHandle::Strong(strong) => Some(strong.clone()), } } - fn downgrade(&self) -> WeakEntity { + fn downgrade(&self) -> WeakEntity { match self { - ContextHandle::Weak(weak) => weak.clone(), - ContextHandle::Strong(strong) => strong.downgrade(), + TextThreadHandle::Weak(weak) => weak.clone(), + TextThreadHandle::Strong(strong) => strong.downgrade(), } } } -impl ContextStore { +impl TextThreadStore { pub fn new( project: Entity, prompt_builder: Arc, @@ -97,14 +91,14 @@ impl ContextStore { let telemetry = project.read(cx).client().telemetry().clone(); cx.spawn(async move |cx| { const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100); - let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; + let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new(|cx: &mut Context| { let mut this = Self { - contexts: Vec::new(), - contexts_metadata: Vec::new(), + text_threads: Vec::new(), + text_threads_metadata: Vec::new(), context_server_slash_command_ids: HashMap::default(), - host_contexts: Vec::new(), + host_text_threads: Vec::new(), fs, languages, slash_commands, @@ -142,10 +136,10 @@ impl ContextStore { #[cfg(any(test, feature = "test-support"))] pub fn fake(project: Entity, cx: &mut Context) -> Self { Self { - contexts: Default::default(), - contexts_metadata: Default::default(), + text_threads: Default::default(), + text_threads_metadata: Default::default(), context_server_slash_command_ids: Default::default(), - host_contexts: Default::default(), + host_text_threads: Default::default(), fs: project.read(cx).fs().clone(), languages: project.read(cx).languages().clone(), slash_commands: Arc::default(), @@ -166,13 +160,13 @@ impl ContextStore { mut cx: AsyncApp, ) -> Result<()> { this.update(&mut cx, |this, cx| { - this.host_contexts = envelope + this.host_text_threads = envelope .payload .contexts .into_iter() - .map(|context| RemoteContextMetadata { - id: ContextId::from_proto(context.context_id), - summary: context.summary, + .map(|text_thread| RemoteTextThreadMetadata { + id: TextThreadId::from_proto(text_thread.context_id), + summary: text_thread.summary, }) .collect(); cx.notify(); @@ -184,25 +178,25 @@ impl ContextStore { envelope: TypedEnvelope, mut cx: AsyncApp, ) -> Result { - let context_id = ContextId::from_proto(envelope.payload.context_id); + let context_id = TextThreadId::from_proto(envelope.payload.context_id); let operations = this.update(&mut cx, |this, cx| { anyhow::ensure!( !this.project.read(cx).is_via_collab(), "only the host contexts can be opened" ); - let context = this - .loaded_context_for_id(&context_id, cx) + let text_thread = this + .loaded_text_thread_for_id(&context_id, cx) .context("context not found")?; anyhow::ensure!( - context.read(cx).replica_id() == ReplicaId::default(), + text_thread.read(cx).replica_id() == ReplicaId::default(), "context must be opened via the host" ); anyhow::Ok( - context + text_thread .read(cx) - .serialize_ops(&ContextVersion::default(), cx), + .serialize_ops(&TextThreadVersion::default(), cx), ) })??; let operations = operations.await; @@ -222,15 +216,14 @@ impl ContextStore { "can only create contexts as the host" ); - let context = this.create(cx); - let context_id = context.read(cx).id().clone(); - cx.emit(ContextStoreEvent::ContextCreated(context_id.clone())); + let text_thread = this.create(cx); + let context_id = text_thread.read(cx).id().clone(); anyhow::Ok(( context_id, - context + text_thread .read(cx) - .serialize_ops(&ContextVersion::default(), cx), + .serialize_ops(&TextThreadVersion::default(), cx), )) })??; let operations = operations.await; @@ -246,11 +239,11 @@ impl ContextStore { mut cx: AsyncApp, ) -> Result<()> { this.update(&mut cx, |this, cx| { - let context_id = ContextId::from_proto(envelope.payload.context_id); - if let Some(context) = this.loaded_context_for_id(&context_id, cx) { + let context_id = TextThreadId::from_proto(envelope.payload.context_id); + if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) { let operation_proto = envelope.payload.operation.context("invalid operation")?; - let operation = ContextOperation::from_proto(operation_proto)?; - context.update(cx, |context, cx| context.apply_ops([operation], cx)); + let operation = TextThreadOperation::from_proto(operation_proto)?; + text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx)); } Ok(()) })? @@ -269,12 +262,12 @@ impl ContextStore { let mut local_versions = Vec::new(); for remote_version_proto in envelope.payload.contexts { - let remote_version = ContextVersion::from_proto(&remote_version_proto); - let context_id = ContextId::from_proto(remote_version_proto.context_id); - if let Some(context) = this.loaded_context_for_id(&context_id, cx) { - let context = context.read(cx); - let operations = context.serialize_ops(&remote_version, cx); - local_versions.push(context.version(cx).to_proto(context_id.clone())); + let remote_version = TextThreadVersion::from_proto(&remote_version_proto); + let context_id = TextThreadId::from_proto(remote_version_proto.context_id); + if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) { + let text_thread = text_thread.read(cx); + let operations = text_thread.serialize_ops(&remote_version, cx); + local_versions.push(text_thread.version(cx).to_proto(context_id.clone())); let client = this.client.clone(); let project_id = envelope.payload.project_id; cx.background_spawn(async move { @@ -308,9 +301,9 @@ impl ContextStore { } if is_shared { - self.contexts.retain_mut(|context| { - if let Some(strong_context) = context.upgrade() { - *context = ContextHandle::Strong(strong_context); + self.text_threads.retain_mut(|text_thread| { + if let Some(strong_context) = text_thread.upgrade() { + *text_thread = TextThreadHandle::Strong(strong_context); true } else { false @@ -345,12 +338,12 @@ impl ContextStore { self.synchronize_contexts(cx); } project::Event::DisconnectedFromHost => { - self.contexts.retain_mut(|context| { - if let Some(strong_context) = context.upgrade() { - *context = ContextHandle::Weak(context.downgrade()); - strong_context.update(cx, |context, cx| { - if context.replica_id() != ReplicaId::default() { - context.set_capability(language::Capability::ReadOnly, cx); + self.text_threads.retain_mut(|text_thread| { + if let Some(strong_context) = text_thread.upgrade() { + *text_thread = TextThreadHandle::Weak(text_thread.downgrade()); + strong_context.update(cx, |text_thread, cx| { + if text_thread.replica_id() != ReplicaId::default() { + text_thread.set_capability(language::Capability::ReadOnly, cx); } }); true @@ -358,20 +351,24 @@ impl ContextStore { false } }); - self.host_contexts.clear(); + self.host_text_threads.clear(); cx.notify(); } _ => {} } } - pub fn unordered_contexts(&self) -> impl Iterator { - self.contexts_metadata.iter() + pub fn unordered_text_threads(&self) -> impl Iterator { + self.text_threads_metadata.iter() } - pub fn create(&mut self, cx: &mut Context) -> Entity { + pub fn host_text_threads(&self) -> impl Iterator { + self.host_text_threads.iter() + } + + pub fn create(&mut self, cx: &mut Context) -> Entity { let context = cx.new(|cx| { - AssistantContext::local( + TextThread::local( self.languages.clone(), Some(self.project.clone()), Some(self.telemetry.clone()), @@ -380,14 +377,11 @@ impl ContextStore { cx, ) }); - self.register_context(&context, cx); + self.register_text_thread(&context, cx); context } - pub fn create_remote_context( - &mut self, - cx: &mut Context, - ) -> Task>> { + pub fn create_remote(&mut self, cx: &mut Context) -> Task>> { let project = self.project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); @@ -403,10 +397,10 @@ impl ContextStore { let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(async move |this, cx| { let response = request.await?; - let context_id = ContextId::from_proto(response.context_id); + let context_id = TextThreadId::from_proto(response.context_id); let context_proto = response.context.context("invalid context")?; - let context = cx.new(|cx| { - AssistantContext::new( + let text_thread = cx.new(|cx| { + TextThread::new( context_id.clone(), replica_id, capability, @@ -423,29 +417,29 @@ impl ContextStore { context_proto .operations .into_iter() - .map(ContextOperation::from_proto) + .map(TextThreadOperation::from_proto) .collect::>>() }) .await?; - context.update(cx, |context, cx| context.apply_ops(operations, cx))?; + text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?; this.update(cx, |this, cx| { - if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { + if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) { existing_context } else { - this.register_context(&context, cx); + this.register_text_thread(&text_thread, cx); this.synchronize_contexts(cx); - context + text_thread } }) }) } - pub fn open_local_context( + pub fn open_local( &mut self, path: Arc, cx: &Context, - ) -> Task>> { - if let Some(existing_context) = self.loaded_context_for_path(&path, cx) { + ) -> Task>> { + if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) { return Task::ready(Ok(existing_context)); } @@ -457,7 +451,7 @@ impl ContextStore { let path = path.clone(); async move { let saved_context = fs.load(&path).await?; - SavedContext::from_json(&saved_context) + SavedTextThread::from_json(&saved_context) } }); let prompt_builder = self.prompt_builder.clone(); @@ -466,7 +460,7 @@ impl ContextStore { cx.spawn(async move |this, cx| { let saved_context = load.await?; let context = cx.new(|cx| { - AssistantContext::deserialize( + TextThread::deserialize( saved_context, path.clone(), languages, @@ -478,21 +472,17 @@ impl ContextStore { ) })?; this.update(cx, |this, cx| { - if let Some(existing_context) = this.loaded_context_for_path(&path, cx) { + if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) { existing_context } else { - this.register_context(&context, cx); + this.register_text_thread(&context, cx); context } }) }) } - pub fn delete_local_context( - &mut self, - path: Arc, - cx: &mut Context, - ) -> Task> { + pub fn delete_local(&mut self, path: Arc, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { @@ -506,57 +496,57 @@ impl ContextStore { .await?; this.update(cx, |this, cx| { - this.contexts.retain(|context| { - context + this.text_threads.retain(|text_thread| { + text_thread .upgrade() - .and_then(|context| context.read(cx).path()) + .and_then(|text_thread| text_thread.read(cx).path()) != Some(&path) }); - this.contexts_metadata - .retain(|context| context.path.as_ref() != path.as_ref()); + this.text_threads_metadata + .retain(|text_thread| text_thread.path.as_ref() != path.as_ref()); })?; Ok(()) }) } - fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option> { - self.contexts.iter().find_map(|context| { - let context = context.upgrade()?; - if context.read(cx).path().map(Arc::as_ref) == Some(path) { - Some(context) + fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option> { + self.text_threads.iter().find_map(|text_thread| { + let text_thread = text_thread.upgrade()?; + if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) { + Some(text_thread) } else { None } }) } - pub fn loaded_context_for_id( + pub fn loaded_text_thread_for_id( &self, - id: &ContextId, + id: &TextThreadId, cx: &App, - ) -> Option> { - self.contexts.iter().find_map(|context| { - let context = context.upgrade()?; - if context.read(cx).id() == id { - Some(context) + ) -> Option> { + self.text_threads.iter().find_map(|text_thread| { + let text_thread = text_thread.upgrade()?; + if text_thread.read(cx).id() == id { + Some(text_thread) } else { None } }) } - pub fn open_remote_context( + pub fn open_remote( &mut self, - context_id: ContextId, + text_thread_id: TextThreadId, cx: &mut Context, - ) -> Task>> { + ) -> Task>> { let project = self.project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); }; - if let Some(context) = self.loaded_context_for_id(&context_id, cx) { + if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) { return Task::ready(Ok(context)); } @@ -567,16 +557,16 @@ impl ContextStore { let telemetry = self.telemetry.clone(); let request = self.client.request(proto::OpenContext { project_id, - context_id: context_id.to_proto(), + context_id: text_thread_id.to_proto(), }); let prompt_builder = self.prompt_builder.clone(); let slash_commands = self.slash_commands.clone(); cx.spawn(async move |this, cx| { let response = request.await?; let context_proto = response.context.context("invalid context")?; - let context = cx.new(|cx| { - AssistantContext::new( - context_id.clone(), + let text_thread = cx.new(|cx| { + TextThread::new( + text_thread_id.clone(), replica_id, capability, language_registry, @@ -592,38 +582,40 @@ impl ContextStore { context_proto .operations .into_iter() - .map(ContextOperation::from_proto) + .map(TextThreadOperation::from_proto) .collect::>>() }) .await?; - context.update(cx, |context, cx| context.apply_ops(operations, cx))?; + text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?; this.update(cx, |this, cx| { - if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { + if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx) + { existing_context } else { - this.register_context(&context, cx); + this.register_text_thread(&text_thread, cx); this.synchronize_contexts(cx); - context + text_thread } }) }) } - fn register_context(&mut self, context: &Entity, cx: &mut Context) { + fn register_text_thread(&mut self, text_thread: &Entity, cx: &mut Context) { let handle = if self.project_is_shared { - ContextHandle::Strong(context.clone()) + TextThreadHandle::Strong(text_thread.clone()) } else { - ContextHandle::Weak(context.downgrade()) + TextThreadHandle::Weak(text_thread.downgrade()) }; - self.contexts.push(handle); + self.text_threads.push(handle); self.advertise_contexts(cx); - cx.subscribe(context, Self::handle_context_event).detach(); + cx.subscribe(text_thread, Self::handle_context_event) + .detach(); } fn handle_context_event( &mut self, - context: Entity, - event: &ContextEvent, + text_thread: Entity, + event: &TextThreadEvent, cx: &mut Context, ) { let Some(project_id) = self.project.read(cx).remote_id() else { @@ -631,12 +623,12 @@ impl ContextStore { }; match event { - ContextEvent::SummaryChanged => { + TextThreadEvent::SummaryChanged => { self.advertise_contexts(cx); } - ContextEvent::PathChanged { old_path, new_path } => { + TextThreadEvent::PathChanged { old_path, new_path } => { if let Some(old_path) = old_path.as_ref() { - for metadata in &mut self.contexts_metadata { + for metadata in &mut self.text_threads_metadata { if &metadata.path == old_path { metadata.path = new_path.clone(); break; @@ -644,8 +636,8 @@ impl ContextStore { } } } - ContextEvent::Operation(operation) => { - let context_id = context.read(cx).id().to_proto(); + TextThreadEvent::Operation(operation) => { + let context_id = text_thread.read(cx).id().to_proto(); let operation = operation.to_proto(); self.client .send(proto::UpdateContext { @@ -670,15 +662,15 @@ impl ContextStore { } let contexts = self - .contexts + .text_threads .iter() .rev() - .filter_map(|context| { - let context = context.upgrade()?.read(cx); - if context.replica_id() == ReplicaId::default() { + .filter_map(|text_thread| { + let text_thread = text_thread.upgrade()?.read(cx); + if text_thread.replica_id() == ReplicaId::default() { Some(proto::ContextMetadata { - context_id: context.id().to_proto(), - summary: context + context_id: text_thread.id().to_proto(), + summary: text_thread .summary() .content() .map(|summary| summary.text.clone()), @@ -701,13 +693,13 @@ impl ContextStore { return; }; - let contexts = self - .contexts + let text_threads = self + .text_threads .iter() - .filter_map(|context| { - let context = context.upgrade()?.read(cx); - if context.replica_id() != ReplicaId::default() { - Some(context.version(cx).to_proto(context.id().clone())) + .filter_map(|text_thread| { + let text_thread = text_thread.upgrade()?.read(cx); + if text_thread.replica_id() != ReplicaId::default() { + Some(text_thread.version(cx).to_proto(text_thread.id().clone())) } else { None } @@ -717,26 +709,27 @@ impl ContextStore { let client = self.client.clone(); let request = self.client.request(proto::SynchronizeContexts { project_id, - contexts, + contexts: text_threads, }); cx.spawn(async move |this, cx| { let response = request.await?; - let mut context_ids = Vec::new(); + let mut text_thread_ids = Vec::new(); let mut operations = Vec::new(); this.read_with(cx, |this, cx| { for context_version_proto in response.contexts { - let context_version = ContextVersion::from_proto(&context_version_proto); - let context_id = ContextId::from_proto(context_version_proto.context_id); - if let Some(context) = this.loaded_context_for_id(&context_id, cx) { - context_ids.push(context_id); - operations.push(context.read(cx).serialize_ops(&context_version, cx)); + let text_thread_version = TextThreadVersion::from_proto(&context_version_proto); + let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id); + if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) { + text_thread_ids.push(text_thread_id); + operations + .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx)); } } })?; let operations = futures::future::join_all(operations).await; - for (context_id, operations) in context_ids.into_iter().zip(operations) { + for (context_id, operations) in text_thread_ids.into_iter().zip(operations) { for operation in operations { client.send(proto::UpdateContext { project_id, @@ -751,8 +744,8 @@ impl ContextStore { .detach_and_log_err(cx); } - pub fn search(&self, query: String, cx: &App) -> Task> { - let metadata = self.contexts_metadata.clone(); + pub fn search(&self, query: String, cx: &App) -> Task> { + let metadata = self.text_threads_metadata.clone(); let executor = cx.background_executor().clone(); cx.background_spawn(async move { if query.is_empty() { @@ -782,20 +775,16 @@ impl ContextStore { }) } - pub fn host_contexts(&self) -> &[RemoteContextMetadata] { - &self.host_contexts - } - fn reload(&mut self, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { if *ZED_STATELESS { return Ok(()); } - fs.create_dir(contexts_dir()).await?; + fs.create_dir(text_threads_dir()).await?; - let mut paths = fs.read_dir(contexts_dir()).await?; - let mut contexts = Vec::::new(); + let mut paths = fs.read_dir(text_threads_dir()).await?; + let mut contexts = Vec::::new(); while let Some(path) = paths.next().await { let path = path?; if path.extension() != Some(OsStr::new("json")) { @@ -821,7 +810,7 @@ impl ContextStore { .lines() .next() { - contexts.push(SavedContextMetadata { + contexts.push(SavedTextThreadMetadata { title: title.to_string().into(), path: path.into(), mtime: metadata.mtime.timestamp_for_user().into(), @@ -829,10 +818,10 @@ impl ContextStore { } } } - contexts.sort_unstable_by_key(|context| Reverse(context.mtime)); + contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime)); this.update(cx, |this, cx| { - this.contexts_metadata = contexts; + this.text_threads_metadata = contexts; cx.notify(); }) }) diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 52dbe46107501325e305a7e8e6e7bd9bb483affb..c8467da7954b195c0eef09ce1bed8361d7fa2c7b 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -73,7 +73,7 @@ uuid.workspace = true [dev-dependencies] agent_settings.workspace = true -assistant_context.workspace = true +assistant_text_thread.workspace = true assistant_slash_command.workspace = true async-trait.workspace = true audio.workspace = true diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index cc2c01b7857a1efefd88b47d2ea199fc571051ea..4fa32b6c9ba55e6962547510f52251f16fc9be81 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,8 +6,8 @@ use crate::{ }, }; use anyhow::{Result, anyhow}; -use assistant_context::ContextStore; use assistant_slash_command::SlashCommandWorkingSet; +use assistant_text_thread::TextThreadStore; use buffer_diff::{DiffHunkSecondaryStatus, DiffHunkStatus, assert_hunks}; use call::{ActiveCall, ParticipantLocation, Room, room}; use client::{RECEIVE_TIMEOUT, User}; @@ -6877,9 +6877,9 @@ async fn test_context_collaboration_with_reconnect( }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context_store_a = cx_a + let text_thread_store_a = cx_a .update(|cx| { - ContextStore::new( + TextThreadStore::new( project_a.clone(), prompt_builder.clone(), Arc::new(SlashCommandWorkingSet::default()), @@ -6888,9 +6888,9 @@ async fn test_context_collaboration_with_reconnect( }) .await .unwrap(); - let context_store_b = cx_b + let text_thread_store_b = cx_b .update(|cx| { - ContextStore::new( + TextThreadStore::new( project_b.clone(), prompt_builder.clone(), Arc::new(SlashCommandWorkingSet::default()), @@ -6901,60 +6901,60 @@ async fn test_context_collaboration_with_reconnect( .unwrap(); // Client A creates a new chats. - let context_a = context_store_a.update(cx_a, |store, cx| store.create(cx)); + let text_thread_a = text_thread_store_a.update(cx_a, |store, cx| store.create(cx)); executor.run_until_parked(); // Client B retrieves host's contexts and joins one. - let context_b = context_store_b + let text_thread_b = text_thread_store_b .update(cx_b, |store, cx| { - let host_contexts = store.host_contexts().to_vec(); - assert_eq!(host_contexts.len(), 1); - store.open_remote_context(host_contexts[0].id.clone(), cx) + let host_text_threads = store.host_text_threads().collect::>(); + assert_eq!(host_text_threads.len(), 1); + store.open_remote(host_text_threads[0].id.clone(), cx) }) .await .unwrap(); // Host and guest make changes - context_a.update(cx_a, |context, cx| { - context.buffer().update(cx, |buffer, cx| { + text_thread_a.update(cx_a, |text_thread, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "Host change\n")], None, cx) }) }); - context_b.update(cx_b, |context, cx| { - context.buffer().update(cx, |buffer, cx| { + text_thread_b.update(cx_b, |text_thread, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "Guest change\n")], None, cx) }) }); executor.run_until_parked(); assert_eq!( - context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + text_thread_a.read_with(cx_a, |text_thread, cx| text_thread.buffer().read(cx).text()), "Guest change\nHost change\n" ); assert_eq!( - context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + text_thread_b.read_with(cx_b, |text_thread, cx| text_thread.buffer().read(cx).text()), "Guest change\nHost change\n" ); // Disconnect client A and make some changes while disconnected. server.disconnect_client(client_a.peer_id().unwrap()); server.forbid_connections(); - context_a.update(cx_a, |context, cx| { - context.buffer().update(cx, |buffer, cx| { + text_thread_a.update(cx_a, |text_thread, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "Host offline change\n")], None, cx) }) }); - context_b.update(cx_b, |context, cx| { - context.buffer().update(cx, |buffer, cx| { + text_thread_b.update(cx_b, |text_thread, cx| { + text_thread.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "Guest offline change\n")], None, cx) }) }); executor.run_until_parked(); assert_eq!( - context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + text_thread_a.read_with(cx_a, |text_thread, cx| text_thread.buffer().read(cx).text()), "Host offline change\nGuest change\nHost change\n" ); assert_eq!( - context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + text_thread_b.read_with(cx_b, |text_thread, cx| text_thread.buffer().read(cx).text()), "Guest offline change\nGuest change\nHost change\n" ); @@ -6962,11 +6962,11 @@ async fn test_context_collaboration_with_reconnect( server.allow_connections(); executor.advance_clock(RECEIVE_TIMEOUT); assert_eq!( - context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + text_thread_a.read_with(cx_a, |text_thread, cx| text_thread.buffer().read(cx).text()), "Guest offline change\nHost offline change\nGuest change\nHost change\n" ); assert_eq!( - context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + text_thread_b.read_with(cx_b, |text_thread, cx| text_thread.buffer().read(cx).text()), "Guest offline change\nHost offline change\nGuest change\nHost change\n" ); @@ -6974,8 +6974,8 @@ async fn test_context_collaboration_with_reconnect( server.forbid_connections(); server.disconnect_client(client_a.peer_id().unwrap()); executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - context_b.read_with(cx_b, |context, cx| { - assert!(context.buffer().read(cx).read_only()); + text_thread_b.read_with(cx_b, |text_thread, cx| { + assert!(text_thread.buffer().read(cx).read_only()); }); } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 528253f0dc2e9d4dc8b88a7d8d8c2926be2b2652..fbff269494f3f1ae5fb48d124ad090e61a558f31 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -358,7 +358,7 @@ impl TestServer { settings::KeymapFile::load_asset_allow_partial_failure(os_keymap, cx).unwrap(), ); language_model::LanguageModelRegistry::test(cx); - assistant_context::init(client.clone(), cx); + assistant_text_thread::init(client.clone(), cx); agent_settings::init(cx); }); diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index bbb6ddb976312b7baca5a11ace863b4a3be8d2bc..207e1f3bb4324d17784b1d8df53ba4bfbc4adddb 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -288,7 +288,7 @@ pub fn snippets_dir() -> &'static PathBuf { /// Returns the path to the contexts directory. /// /// This is where the saved contexts from the Assistant are stored. -pub fn contexts_dir() -> &'static PathBuf { +pub fn text_threads_dir() -> &'static PathBuf { static CONTEXTS_DIR: OnceLock = OnceLock::new(); CONTEXTS_DIR.get_or_init(|| { if cfg!(target_os = "macos") {