Detailed changes
@@ -377,6 +377,7 @@ dependencies = [
"cargo_toml",
"chrono",
"client",
+ "clock",
"collections",
"command_palette_hooks",
"ctor",
@@ -419,6 +420,7 @@ dependencies = [
"telemetry_events",
"terminal",
"terminal_view",
+ "text",
"theme",
"tiktoken-rs",
"toml 0.8.10",
@@ -2405,6 +2407,7 @@ version = "0.1.0"
dependencies = [
"chrono",
"parking_lot",
+ "serde",
"smallvec",
]
@@ -2463,6 +2466,7 @@ version = "0.44.0"
dependencies = [
"anthropic",
"anyhow",
+ "assistant",
"async-trait",
"async-tungstenite",
"audio",
@@ -12,6 +12,14 @@ workspace = true
path = "src/assistant.rs"
doctest = false
+[features]
+test-support = [
+ "editor/test-support",
+ "language/test-support",
+ "project/test-support",
+ "text/test-support",
+]
+
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
@@ -21,6 +29,7 @@ breadcrumbs.workspace = true
cargo_toml.workspace = true
chrono.workspace = true
client.workspace = true
+clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
@@ -72,7 +81,9 @@ picker.workspace = true
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
+language = { workspace = true, features = ["test-support"] }
log.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
+text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
@@ -1,7 +1,8 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod completion_provider;
-mod context_store;
+mod context;
+pub mod context_store;
mod inline_assistant;
mod model_selector;
mod prompt_library;
@@ -16,8 +17,9 @@ use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaMo
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
-pub(crate) use completion_provider::*;
-pub(crate) use context_store::*;
+pub use completion_provider::*;
+pub use context::*;
+pub use context_store::*;
use fs::Fs;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
@@ -57,10 +59,14 @@ actions!(
]
);
-#[derive(
- Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
-)]
-struct MessageId(usize);
+#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub struct MessageId(clock::Lamport);
+
+impl MessageId {
+ pub fn as_u64(self) -> u64 {
+ self.0.as_u64()
+ }
+}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -71,8 +77,26 @@ pub enum Role {
}
impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
+ pub fn from_proto(role: i32) -> Role {
+ match proto::LanguageModelRole::from_i32(role) {
+ Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
+ Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
+ Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
+ Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
+ None => Role::User,
+ }
+ }
+
+ pub fn to_proto(&self) -> proto::LanguageModelRole {
+ match self {
+ Role::User => proto::LanguageModelRole::LanguageModelUser,
+ Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
+ Role::System => proto::LanguageModelRole::LanguageModelSystem,
+ }
+ }
+
+ pub fn cycle(self) -> Role {
+ match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
@@ -151,11 +175,7 @@ pub struct LanguageModelRequestMessage {
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
- role: match self.role {
- Role::User => proto::LanguageModelRole::LanguageModelUser,
- Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
- Role::System => proto::LanguageModelRole::LanguageModelSystem,
- } as i32,
+ role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
@@ -222,19 +242,48 @@ pub struct LanguageModelChoiceDelta {
pub finish_reason: Option<String>,
}
-#[derive(Clone, Debug, Serialize, Deserialize)]
-struct MessageMetadata {
- role: Role,
- status: MessageStatus,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-enum MessageStatus {
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub enum MessageStatus {
Pending,
Done,
Error(SharedString),
}
+impl MessageStatus {
+ pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
+ match status.variant {
+ Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
+ Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
+ Some(proto::context_message_status::Variant::Error(error)) => {
+ MessageStatus::Error(error.message.into())
+ }
+ None => MessageStatus::Pending,
+ }
+ }
+
+ pub fn to_proto(&self) -> proto::ContextMessageStatus {
+ match self {
+ MessageStatus::Pending => proto::ContextMessageStatus {
+ variant: Some(proto::context_message_status::Variant::Pending(
+ proto::context_message_status::Pending {},
+ )),
+ },
+ MessageStatus::Done => proto::ContextMessageStatus {
+ variant: Some(proto::context_message_status::Variant::Done(
+ proto::context_message_status::Done {},
+ )),
+ },
+ MessageStatus::Error(message) => proto::ContextMessageStatus {
+ variant: Some(proto::context_message_status::Variant::Error(
+ proto::context_message_status::Error {
+ message: message.to_string(),
+ },
+ )),
+ },
+ }
+ }
+}
+
/// The state pertaining to the Assistant.
#[derive(Default)]
struct Assistant {
@@ -287,6 +336,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
})
.detach();
+ context_store::init(&client);
prompt_library::init(cx);
completion_provider::init(client.clone(), cx);
assistant_slash_command::init(cx);
@@ -1,24 +1,23 @@
-use crate::slash_command::docs_command::{DocsSlashCommand, DocsSlashCommandArgs};
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
- humanize_token_count,
+ humanize_token_count, parse_next_edit_suggestion,
prompt_library::open_prompt_library,
search::*,
slash_command::{
- default_command::DefaultSlashCommand, SlashCommandCompletionProvider, SlashCommandLine,
- SlashCommandRegistry,
+ default_command::DefaultSlashCommand,
+ docs_command::{DocsSlashCommand, DocsSlashCommandArgs},
+ SlashCommandCompletionProvider, SlashCommandRegistry,
},
terminal_inline_assistant::TerminalInlineAssistant,
- ApplyEdit, Assist, CompletionProvider, ConfirmCommand, ContextStore, CycleMessageRole,
- DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistant, InsertIntoEditor,
- LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
- ModelSelector, QuoteSelection, ResetKey, Role, SavedContext, SavedContextMetadata,
- SavedMessage, Split, ToggleFocus, ToggleModelSelector,
+ ApplyEdit, Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId,
+ ContextStore, CycleMessageRole, DeployHistory, DeployPromptLibrary, EditSuggestion,
+ InlineAssist, InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector,
+ PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata,
+ ResetKey, Role, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
};
use anyhow::{anyhow, Result};
-use assistant_slash_command::{SlashCommand, SlashCommandOutput, SlashCommandOutputSection};
+use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use breadcrumbs::Breadcrumbs;
-use client::telemetry::Telemetry;
use collections::{BTreeSet, HashMap, HashSet};
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
@@ -30,44 +29,33 @@ use editor::{
};
use editor::{display_map::CreaseId, FoldPlaceholder};
use fs::Fs;
-use futures::future::Shared;
-use futures::{FutureExt, StreamExt};
use gpui::{
div, percentage, point, Action, Animation, AnimationExt, AnyElement, AnyView, AppContext,
- AsyncAppContext, AsyncWindowContext, ClipboardItem, Context as _, DismissEvent, Empty,
- EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, ModelContext,
- ParentElement, Pixels, Render, SharedString, StatefulInteractiveElement, Styled, Subscription,
- Task, Transformation, UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext,
+ AsyncWindowContext, ClipboardItem, DismissEvent, Empty, EventEmitter, FocusHandle,
+ FocusableView, InteractiveElement, IntoElement, Model, ParentElement, Pixels, Render,
+ SharedString, StatefulInteractiveElement, Styled, Subscription, Task, Transformation,
+ UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext,
};
use indexed_docs::IndexedDocsStore;
use language::{
- language_settings::SoftWrap, AnchorRangeExt as _, AutoindentMode, Buffer, LanguageRegistry,
- LspAdapterDelegate, OffsetRangeExt as _, Point, ToOffset as _,
+ language_settings::SoftWrap, AutoindentMode, Buffer, LanguageRegistry, LspAdapterDelegate,
+ OffsetRangeExt as _, Point, ToOffset,
};
use multi_buffer::MultiBufferRow;
-use paths::contexts_dir;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate, ProjectTransaction};
use search::{buffer_search::DivRegistrar, BufferSearchBar};
use settings::Settings;
-use std::{
- cmp::{self, Ordering},
- fmt::Write,
- iter,
- ops::Range,
- path::PathBuf,
- sync::Arc,
- time::{Duration, Instant},
-};
-use telemetry_events::AssistantKind;
+use std::{cmp, fmt::Write, ops::Range, path::PathBuf, sync::Arc, time::Duration};
use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
use theme::ThemeSettings;
use ui::{
- prelude::*, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem,
+ prelude::*,
+ utils::{format_distance_from_now, DateTimeType},
+ Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem,
ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip,
};
-use util::{post_inc, ResultExt, TryFutureExt};
-use uuid::Uuid;
+use util::ResultExt;
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
item::{BreadcrumbText, Item, ItemHandle},
@@ -106,24 +94,30 @@ pub struct AssistantPanel {
workspace: WeakView<Workspace>,
width: Option<Pixels>,
height: Option<Pixels>,
+ project: Model<Project>,
context_store: Model<ContextStore>,
languages: Arc<LanguageRegistry>,
- slash_commands: Arc<SlashCommandRegistry>,
fs: Arc<dyn Fs>,
- telemetry: Arc<Telemetry>,
subscriptions: Vec<Subscription>,
authentication_prompt: Option<AnyView>,
model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
}
+#[derive(Clone)]
+enum ContextMetadata {
+ Remote(RemoteContextMetadata),
+ Saved(SavedContextMetadata),
+}
+
struct SavedContextPickerDelegate {
store: Model<ContextStore>,
- matches: Vec<SavedContextMetadata>,
+ project: Model<Project>,
+ matches: Vec<ContextMetadata>,
selected_index: usize,
}
enum SavedContextPickerEvent {
- Confirmed { path: PathBuf },
+ Confirmed(ContextMetadata),
}
enum InlineAssistTarget {
@@ -134,8 +128,9 @@ enum InlineAssistTarget {
impl EventEmitter<SavedContextPickerEvent> for Picker<SavedContextPickerDelegate> {}
impl SavedContextPickerDelegate {
- fn new(store: Model<ContextStore>) -> Self {
+ fn new(project: Model<Project>, store: Model<ContextStore>) -> Self {
Self {
+ project,
store,
matches: Vec::new(),
selected_index: 0,
@@ -167,7 +162,13 @@ impl PickerDelegate for SavedContextPickerDelegate {
cx.spawn(|this, mut cx| async move {
let matches = search.await;
this.update(&mut cx, |this, cx| {
- this.delegate.matches = matches;
+ let host_contexts = this.delegate.store.read(cx).host_contexts();
+ this.delegate.matches = host_contexts
+ .iter()
+ .cloned()
+ .map(ContextMetadata::Remote)
+ .chain(matches.into_iter().map(ContextMetadata::Saved))
+ .collect();
this.delegate.selected_index = 0;
cx.notify();
})
@@ -177,9 +178,7 @@ impl PickerDelegate for SavedContextPickerDelegate {
fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
if let Some(metadata) = self.matches.get(self.selected_index) {
- cx.emit(SavedContextPickerEvent::Confirmed {
- path: metadata.path.clone(),
- })
+ cx.emit(SavedContextPickerEvent::Confirmed(metadata.clone()));
}
}
@@ -189,26 +188,78 @@ impl PickerDelegate for SavedContextPickerDelegate {
&self,
ix: usize,
selected: bool,
- _cx: &mut ViewContext<Picker<Self>>,
+ cx: &mut ViewContext<Picker<Self>>,
) -> Option<Self::ListItem> {
let context = self.matches.get(ix)?;
+ let item = match context {
+ ContextMetadata::Remote(context) => {
+ let host_user = self.project.read(cx).host().and_then(|collaborator| {
+ self.project
+ .read(cx)
+ .user_store()
+ .read(cx)
+ .get_cached_user(collaborator.user_id)
+ });
+ div()
+ .flex()
+ .w_full()
+ .justify_between()
+ .gap_2()
+ .child(
+ h_flex().flex_1().overflow_x_hidden().child(
+ Label::new(context.summary.clone().unwrap_or("New Context".into()))
+ .size(LabelSize::Small),
+ ),
+ )
+ .child(
+ h_flex()
+ .gap_2()
+ .children(if let Some(host_user) = host_user {
+ vec![
+ Avatar::new(host_user.avatar_uri.clone())
+ .shape(AvatarShape::Circle)
+ .into_any_element(),
+ Label::new(format!("Shared by @{}", host_user.github_login))
+ .color(Color::Muted)
+ .size(LabelSize::Small)
+ .into_any_element(),
+ ]
+ } else {
+ vec![Label::new("Shared by host")
+ .color(Color::Muted)
+ .size(LabelSize::Small)
+ .into_any_element()]
+ }),
+ )
+ }
+ ContextMetadata::Saved(context) => div()
+ .flex()
+ .w_full()
+ .justify_between()
+ .gap_2()
+ .child(
+ h_flex()
+ .flex_1()
+ .child(Label::new(context.title.clone()).size(LabelSize::Small))
+ .overflow_x_hidden(),
+ )
+ .child(
+ Label::new(format_distance_from_now(
+ DateTimeType::Local(context.mtime),
+ false,
+ true,
+ true,
+ ))
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ };
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.selected(selected)
- .child(
- div()
- .flex()
- .w_full()
- .gap_2()
- .child(
- Label::new(context.mtime.format("%F %I:%M%p").to_string())
- .color(Color::Muted)
- .size(LabelSize::Small),
- )
- .child(Label::new(context.title.clone()).size(LabelSize::Small)),
- ),
+ .child(item),
)
}
}
@@ -219,11 +270,14 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
- // TODO: deserialize state.
- let fs = workspace.update(&mut cx, |workspace, _| workspace.app_state().fs.clone())?;
- let context_store = cx.update(|cx| ContextStore::new(fs.clone(), cx))?.await?;
+ let context_store = workspace
+ .update(&mut cx, |workspace, cx| {
+ ContextStore::new(workspace.project().clone(), cx)
+ })?
+ .await?;
workspace.update(&mut cx, |workspace, cx| {
- cx.new_view(|cx| Self::new(workspace, context_store.clone(), cx))
+ // TODO: deserialize state.
+ cx.new_view(|cx| Self::new(workspace, context_store, cx))
})
})
}
@@ -308,11 +362,10 @@ impl AssistantPanel {
workspace: workspace.weak_handle(),
width: None,
height: None,
+ project: workspace.project().clone(),
context_store,
languages: workspace.app_state().languages.clone(),
- slash_commands: SlashCommandRegistry::global(cx),
fs: workspace.app_state().fs.clone(),
- telemetry: workspace.client().telemetry().clone(),
subscriptions,
authentication_prompt: None,
model_selector_menu_handle,
@@ -519,16 +572,22 @@ impl AssistantPanel {
}
fn new_context(&mut self, cx: &mut ViewContext<Self>) -> Option<View<ContextEditor>> {
+ let context = self.context_store.update(cx, |store, cx| store.create(cx));
let workspace = self.workspace.upgrade()?;
+ let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| {
+ make_lsp_adapter_delegate(workspace.project(), cx).log_err()
+ });
let editor = cx.new_view(|cx| {
- ContextEditor::new(
- self.languages.clone(),
- self.slash_commands.clone(),
+ let mut editor = ContextEditor::for_context(
+ context,
self.fs.clone(),
workspace,
+ lsp_adapter_delegate,
cx,
- )
+ );
+ editor.insert_default_prompt(cx);
+ editor
});
self.show_context(editor.clone(), cx);
@@ -577,7 +636,12 @@ impl AssistantPanel {
} else {
let assistant_panel = cx.view().downgrade();
let history = cx.new_view(|cx| {
- ContextHistory::new(self.context_store.clone(), assistant_panel, cx)
+ ContextHistory::new(
+ self.project.clone(),
+ self.context_store.clone(),
+ assistant_panel,
+ cx,
+ )
});
self.pane.update(cx, |pane, cx| {
pane.add_item(Box::new(history), true, true, None, cx);
@@ -610,10 +674,14 @@ impl AssistantPanel {
Some(self.active_context_editor(cx)?.read(cx).context.clone())
}
- fn open_context(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ fn open_saved_context(
+ &mut self,
+ path: PathBuf,
+ cx: &mut ViewContext<Self>,
+ ) -> Task<Result<()>> {
let existing_context = self.pane.read(cx).items().find_map(|item| {
item.downcast::<ContextEditor>()
- .filter(|editor| editor.read(cx).context.read(cx).path.as_ref() == Some(&path))
+ .filter(|editor| editor.read(cx).context.read(cx).path() == Some(&path))
});
if let Some(existing_context) = existing_context {
return cx.spawn(|this, mut cx| async move {
@@ -621,12 +689,11 @@ impl AssistantPanel {
});
}
- let saved_context = self.context_store.read(cx).load(path.clone(), cx);
+ let context = self
+ .context_store
+ .update(cx, |store, cx| store.open_local_context(path.clone(), cx));
let fs = self.fs.clone();
let workspace = self.workspace.clone();
- let slash_commands = self.slash_commands.clone();
- let languages = self.languages.clone();
- let telemetry = self.telemetry.clone();
let lsp_adapter_delegate = workspace
.update(cx, |workspace, cx| {
@@ -636,17 +703,51 @@ impl AssistantPanel {
.flatten();
cx.spawn(|this, mut cx| async move {
- let saved_context = saved_context.await?;
- let context = Context::deserialize(
- saved_context,
- path,
- languages,
- slash_commands,
- Some(telemetry),
- &mut cx,
- )
- .await?;
+ let context = context.await?;
+ this.update(&mut cx, |this, cx| {
+ let workspace = workspace
+ .upgrade()
+ .ok_or_else(|| anyhow!("workspace dropped"))?;
+ let editor = cx.new_view(|cx| {
+ ContextEditor::for_context(context, fs, workspace, lsp_adapter_delegate, cx)
+ });
+ this.show_context(editor, cx);
+ anyhow::Ok(())
+ })??;
+ Ok(())
+ })
+ }
+ fn open_remote_context(
+ &mut self,
+ id: ContextId,
+ cx: &mut ViewContext<Self>,
+ ) -> Task<Result<()>> {
+ let existing_context = self.pane.read(cx).items().find_map(|item| {
+ item.downcast::<ContextEditor>()
+ .filter(|editor| *editor.read(cx).context.read(cx).id() == id)
+ });
+ if let Some(existing_context) = existing_context {
+ return cx.spawn(|this, mut cx| async move {
+ this.update(&mut cx, |this, cx| this.show_context(existing_context, cx))
+ });
+ }
+
+ let context = self
+ .context_store
+ .update(cx, |store, cx| store.open_remote_context(id, cx));
+ let fs = self.fs.clone();
+ let workspace = self.workspace.clone();
+
+ let lsp_adapter_delegate = workspace
+ .update(cx, |workspace, cx| {
+ make_lsp_adapter_delegate(workspace.project(), cx).log_err()
+ })
+ .log_err()
+ .flatten();
+
+ cx.spawn(|this, mut cx| async move {
+ let context = context.await?;
this.update(&mut cx, |this, cx| {
let workspace = workspace
.upgrade()
@@ -730,1272 +831,78 @@ impl Panel for AssistantPanel {
let dock = match position {
DockPosition::Left => AssistantDockPosition::Left,
DockPosition::Bottom => AssistantDockPosition::Bottom,
- DockPosition::Right => AssistantDockPosition::Right,
- };
- settings.set_dock(dock);
- });
- }
-
- fn size(&self, cx: &WindowContext) -> Pixels {
- let settings = AssistantSettings::get_global(cx);
- match self.position(cx) {
- DockPosition::Left | DockPosition::Right => {
- self.width.unwrap_or(settings.default_width)
- }
- DockPosition::Bottom => self.height.unwrap_or(settings.default_height),
- }
- }
-
- fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
- match self.position(cx) {
- DockPosition::Left | DockPosition::Right => self.width = size,
- DockPosition::Bottom => self.height = size,
- }
- cx.notify();
- }
-
- fn is_zoomed(&self, cx: &WindowContext) -> bool {
- self.pane.read(cx).is_zoomed()
- }
-
- fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
- self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
- }
-
- fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
- if active {
- let load_credentials = self.authenticate(cx);
- cx.spawn(|this, mut cx| async move {
- load_credentials.await?;
- this.update(&mut cx, |this, cx| {
- if this.is_authenticated(cx) && this.active_context_editor(cx).is_none() {
- this.new_context(cx);
- }
- })
- })
- .detach_and_log_err(cx);
- }
- }
-
- fn icon(&self, cx: &WindowContext) -> Option<IconName> {
- let settings = AssistantSettings::get_global(cx);
- if !settings.enabled || !settings.button {
- return None;
- }
-
- Some(IconName::ZedAssistant)
- }
-
- fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
- Some("Assistant Panel")
- }
-
- fn toggle_action(&self) -> Box<dyn Action> {
- Box::new(ToggleFocus)
- }
-}
-
-impl EventEmitter<PanelEvent> for AssistantPanel {}
-impl EventEmitter<AssistantPanelEvent> for AssistantPanel {}
-
-impl FocusableView for AssistantPanel {
- fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
- self.pane.focus_handle(cx)
- }
-}
-
-#[derive(Clone)]
-enum ContextEvent {
- MessagesEdited,
- SummaryChanged,
- EditSuggestionsChanged,
- StreamedCompletion,
- PendingSlashCommandsUpdated {
- removed: Vec<Range<language::Anchor>>,
- updated: Vec<PendingSlashCommand>,
- },
- SlashCommandFinished {
- output_range: Range<language::Anchor>,
- sections: Vec<SlashCommandOutputSection<language::Anchor>>,
- run_commands_in_output: bool,
- },
-}
-
-#[derive(Default)]
-struct Summary {
- text: String,
- done: bool,
-}
-
-pub struct Context {
- id: Option<String>,
- buffer: Model<Buffer>,
- edit_suggestions: Vec<EditSuggestion>,
- pending_slash_commands: Vec<PendingSlashCommand>,
- edits_since_last_slash_command_parse: language::Subscription,
- slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
- message_anchors: Vec<MessageAnchor>,
- messages_metadata: HashMap<MessageId, MessageMetadata>,
- next_message_id: MessageId,
- summary: Option<Summary>,
- pending_summary: Task<Option<()>>,
- completion_count: usize,
- pending_completions: Vec<PendingCompletion>,
- token_count: Option<usize>,
- pending_token_count: Task<Option<()>>,
- pending_edit_suggestion_parse: Option<Task<()>>,
- pending_save: Task<Result<()>>,
- path: Option<PathBuf>,
- _subscriptions: Vec<Subscription>,
- telemetry: Option<Arc<Telemetry>>,
- slash_command_registry: Arc<SlashCommandRegistry>,
- language_registry: Arc<LanguageRegistry>,
-}
-
-impl EventEmitter<ContextEvent> for Context {}
-
-impl Context {
- fn new(
- language_registry: Arc<LanguageRegistry>,
- slash_command_registry: Arc<SlashCommandRegistry>,
- telemetry: Option<Arc<Telemetry>>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let buffer = cx.new_model(|cx| {
- let mut buffer = Buffer::local("", cx);
- buffer.set_language_registry(language_registry.clone());
- buffer
- });
- let edits_since_last_slash_command_parse =
- buffer.update(cx, |buffer, _| buffer.subscribe());
- let mut this = Self {
- id: Some(Uuid::new_v4().to_string()),
- message_anchors: Default::default(),
- messages_metadata: Default::default(),
- next_message_id: Default::default(),
- edit_suggestions: Vec::new(),
- pending_slash_commands: Vec::new(),
- slash_command_output_sections: Vec::new(),
- edits_since_last_slash_command_parse,
- summary: None,
- pending_summary: Task::ready(None),
- completion_count: Default::default(),
- pending_completions: Default::default(),
- token_count: None,
- pending_token_count: Task::ready(None),
- pending_edit_suggestion_parse: None,
- _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
- pending_save: Task::ready(Ok(())),
- path: None,
- buffer,
- telemetry,
- language_registry,
- slash_command_registry,
- };
-
- let message = MessageAnchor {
- id: MessageId(post_inc(&mut this.next_message_id.0)),
- start: language::Anchor::MIN,
- };
- this.message_anchors.push(message.clone());
- this.messages_metadata.insert(
- message.id,
- MessageMetadata {
- role: Role::User,
- status: MessageStatus::Done,
- },
- );
-
- this.set_language(cx);
- this.count_remaining_tokens(cx);
- this
- }
-
- fn serialize(&self, cx: &AppContext) -> SavedContext {
- let buffer = self.buffer.read(cx);
- SavedContext {
- id: self.id.clone(),
- zed: "context".into(),
- version: SavedContext::VERSION.into(),
- text: buffer.text(),
- message_metadata: self.messages_metadata.clone(),
- messages: self
- .messages(cx)
- .map(|message| SavedMessage {
- id: message.id,
- start: message.offset_range.start,
- })
- .collect(),
- summary: self
- .summary
- .as_ref()
- .map(|summary| summary.text.clone())
- .unwrap_or_default(),
- slash_command_output_sections: self
- .slash_command_output_sections
- .iter()
- .filter_map(|section| {
- let range = section.range.to_offset(buffer);
- if section.range.start.is_valid(buffer) && !range.is_empty() {
- Some(SlashCommandOutputSection {
- range,
- icon: section.icon,
- label: section.label.clone(),
- })
- } else {
- None
- }
- })
- .collect(),
- }
- }
-
- #[allow(clippy::too_many_arguments)]
- async fn deserialize(
- saved_context: SavedContext,
- path: PathBuf,
- language_registry: Arc<LanguageRegistry>,
- slash_command_registry: Arc<SlashCommandRegistry>,
- telemetry: Option<Arc<Telemetry>>,
- cx: &mut AsyncAppContext,
- ) -> Result<Model<Self>> {
- let id = match saved_context.id {
- Some(id) => Some(id),
- None => Some(Uuid::new_v4().to_string()),
- };
-
- let markdown = language_registry.language_for_name("Markdown");
- let mut message_anchors = Vec::new();
- let mut next_message_id = MessageId(0);
- let buffer = cx.new_model(|cx| {
- let mut buffer = Buffer::local(saved_context.text, cx);
- for message in saved_context.messages {
- message_anchors.push(MessageAnchor {
- id: message.id,
- start: buffer.anchor_before(message.start),
- });
- next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
- }
- buffer.set_language_registry(language_registry.clone());
- cx.spawn(|buffer, mut cx| async move {
- let markdown = markdown.await?;
- buffer.update(&mut cx, |buffer: &mut Buffer, cx| {
- buffer.set_language(Some(markdown), cx)
- })?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- buffer
- })?;
-
- cx.new_model(move |cx| {
- let edits_since_last_slash_command_parse =
- buffer.update(cx, |buffer, _| buffer.subscribe());
- let mut this = Self {
- id,
- message_anchors,
- messages_metadata: saved_context.message_metadata,
- next_message_id,
- edit_suggestions: Vec::new(),
- pending_slash_commands: Vec::new(),
- slash_command_output_sections: saved_context
- .slash_command_output_sections
- .into_iter()
- .map(|section| {
- let buffer = buffer.read(cx);
- SlashCommandOutputSection {
- range: buffer.anchor_after(section.range.start)
- ..buffer.anchor_before(section.range.end),
- icon: section.icon,
- label: section.label,
- }
- })
- .collect(),
- edits_since_last_slash_command_parse,
- summary: Some(Summary {
- text: saved_context.summary,
- done: true,
- }),
- pending_summary: Task::ready(None),
- completion_count: Default::default(),
- pending_completions: Default::default(),
- token_count: None,
- pending_edit_suggestion_parse: None,
- pending_token_count: Task::ready(None),
- _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
- pending_save: Task::ready(Ok(())),
- path: Some(path),
- buffer,
- telemetry,
- language_registry,
- slash_command_registry,
- };
- this.set_language(cx);
- this.reparse_edit_suggestions(cx);
- this.count_remaining_tokens(cx);
- this
- })
- }
-
- fn set_language(&mut self, cx: &mut ModelContext<Self>) {
- let markdown = self.language_registry.language_for_name("Markdown");
- cx.spawn(|this, mut cx| async move {
- let markdown = markdown.await?;
- this.update(&mut cx, |this, cx| {
- this.buffer
- .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
- })
- })
- .detach_and_log_err(cx);
- }
-
- fn handle_buffer_event(
- &mut self,
- _: Model<Buffer>,
- event: &language::Event,
- cx: &mut ModelContext<Self>,
- ) {
- if *event == language::Event::Edited {
- self.count_remaining_tokens(cx);
- self.reparse_edit_suggestions(cx);
- self.reparse_slash_commands(cx);
- cx.emit(ContextEvent::MessagesEdited);
- }
- }
-
- pub(crate) fn token_count(&self) -> Option<usize> {
- self.token_count
- }
-
- pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
- let request = self.to_completion_request(cx);
- self.pending_token_count = cx.spawn(|this, mut cx| {
- async move {
- cx.background_executor()
- .timer(Duration::from_millis(200))
- .await;
-
- let token_count = cx
- .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
- .await?;
-
- this.update(&mut cx, |this, cx| {
- this.token_count = Some(token_count);
- cx.notify()
- })?;
- anyhow::Ok(())
- }
- .log_err()
- });
- }
-
- fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
- let buffer = self.buffer.read(cx);
- let mut row_ranges = self
- .edits_since_last_slash_command_parse
- .consume()
- .into_iter()
- .map(|edit| {
- let start_row = buffer.offset_to_point(edit.new.start).row;
- let end_row = buffer.offset_to_point(edit.new.end).row + 1;
- start_row..end_row
- })
- .peekable();
-
- let mut removed = Vec::new();
- let mut updated = Vec::new();
- while let Some(mut row_range) = row_ranges.next() {
- while let Some(next_row_range) = row_ranges.peek() {
- if row_range.end >= next_row_range.start {
- row_range.end = next_row_range.end;
- row_ranges.next();
- } else {
- break;
- }
- }
-
- let start = buffer.anchor_before(Point::new(row_range.start, 0));
- let end = buffer.anchor_after(Point::new(
- row_range.end - 1,
- buffer.line_len(row_range.end - 1),
- ));
-
- let old_range = self.pending_command_indices_for_range(start..end, cx);
-
- let mut new_commands = Vec::new();
- let mut lines = buffer.text_for_range(start..end).lines();
- let mut offset = lines.offset();
- while let Some(line) = lines.next() {
- if let Some(command_line) = SlashCommandLine::parse(line) {
- let name = &line[command_line.name.clone()];
- let argument = command_line.argument.as_ref().and_then(|argument| {
- (!argument.is_empty()).then_some(&line[argument.clone()])
- });
- if let Some(command) = self.slash_command_registry.command(name) {
- if !command.requires_argument() || argument.is_some() {
- let start_ix = offset + command_line.name.start - 1;
- let end_ix = offset
- + command_line
- .argument
- .map_or(command_line.name.end, |argument| argument.end);
- let source_range =
- buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
- let pending_command = PendingSlashCommand {
- name: name.to_string(),
- argument: argument.map(ToString::to_string),
- source_range,
- status: PendingSlashCommandStatus::Idle,
- };
- updated.push(pending_command.clone());
- new_commands.push(pending_command);
- }
- }
- }
-
- offset = lines.offset();
- }
-
- let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
- removed.extend(removed_commands.map(|command| command.source_range));
- }
-
- if !updated.is_empty() || !removed.is_empty() {
- cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
- }
- }
-
- fn reparse_edit_suggestions(&mut self, cx: &mut ModelContext<Self>) {
- self.pending_edit_suggestion_parse = Some(cx.spawn(|this, mut cx| async move {
- cx.background_executor()
- .timer(Duration::from_millis(200))
- .await;
-
- this.update(&mut cx, |this, cx| {
- this.reparse_edit_suggestions_in_range(0..this.buffer.read(cx).len(), cx);
- })
- .ok();
- }));
- }
-
- fn reparse_edit_suggestions_in_range(
- &mut self,
- range: Range<usize>,
- cx: &mut ModelContext<Self>,
- ) {
- self.buffer.update(cx, |buffer, _| {
- let range_start = buffer.anchor_before(range.start);
- let range_end = buffer.anchor_after(range.end);
- let start_ix = self
- .edit_suggestions
- .binary_search_by(|probe| {
- probe
- .source_range
- .end
- .cmp(&range_start, buffer)
- .then(Ordering::Greater)
- })
- .unwrap_err();
- let end_ix = self
- .edit_suggestions
- .binary_search_by(|probe| {
- probe
- .source_range
- .start
- .cmp(&range_end, buffer)
- .then(Ordering::Less)
- })
- .unwrap_err();
-
- let mut new_edit_suggestions = Vec::new();
- let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
- while let Some(suggestion) = parse_next_edit_suggestion(&mut message_lines) {
- let start_anchor = buffer.anchor_after(suggestion.outer_range.start);
- let end_anchor = buffer.anchor_before(suggestion.outer_range.end);
- new_edit_suggestions.push(EditSuggestion {
- source_range: start_anchor..end_anchor,
- full_path: suggestion.path,
- });
- }
- self.edit_suggestions
- .splice(start_ix..end_ix, new_edit_suggestions);
- });
- cx.emit(ContextEvent::EditSuggestionsChanged);
- cx.notify();
- }
-
- fn pending_command_for_position(
- &mut self,
- position: language::Anchor,
- cx: &mut ModelContext<Self>,
- ) -> Option<&mut PendingSlashCommand> {
- let buffer = self.buffer.read(cx);
- match self
- .pending_slash_commands
- .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
- {
- Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
- Err(ix) => {
- let cmd = self.pending_slash_commands.get_mut(ix)?;
- if position.cmp(&cmd.source_range.start, buffer).is_ge()
- && position.cmp(&cmd.source_range.end, buffer).is_le()
- {
- Some(cmd)
- } else {
- None
- }
- }
- }
- }
-
- fn pending_commands_for_range(
- &self,
- range: Range<language::Anchor>,
- cx: &AppContext,
- ) -> &[PendingSlashCommand] {
- let range = self.pending_command_indices_for_range(range, cx);
- &self.pending_slash_commands[range]
- }
-
- fn pending_command_indices_for_range(
- &self,
- range: Range<language::Anchor>,
- cx: &AppContext,
- ) -> Range<usize> {
- let buffer = self.buffer.read(cx);
- let start_ix = match self
- .pending_slash_commands
- .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
- {
- Ok(ix) | Err(ix) => ix,
- };
- let end_ix = match self
- .pending_slash_commands
- .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
- {
- Ok(ix) => ix + 1,
- Err(ix) => ix,
- };
- start_ix..end_ix
- }
-
- fn insert_command_output(
- &mut self,
- command_range: Range<language::Anchor>,
- output: Task<Result<SlashCommandOutput>>,
- insert_trailing_newline: bool,
- cx: &mut ModelContext<Self>,
- ) {
- self.reparse_slash_commands(cx);
-
- let insert_output_task = cx.spawn(|this, mut cx| {
- let command_range = command_range.clone();
- async move {
- let output = output.await;
- this.update(&mut cx, |this, cx| match output {
- Ok(mut output) => {
- if insert_trailing_newline {
- output.text.push('\n');
- }
-
- let event = this.buffer.update(cx, |buffer, cx| {
- let start = command_range.start.to_offset(buffer);
- let old_end = command_range.end.to_offset(buffer);
- let new_end = start + output.text.len();
- buffer.edit([(start..old_end, output.text)], None, cx);
-
- let mut sections = output
- .sections
- .into_iter()
- .map(|section| SlashCommandOutputSection {
- range: buffer.anchor_after(start + section.range.start)
- ..buffer.anchor_before(start + section.range.end),
- icon: section.icon,
- label: section.label,
- })
- .collect::<Vec<_>>();
- sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
-
- this.slash_command_output_sections
- .extend(sections.iter().cloned());
- this.slash_command_output_sections
- .sort_by(|a, b| a.range.cmp(&b.range, buffer));
-
- ContextEvent::SlashCommandFinished {
- output_range: buffer.anchor_after(start)
- ..buffer.anchor_before(new_end),
- sections,
- run_commands_in_output: output.run_commands_in_text,
- }
- });
- cx.emit(event);
- }
- Err(error) => {
- if let Some(pending_command) =
- this.pending_command_for_position(command_range.start, cx)
- {
- pending_command.status =
- PendingSlashCommandStatus::Error(error.to_string());
- cx.emit(ContextEvent::PendingSlashCommandsUpdated {
- removed: vec![pending_command.source_range.clone()],
- updated: vec![pending_command.clone()],
- });
- }
- }
- })
- .ok();
- }
- });
-
- if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
- pending_command.status = PendingSlashCommandStatus::Running {
- _task: insert_output_task.shared(),
- };
- cx.emit(ContextEvent::PendingSlashCommandsUpdated {
- removed: vec![pending_command.source_range.clone()],
- updated: vec![pending_command.clone()],
- });
- }
- }
-
- fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
- self.count_remaining_tokens(cx);
- }
-
- fn assist(
- &mut self,
- selected_messages: HashSet<MessageId>,
- cx: &mut ModelContext<Self>,
- ) -> Vec<MessageAnchor> {
- let mut user_messages = Vec::new();
-
- let last_message_id = if let Some(last_message_id) =
- self.message_anchors.iter().rev().find_map(|message| {
- message
- .start
- .is_valid(self.buffer.read(cx))
- .then_some(message.id)
- }) {
- last_message_id
- } else {
- return Default::default();
- };
-
- let mut should_assist = false;
- for selected_message_id in selected_messages {
- let selected_message_role =
- if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
- metadata.role
- } else {
- continue;
- };
-
- if selected_message_role == Role::Assistant {
- if let Some(user_message) = self.insert_message_after(
- selected_message_id,
- Role::User,
- MessageStatus::Done,
- cx,
- ) {
- user_messages.push(user_message);
- }
- } else {
- should_assist = true;
- }
- }
-
- if should_assist {
- if !CompletionProvider::global(cx).is_authenticated() {
- log::info!("completion provider has no credentials");
- return Default::default();
- }
-
- let request = self.to_completion_request(cx);
- let response = CompletionProvider::global(cx).complete(request, cx);
- let assistant_message = self
- .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
- .unwrap();
-
- // Queue up the user's next reply.
- let user_message = self
- .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
- .unwrap();
- user_messages.push(user_message);
-
- let task = cx.spawn({
- |this, mut cx| async move {
- let response = response.await;
- let assistant_message_id = assistant_message.id;
- let mut response_latency = None;
- let stream_completion = async {
- let request_start = Instant::now();
- let mut messages = response.inner.await?;
-
- while let Some(message) = messages.next().await {
- if response_latency.is_none() {
- response_latency = Some(request_start.elapsed());
- }
- let text = message?;
-
- this.update(&mut cx, |this, cx| {
- let message_ix = this
- .message_anchors
- .iter()
- .position(|message| message.id == assistant_message_id)?;
- let message_range = this.buffer.update(cx, |buffer, cx| {
- let message_start_offset =
- this.message_anchors[message_ix].start.to_offset(buffer);
- let message_old_end_offset = this.message_anchors
- [message_ix + 1..]
- .iter()
- .find(|message| message.start.is_valid(buffer))
- .map_or(buffer.len(), |message| {
- message.start.to_offset(buffer).saturating_sub(1)
- });
- let message_new_end_offset =
- message_old_end_offset + text.len();
- buffer.edit(
- [(message_old_end_offset..message_old_end_offset, text)],
- None,
- cx,
- );
- message_start_offset..message_new_end_offset
- });
- this.reparse_edit_suggestions_in_range(message_range, cx);
- cx.emit(ContextEvent::StreamedCompletion);
-
- Some(())
- })?;
- smol::future::yield_now().await;
- }
-
- this.update(&mut cx, |this, cx| {
- this.pending_completions
- .retain(|completion| completion.id != this.completion_count);
- this.summarize(cx);
- })?;
-
- anyhow::Ok(())
- };
-
- let result = stream_completion.await;
-
- this.update(&mut cx, |this, cx| {
- if let Some(metadata) =
- this.messages_metadata.get_mut(&assistant_message.id)
- {
- let error_message = result
- .err()
- .map(|error| error.to_string().trim().to_string());
- if let Some(error_message) = error_message.as_ref() {
- metadata.status =
- MessageStatus::Error(SharedString::from(error_message.clone()));
- } else {
- metadata.status = MessageStatus::Done;
- }
-
- if let Some(telemetry) = this.telemetry.as_ref() {
- let model = CompletionProvider::global(cx).model();
- telemetry.report_assistant_event(
- this.id.clone(),
- AssistantKind::Panel,
- model.telemetry_id(),
- response_latency,
- error_message,
- );
- }
-
- cx.emit(ContextEvent::MessagesEdited);
- }
- })
- .ok();
- }
- });
-
- self.pending_completions.push(PendingCompletion {
- id: post_inc(&mut self.completion_count),
- _task: task,
- });
- }
-
- user_messages
- }
-
- pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
- let messages = self
- .messages(cx)
- .filter(|message| matches!(message.status, MessageStatus::Done))
- .map(|message| message.to_request_message(self.buffer.read(cx)));
-
- LanguageModelRequest {
- model: CompletionProvider::global(cx).model(),
- messages: messages.collect(),
- stop: vec![],
- temperature: 1.0,
- }
- }
-
- fn cancel_last_assist(&mut self) -> bool {
- self.pending_completions.pop().is_some()
- }
-
- fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
- for id in ids {
- if let Some(metadata) = self.messages_metadata.get_mut(&id) {
- metadata.role.cycle();
- cx.emit(ContextEvent::MessagesEdited);
- cx.notify();
- }
- }
- }
-
- fn insert_message_after(
- &mut self,
- message_id: MessageId,
- role: Role,
- status: MessageStatus,
- cx: &mut ModelContext<Self>,
- ) -> Option<MessageAnchor> {
- if let Some(prev_message_ix) = self
- .message_anchors
- .iter()
- .position(|message| message.id == message_id)
- {
- // Find the next valid message after the one we were given.
- let mut next_message_ix = prev_message_ix + 1;
- while let Some(next_message) = self.message_anchors.get(next_message_ix) {
- if next_message.start.is_valid(self.buffer.read(cx)) {
- break;
- }
- next_message_ix += 1;
- }
-
- let start = self.buffer.update(cx, |buffer, cx| {
- let offset = self
- .message_anchors
- .get(next_message_ix)
- .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
- buffer.edit([(offset..offset, "\n")], None, cx);
- buffer.anchor_before(offset + 1)
- });
- let message = MessageAnchor {
- id: MessageId(post_inc(&mut self.next_message_id.0)),
- start,
- };
- self.message_anchors
- .insert(next_message_ix, message.clone());
- self.messages_metadata
- .insert(message.id, MessageMetadata { role, status });
- cx.emit(ContextEvent::MessagesEdited);
- Some(message)
- } else {
- None
- }
- }
-
- fn split_message(
- &mut self,
- range: Range<usize>,
- cx: &mut ModelContext<Self>,
- ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
- let start_message = self.message_for_offset(range.start, cx);
- let end_message = self.message_for_offset(range.end, cx);
- if let Some((start_message, end_message)) = start_message.zip(end_message) {
- // Prevent splitting when range spans multiple messages.
- if start_message.id != end_message.id {
- return (None, None);
- }
-
- let message = start_message;
- let role = message.role;
- let mut edited_buffer = false;
-
- let mut suffix_start = None;
- if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
- {
- if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
- suffix_start = Some(range.end + 1);
- } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
- suffix_start = Some(range.end);
- }
- }
-
- let suffix = if let Some(suffix_start) = suffix_start {
- MessageAnchor {
- id: MessageId(post_inc(&mut self.next_message_id.0)),
- start: self.buffer.read(cx).anchor_before(suffix_start),
- }
- } else {
- self.buffer.update(cx, |buffer, cx| {
- buffer.edit([(range.end..range.end, "\n")], None, cx);
- });
- edited_buffer = true;
- MessageAnchor {
- id: MessageId(post_inc(&mut self.next_message_id.0)),
- start: self.buffer.read(cx).anchor_before(range.end + 1),
- }
- };
-
- self.message_anchors
- .insert(message.index_range.end + 1, suffix.clone());
- self.messages_metadata.insert(
- suffix.id,
- MessageMetadata {
- role,
- status: MessageStatus::Done,
- },
- );
-
- let new_messages =
- if range.start == range.end || range.start == message.offset_range.start {
- (None, Some(suffix))
- } else {
- let mut prefix_end = None;
- if range.start > message.offset_range.start
- && range.end < message.offset_range.end - 1
- {
- if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
- prefix_end = Some(range.start + 1);
- } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
- == Some('\n')
- {
- prefix_end = Some(range.start);
- }
- }
-
- let selection = if let Some(prefix_end) = prefix_end {
- cx.emit(ContextEvent::MessagesEdited);
- MessageAnchor {
- id: MessageId(post_inc(&mut self.next_message_id.0)),
- start: self.buffer.read(cx).anchor_before(prefix_end),
- }
- } else {
- self.buffer.update(cx, |buffer, cx| {
- buffer.edit([(range.start..range.start, "\n")], None, cx)
- });
- edited_buffer = true;
- MessageAnchor {
- id: MessageId(post_inc(&mut self.next_message_id.0)),
- start: self.buffer.read(cx).anchor_before(range.end + 1),
- }
- };
-
- self.message_anchors
- .insert(message.index_range.end + 1, selection.clone());
- self.messages_metadata.insert(
- selection.id,
- MessageMetadata {
- role,
- status: MessageStatus::Done,
- },
- );
- (Some(selection), Some(suffix))
- };
-
- if !edited_buffer {
- cx.emit(ContextEvent::MessagesEdited);
- }
- new_messages
- } else {
- (None, None)
- }
- }
-
- fn summarize(&mut self, cx: &mut ModelContext<Self>) {
- if self.message_anchors.len() >= 2 && self.summary.is_none() {
- if !CompletionProvider::global(cx).is_authenticated() {
- return;
- }
-
- let messages = self
- .messages(cx)
- .map(|message| message.to_request_message(self.buffer.read(cx)))
- .chain(Some(LanguageModelRequestMessage {
- role: Role::User,
- content: "Summarize the context into a short title without punctuation.".into(),
- }));
- let request = LanguageModelRequest {
- model: CompletionProvider::global(cx).model(),
- messages: messages.collect(),
- stop: vec![],
- temperature: 1.0,
- };
-
- let response = CompletionProvider::global(cx).complete(request, cx);
- self.pending_summary = cx.spawn(|this, mut cx| {
- async move {
- let response = response.await;
- let mut messages = response.inner.await?;
-
- while let Some(message) = messages.next().await {
- let text = message?;
- let mut lines = text.lines();
- this.update(&mut cx, |this, cx| {
- let summary = this.summary.get_or_insert(Default::default());
- summary.text.extend(lines.next());
- cx.emit(ContextEvent::SummaryChanged);
- })?;
-
- // Stop if the LLM generated multiple lines.
- if lines.next().is_some() {
- break;
- }
- }
-
- this.update(&mut cx, |this, cx| {
- if let Some(summary) = this.summary.as_mut() {
- summary.done = true;
- cx.emit(ContextEvent::SummaryChanged);
- }
- })?;
-
- anyhow::Ok(())
- }
- .log_err()
- });
- }
- }
-
- fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
- self.messages_for_offsets([offset], cx).pop()
- }
-
- fn messages_for_offsets(
- &self,
- offsets: impl IntoIterator<Item = usize>,
- cx: &AppContext,
- ) -> Vec<Message> {
- let mut result = Vec::new();
-
- let mut messages = self.messages(cx).peekable();
- let mut offsets = offsets.into_iter().peekable();
- let mut current_message = messages.next();
- while let Some(offset) = offsets.next() {
- // Locate the message that contains the offset.
- while current_message.as_ref().map_or(false, |message| {
- !message.offset_range.contains(&offset) && messages.peek().is_some()
- }) {
- current_message = messages.next();
- }
- let Some(message) = current_message.as_ref() else {
- break;
+ DockPosition::Right => AssistantDockPosition::Right,
};
+ settings.set_dock(dock);
+ });
+ }
- // Skip offsets that are in the same message.
- while offsets.peek().map_or(false, |offset| {
- message.offset_range.contains(offset) || messages.peek().is_none()
- }) {
- offsets.next();
+ fn size(&self, cx: &WindowContext) -> Pixels {
+ let settings = AssistantSettings::get_global(cx);
+ match self.position(cx) {
+ DockPosition::Left | DockPosition::Right => {
+ self.width.unwrap_or(settings.default_width)
}
-
- result.push(message.clone());
+ DockPosition::Bottom => self.height.unwrap_or(settings.default_height),
}
- result
}
- fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
- let buffer = self.buffer.read(cx);
- let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
- iter::from_fn(move || {
- if let Some((start_ix, message_anchor)) = message_anchors.next() {
- let metadata = self.messages_metadata.get(&message_anchor.id)?;
- let message_start = message_anchor.start.to_offset(buffer);
- let mut message_end = None;
- let mut end_ix = start_ix;
- while let Some((_, next_message)) = message_anchors.peek() {
- if next_message.start.is_valid(buffer) {
- message_end = Some(next_message.start);
- break;
- } else {
- end_ix += 1;
- message_anchors.next();
- }
- }
- let message_end = message_end
- .unwrap_or(language::Anchor::MAX)
- .to_offset(buffer);
-
- return Some(Message {
- index_range: start_ix..end_ix,
- offset_range: message_start..message_end,
- id: message_anchor.id,
- anchor: message_anchor.start,
- role: metadata.role,
- status: metadata.status.clone(),
- });
- }
- None
- })
+ fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
+ match self.position(cx) {
+ DockPosition::Left | DockPosition::Right => self.width = size,
+ DockPosition::Bottom => self.height = size,
+ }
+ cx.notify();
}
- fn save(
- &mut self,
- debounce: Option<Duration>,
- fs: Arc<dyn Fs>,
- cx: &mut ModelContext<Context>,
- ) {
- self.pending_save = cx.spawn(|this, mut cx| async move {
- if let Some(debounce) = debounce {
- cx.background_executor().timer(debounce).await;
- }
+ fn is_zoomed(&self, cx: &WindowContext) -> bool {
+ self.pane.read(cx).is_zoomed()
+ }
- let (old_path, summary) = this.read_with(&cx, |this, _| {
- let path = this.path.clone();
- let summary = if let Some(summary) = this.summary.as_ref() {
- if summary.done {
- Some(summary.text.clone())
- } else {
- None
- }
- } else {
- None
- };
- (path, summary)
- })?;
+ fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
+ self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
+ }
- if let Some(summary) = summary {
- let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
- let path = if let Some(old_path) = old_path {
- old_path
- } else {
- let mut discriminant = 1;
- let mut new_path;
- loop {
- new_path = contexts_dir().join(&format!(
- "{} - {}.zed.json",
- summary.trim(),
- discriminant
- ));
- if fs.is_file(&new_path).await {
- discriminant += 1;
- } else {
- break;
- }
+ fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
+ if active {
+ let load_credentials = self.authenticate(cx);
+ cx.spawn(|this, mut cx| async move {
+ load_credentials.await?;
+ this.update(&mut cx, |this, cx| {
+ if this.is_authenticated(cx) && this.active_context_editor(cx).is_none() {
+ this.new_context(cx);
}
- new_path
- };
-
- fs.create_dir(contexts_dir().as_ref()).await?;
- fs.atomic_write(path.clone(), serde_json::to_string(&context).unwrap())
- .await?;
- this.update(&mut cx, |this, _| this.path = Some(path))?;
- }
-
- Ok(())
- });
+ })
+ })
+ .detach_and_log_err(cx);
+ }
}
-}
-
-#[derive(Debug)]
-enum EditParsingState {
- None,
- InOldText {
- path: PathBuf,
- start_offset: usize,
- old_text_start_offset: usize,
- },
- InNewText {
- path: PathBuf,
- start_offset: usize,
- old_text_range: Range<usize>,
- new_text_start_offset: usize,
- },
-}
-#[derive(Clone, Debug, PartialEq)]
-struct EditSuggestion {
- source_range: Range<language::Anchor>,
- full_path: PathBuf,
-}
+ fn icon(&self, cx: &WindowContext) -> Option<IconName> {
+ let settings = AssistantSettings::get_global(cx);
+ if !settings.enabled || !settings.button {
+ return None;
+ }
-struct ParsedEditSuggestion {
- path: PathBuf,
- outer_range: Range<usize>,
- old_text_range: Range<usize>,
- new_text_range: Range<usize>,
-}
+ Some(IconName::ZedAssistant)
+ }
-fn parse_next_edit_suggestion(lines: &mut rope::Lines) -> Option<ParsedEditSuggestion> {
- let mut state = EditParsingState::None;
- loop {
- let offset = lines.offset();
- let message_line = lines.next()?;
- match state {
- EditParsingState::None => {
- if let Some(rest) = message_line.strip_prefix("```edit ") {
- let path = rest.trim();
- if !path.is_empty() {
- state = EditParsingState::InOldText {
- path: PathBuf::from(path),
- start_offset: offset,
- old_text_start_offset: lines.offset(),
- };
- }
- }
- }
- EditParsingState::InOldText {
- path,
- start_offset,
- old_text_start_offset,
- } => {
- if message_line == "---" {
- state = EditParsingState::InNewText {
- path,
- start_offset,
- old_text_range: old_text_start_offset..offset,
- new_text_start_offset: lines.offset(),
- };
- } else {
- state = EditParsingState::InOldText {
- path,
- start_offset,
- old_text_start_offset,
- };
- }
- }
- EditParsingState::InNewText {
- path,
- start_offset,
- old_text_range,
- new_text_start_offset,
- } => {
- if message_line == "```" {
- return Some(ParsedEditSuggestion {
- path,
- outer_range: start_offset..offset + "```".len(),
- old_text_range,
- new_text_range: new_text_start_offset..offset,
- });
- } else {
- state = EditParsingState::InNewText {
- path,
- start_offset,
- old_text_range,
- new_text_start_offset,
- };
- }
- }
- }
+ fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
+ Some("Assistant Panel")
}
-}
-#[derive(Clone)]
-struct PendingSlashCommand {
- name: String,
- argument: Option<String>,
- status: PendingSlashCommandStatus,
- source_range: Range<language::Anchor>,
+ fn toggle_action(&self) -> Box<dyn Action> {
+ Box::new(ToggleFocus)
+ }
}
-#[derive(Clone)]
-enum PendingSlashCommandStatus {
- Idle,
- Running { _task: Shared<Task<()>> },
- Error(String),
-}
+impl EventEmitter<PanelEvent> for AssistantPanel {}
+impl EventEmitter<AssistantPanelEvent> for AssistantPanel {}
-struct PendingCompletion {
- id: usize,
- _task: Task<()>,
+impl FocusableView for AssistantPanel {
+ fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
+ self.pane.focus_handle(cx)
+ }
}
pub enum ContextEditorEvent {
@@ -1,13 +1,13 @@
mod anthropic;
mod cloud;
-#[cfg(test)]
+#[cfg(any(test, feature = "test-support"))]
mod fake;
mod ollama;
mod open_ai;
pub use anthropic::*;
pub use cloud::*;
-#[cfg(test)]
+#[cfg(any(test, feature = "test-support"))]
pub use fake::*;
pub use ollama::*;
pub use open_ai::*;
@@ -13,7 +13,6 @@ pub struct FakeCompletionProvider {
}
impl FakeCompletionProvider {
- #[cfg(test)]
pub fn setup_test(cx: &mut AppContext) -> Self {
use crate::CompletionProvider;
use parking_lot::RwLock;
@@ -0,0 +1,3009 @@
+use crate::{
+ slash_command::SlashCommandLine, CompletionProvider, LanguageModelRequest,
+ LanguageModelRequestMessage, MessageId, MessageStatus, Role,
+};
+use anyhow::{anyhow, Context as _, Result};
+use assistant_slash_command::{
+ SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
+};
+use client::{proto, telemetry::Telemetry};
+use clock::ReplicaId;
+use collections::{HashMap, HashSet};
+use fs::Fs;
+use futures::{future::Shared, FutureExt, StreamExt};
+use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
+use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
+use open_ai::Model as OpenAiModel;
+use paths::contexts_dir;
+use serde::{Deserialize, Serialize};
+use std::{
+ cmp::Ordering,
+ iter, mem,
+ ops::Range,
+ path::{Path, PathBuf},
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use telemetry_events::AssistantKind;
+use ui::SharedString;
+use util::{post_inc, TryFutureExt};
+use uuid::Uuid;
+
+#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
+pub struct ContextId(String);
+
+impl ContextId {
+ pub fn new() -> Self {
+ Self(Uuid::new_v4().to_string())
+ }
+
+ pub fn from_proto(id: String) -> Self {
+ Self(id)
+ }
+
+ pub fn to_proto(&self) -> String {
+ self.0.clone()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub enum ContextOperation {
+ InsertMessage {
+ anchor: MessageAnchor,
+ metadata: MessageMetadata,
+ version: clock::Global,
+ },
+ UpdateMessage {
+ message_id: MessageId,
+ metadata: MessageMetadata,
+ version: clock::Global,
+ },
+ UpdateSummary {
+ summary: ContextSummary,
+ version: clock::Global,
+ },
+ SlashCommandFinished {
+ id: SlashCommandId,
+ output_range: Range<language::Anchor>,
+ sections: Vec<SlashCommandOutputSection<language::Anchor>>,
+ version: clock::Global,
+ },
+ BufferOperation(language::Operation),
+}
+
+impl ContextOperation {
+ pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
+ match op.variant.context("invalid variant")? {
+ proto::context_operation::Variant::InsertMessage(insert) => {
+ let message = insert.message.context("invalid message")?;
+ let id = MessageId(language::proto::deserialize_timestamp(
+ message.id.context("invalid id")?,
+ ));
+ Ok(Self::InsertMessage {
+ anchor: MessageAnchor {
+ id,
+ start: language::proto::deserialize_anchor(
+ message.start.context("invalid anchor")?,
+ )
+ .context("invalid anchor")?,
+ },
+ metadata: MessageMetadata {
+ role: Role::from_proto(message.role),
+ status: MessageStatus::from_proto(
+ message.status.context("invalid status")?,
+ ),
+ timestamp: id.0,
+ },
+ version: language::proto::deserialize_version(&insert.version),
+ })
+ }
+ proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
+ message_id: MessageId(language::proto::deserialize_timestamp(
+ update.message_id.context("invalid message id")?,
+ )),
+ metadata: MessageMetadata {
+ role: Role::from_proto(update.role),
+ status: MessageStatus::from_proto(update.status.context("invalid status")?),
+ timestamp: language::proto::deserialize_timestamp(
+ update.timestamp.context("invalid timestamp")?,
+ ),
+ },
+ version: language::proto::deserialize_version(&update.version),
+ }),
+ proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
+ summary: ContextSummary {
+ text: update.summary,
+ done: update.done,
+ timestamp: language::proto::deserialize_timestamp(
+ update.timestamp.context("invalid timestamp")?,
+ ),
+ },
+ version: language::proto::deserialize_version(&update.version),
+ }),
+ proto::context_operation::Variant::SlashCommandFinished(finished) => {
+ Ok(Self::SlashCommandFinished {
+ id: SlashCommandId(language::proto::deserialize_timestamp(
+ finished.id.context("invalid id")?,
+ )),
+ output_range: language::proto::deserialize_anchor_range(
+ finished.output_range.context("invalid range")?,
+ )?,
+ sections: finished
+ .sections
+ .into_iter()
+ .map(|section| {
+ Ok(SlashCommandOutputSection {
+ range: language::proto::deserialize_anchor_range(
+ section.range.context("invalid range")?,
+ )?,
+ icon: section.icon_name.parse()?,
+ label: section.label.into(),
+ })
+ })
+ .collect::<Result<Vec<_>>>()?,
+ version: language::proto::deserialize_version(&finished.version),
+ })
+ }
+ proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
+ language::proto::deserialize_operation(
+ op.operation.context("invalid buffer operation")?,
+ )?,
+ )),
+ }
+ }
+
+ pub fn to_proto(&self) -> proto::ContextOperation {
+ match self {
+ Self::InsertMessage {
+ anchor,
+ metadata,
+ version,
+ } => proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::InsertMessage(
+ proto::context_operation::InsertMessage {
+ message: Some(proto::ContextMessage {
+ id: Some(language::proto::serialize_timestamp(anchor.id.0)),
+ start: Some(language::proto::serialize_anchor(&anchor.start)),
+ role: metadata.role.to_proto() as i32,
+ status: Some(metadata.status.to_proto()),
+ }),
+ version: language::proto::serialize_version(version),
+ },
+ )),
+ },
+ Self::UpdateMessage {
+ message_id,
+ metadata,
+ version,
+ } => proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::UpdateMessage(
+ proto::context_operation::UpdateMessage {
+ message_id: Some(language::proto::serialize_timestamp(message_id.0)),
+ role: metadata.role.to_proto() as i32,
+ status: Some(metadata.status.to_proto()),
+ timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
+ version: language::proto::serialize_version(version),
+ },
+ )),
+ },
+ Self::UpdateSummary { summary, version } => proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::UpdateSummary(
+ proto::context_operation::UpdateSummary {
+ summary: summary.text.clone(),
+ done: summary.done,
+ timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
+ version: language::proto::serialize_version(version),
+ },
+ )),
+ },
+ Self::SlashCommandFinished {
+ id,
+ output_range,
+ sections,
+ version,
+ } => proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::SlashCommandFinished(
+ proto::context_operation::SlashCommandFinished {
+ id: Some(language::proto::serialize_timestamp(id.0)),
+ output_range: Some(language::proto::serialize_anchor_range(
+ output_range.clone(),
+ )),
+ sections: sections
+ .iter()
+ .map(|section| {
+ let icon_name: &'static str = section.icon.into();
+ proto::SlashCommandOutputSection {
+ range: Some(language::proto::serialize_anchor_range(
+ section.range.clone(),
+ )),
+ icon_name: icon_name.to_string(),
+ label: section.label.to_string(),
+ }
+ })
+ .collect(),
+ version: language::proto::serialize_version(version),
+ },
+ )),
+ },
+ Self::BufferOperation(operation) => proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::BufferOperation(
+ proto::context_operation::BufferOperation {
+ operation: Some(language::proto::serialize_operation(operation)),
+ },
+ )),
+ },
+ }
+ }
+
+ fn timestamp(&self) -> clock::Lamport {
+ match self {
+ Self::InsertMessage { anchor, .. } => anchor.id.0,
+ Self::UpdateMessage { metadata, .. } => metadata.timestamp,
+ Self::UpdateSummary { summary, .. } => summary.timestamp,
+ Self::SlashCommandFinished { id, .. } => id.0,
+ Self::BufferOperation(_) => {
+ panic!("reading the timestamp of a buffer operation is not supported")
+ }
+ }
+ }
+
+ /// Returns the current version of the context operation.
+ pub fn version(&self) -> &clock::Global {
+ match self {
+ Self::InsertMessage { version, .. }
+ | Self::UpdateMessage { version, .. }
+ | Self::UpdateSummary { version, .. }
+ | Self::SlashCommandFinished { version, .. } => version,
+ Self::BufferOperation(_) => {
+ panic!("reading the version of a buffer operation is not supported")
+ }
+ }
+ }
+}
+
+#[derive(Clone)]
+pub enum ContextEvent {
+ MessagesEdited,
+ SummaryChanged,
+ EditSuggestionsChanged,
+ StreamedCompletion,
+ PendingSlashCommandsUpdated {
+ removed: Vec<Range<language::Anchor>>,
+ updated: Vec<PendingSlashCommand>,
+ },
+ SlashCommandFinished {
+ output_range: Range<language::Anchor>,
+ sections: Vec<SlashCommandOutputSection<language::Anchor>>,
+ run_commands_in_output: bool,
+ },
+ Operation(ContextOperation),
+}
+
+#[derive(Clone, Default, Debug)]
+pub struct ContextSummary {
+ pub text: String,
+ done: bool,
+ timestamp: clock::Lamport,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct MessageAnchor {
+ pub id: MessageId,
+ pub start: language::Anchor,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub struct MessageMetadata {
+ pub role: Role,
+ status: MessageStatus,
+ timestamp: clock::Lamport,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Message {
+ pub offset_range: Range<usize>,
+ pub index_range: Range<usize>,
+ pub id: MessageId,
+ pub anchor: language::Anchor,
+ pub role: Role,
+ pub status: MessageStatus,
+}
+
+impl Message {
+ fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
+ LanguageModelRequestMessage {
+ role: self.role,
+ content: buffer.text_for_range(self.offset_range.clone()).collect(),
+ }
+ }
+}
+
+struct PendingCompletion {
+ id: usize,
+ _task: Task<()>,
+}
+
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
+pub struct SlashCommandId(clock::Lamport);
+
+pub struct Context {
+ id: ContextId,
+ timestamp: clock::Lamport,
+ version: clock::Global,
+ pending_ops: Vec<ContextOperation>,
+ operations: Vec<ContextOperation>,
+ buffer: Model<Buffer>,
+ edit_suggestions: Vec<EditSuggestion>,
+ pending_slash_commands: Vec<PendingSlashCommand>,
+ edits_since_last_slash_command_parse: language::Subscription,
+ finished_slash_commands: HashSet<SlashCommandId>,
+ slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
+ message_anchors: Vec<MessageAnchor>,
+ messages_metadata: HashMap<MessageId, MessageMetadata>,
+ summary: Option<ContextSummary>,
+ pending_summary: Task<Option<()>>,
+ completion_count: usize,
+ pending_completions: Vec<PendingCompletion>,
+ token_count: Option<usize>,
+ pending_token_count: Task<Option<()>>,
+ pending_edit_suggestion_parse: Option<Task<()>>,
+ pending_save: Task<Result<()>>,
+ path: Option<PathBuf>,
+ _subscriptions: Vec<Subscription>,
+ telemetry: Option<Arc<Telemetry>>,
+ language_registry: Arc<LanguageRegistry>,
+}
+
+impl EventEmitter<ContextEvent> for Context {}
+
+impl Context {
+ pub fn local(
+ language_registry: Arc<LanguageRegistry>,
+ telemetry: Option<Arc<Telemetry>>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ Self::new(
+ ContextId::new(),
+ ReplicaId::default(),
+ language::Capability::ReadWrite,
+ language_registry,
+ telemetry,
+ cx,
+ )
+ }
+
+ pub fn new(
+ id: ContextId,
+ replica_id: ReplicaId,
+ capability: language::Capability,
+ language_registry: Arc<LanguageRegistry>,
+ telemetry: Option<Arc<Telemetry>>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let buffer = cx.new_model(|_cx| {
+ let mut buffer = Buffer::remote(
+ language::BufferId::new(1).unwrap(),
+ replica_id,
+ capability,
+ "",
+ );
+ buffer.set_language_registry(language_registry.clone());
+ buffer
+ });
+ let edits_since_last_slash_command_parse =
+ buffer.update(cx, |buffer, _| buffer.subscribe());
+ let mut this = Self {
+ id,
+ timestamp: clock::Lamport::new(replica_id),
+ version: clock::Global::new(),
+ pending_ops: Vec::new(),
+ operations: Vec::new(),
+ message_anchors: Default::default(),
+ messages_metadata: Default::default(),
+ edit_suggestions: Vec::new(),
+ pending_slash_commands: Vec::new(),
+ finished_slash_commands: HashSet::default(),
+ slash_command_output_sections: Vec::new(),
+ edits_since_last_slash_command_parse,
+ summary: None,
+ pending_summary: Task::ready(None),
+ completion_count: Default::default(),
+ pending_completions: Default::default(),
+ token_count: None,
+ pending_token_count: Task::ready(None),
+ pending_edit_suggestion_parse: None,
+ _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+ pending_save: Task::ready(Ok(())),
+ path: None,
+ buffer,
+ telemetry,
+ language_registry,
+ };
+
+ let first_message_id = MessageId(clock::Lamport {
+ replica_id: 0,
+ value: 0,
+ });
+ let message = MessageAnchor {
+ id: first_message_id,
+ start: language::Anchor::MIN,
+ };
+ this.messages_metadata.insert(
+ first_message_id,
+ MessageMetadata {
+ role: Role::User,
+ status: MessageStatus::Done,
+ timestamp: first_message_id.0,
+ },
+ );
+ this.message_anchors.push(message);
+
+ this.set_language(cx);
+ this.count_remaining_tokens(cx);
+ this
+ }
+
+ fn serialize(&self, cx: &AppContext) -> SavedContext {
+ let buffer = self.buffer.read(cx);
+ SavedContext {
+ id: Some(self.id.clone()),
+ zed: "context".into(),
+ version: SavedContext::VERSION.into(),
+ text: buffer.text(),
+ messages: self
+ .messages(cx)
+ .map(|message| SavedMessage {
+ id: message.id,
+ start: message.offset_range.start,
+ metadata: self.messages_metadata[&message.id].clone(),
+ })
+ .collect(),
+ summary: self
+ .summary
+ .as_ref()
+ .map(|summary| summary.text.clone())
+ .unwrap_or_default(),
+ slash_command_output_sections: self
+ .slash_command_output_sections
+ .iter()
+ .filter_map(|section| {
+ let range = section.range.to_offset(buffer);
+ if section.range.start.is_valid(buffer) && !range.is_empty() {
+ Some(assistant_slash_command::SlashCommandOutputSection {
+ range,
+ icon: section.icon,
+ label: section.label.clone(),
+ })
+ } else {
+ None
+ }
+ })
+ .collect(),
+ }
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn deserialize(
+ saved_context: SavedContext,
+ path: PathBuf,
+ language_registry: Arc<LanguageRegistry>,
+ telemetry: Option<Arc<Telemetry>>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
+ let mut this = Self::new(
+ id,
+ ReplicaId::default(),
+ language::Capability::ReadWrite,
+ language_registry,
+ telemetry,
+ cx,
+ );
+ this.path = Some(path);
+ this.buffer.update(cx, |buffer, cx| {
+ buffer.set_text(saved_context.text.as_str(), cx)
+ });
+ let operations = saved_context.into_ops(&this.buffer, cx);
+ this.apply_ops(operations, cx).unwrap();
+ this
+ }
+
+ pub fn id(&self) -> &ContextId {
+ &self.id
+ }
+
+ pub fn replica_id(&self) -> ReplicaId {
+ self.timestamp.replica_id
+ }
+
+ pub fn version(&self, cx: &AppContext) -> ContextVersion {
+ ContextVersion {
+ context: self.version.clone(),
+ buffer: self.buffer.read(cx).version(),
+ }
+ }
+
+ pub fn set_capability(
+ &mut self,
+ capability: language::Capability,
+ cx: &mut ModelContext<Self>,
+ ) {
+ self.buffer
+ .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
+ }
+
+ fn next_timestamp(&mut self) -> clock::Lamport {
+ let timestamp = self.timestamp.tick();
+ self.version.observe(timestamp);
+ timestamp
+ }
+
+ pub fn serialize_ops(
+ &self,
+ since: &ContextVersion,
+ cx: &AppContext,
+ ) -> Task<Vec<proto::ContextOperation>> {
+ let buffer_ops = self
+ .buffer
+ .read(cx)
+ .serialize_ops(Some(since.buffer.clone()), cx);
+
+ let mut context_ops = self
+ .operations
+ .iter()
+ .filter(|op| !since.context.observed(op.timestamp()))
+ .cloned()
+ .collect::<Vec<_>>();
+ context_ops.extend(self.pending_ops.iter().cloned());
+
+ cx.background_executor().spawn(async move {
+ let buffer_ops = buffer_ops.await;
+ context_ops.sort_unstable_by_key(|op| op.timestamp());
+ buffer_ops
+ .into_iter()
+ .map(|op| proto::ContextOperation {
+ variant: Some(proto::context_operation::Variant::BufferOperation(
+ proto::context_operation::BufferOperation {
+ operation: Some(op),
+ },
+ )),
+ })
+ .chain(context_ops.into_iter().map(|op| op.to_proto()))
+ .collect()
+ })
+ }
+
+ pub fn apply_ops(
+ &mut self,
+ ops: impl IntoIterator<Item = ContextOperation>,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<()> {
+ let mut buffer_ops = Vec::new();
+ for op in ops {
+ match op {
+ ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
+ op @ _ => self.pending_ops.push(op),
+ }
+ }
+ self.buffer
+ .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
+ self.flush_ops(cx);
+
+ Ok(())
+ }
+
+ fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
+ let mut messages_changed = false;
+ let mut summary_changed = false;
+
+ self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
+ for op in mem::take(&mut self.pending_ops) {
+ if !self.can_apply_op(&op, cx) {
+ self.pending_ops.push(op);
+ continue;
+ }
+
+ let timestamp = op.timestamp();
+ match op.clone() {
+ ContextOperation::InsertMessage {
+ anchor, metadata, ..
+ } => {
+ if self.messages_metadata.contains_key(&anchor.id) {
+ // We already applied this operation.
+ } else {
+ self.insert_message(anchor, metadata, cx);
+ messages_changed = true;
+ }
+ }
+ ContextOperation::UpdateMessage {
+ message_id,
+ metadata: new_metadata,
+ ..
+ } => {
+ let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
+ if new_metadata.timestamp > metadata.timestamp {
+ *metadata = new_metadata;
+ messages_changed = true;
+ }
+ }
+ ContextOperation::UpdateSummary {
+ summary: new_summary,
+ ..
+ } => {
+ if self
+ .summary
+ .as_ref()
+ .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
+ {
+ self.summary = Some(new_summary);
+ summary_changed = true;
+ }
+ }
+ ContextOperation::SlashCommandFinished {
+ id,
+ output_range,
+ sections,
+ ..
+ } => {
+ if self.finished_slash_commands.insert(id) {
+ let buffer = self.buffer.read(cx);
+ self.slash_command_output_sections
+ .extend(sections.iter().cloned());
+ self.slash_command_output_sections
+ .sort_by(|a, b| a.range.cmp(&b.range, buffer));
+ cx.emit(ContextEvent::SlashCommandFinished {
+ output_range,
+ sections,
+ run_commands_in_output: false,
+ });
+ }
+ }
+ ContextOperation::BufferOperation(_) => unreachable!(),
+ }
+
+ self.version.observe(timestamp);
+ self.timestamp.observe(timestamp);
+ self.operations.push(op);
+ }
+
+ if messages_changed {
+ cx.emit(ContextEvent::MessagesEdited);
+ cx.notify();
+ }
+
+ if summary_changed {
+ cx.emit(ContextEvent::SummaryChanged);
+ cx.notify();
+ }
+ }
+
+ fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
+ if !self.version.observed_all(op.version()) {
+ return false;
+ }
+
+ match op {
+ ContextOperation::InsertMessage { anchor, .. } => self
+ .buffer
+ .read(cx)
+ .version
+ .observed(anchor.start.timestamp),
+ ContextOperation::UpdateMessage { message_id, .. } => {
+ self.messages_metadata.contains_key(message_id)
+ }
+ ContextOperation::UpdateSummary { .. } => true,
+ ContextOperation::SlashCommandFinished {
+ output_range,
+ sections,
+ ..
+ } => {
+ let version = &self.buffer.read(cx).version;
+ sections
+ .iter()
+ .map(|section| §ion.range)
+ .chain([output_range])
+ .all(|range| {
+ let observed_start = range.start == language::Anchor::MIN
+ || range.start == language::Anchor::MAX
+ || version.observed(range.start.timestamp);
+ let observed_end = range.end == language::Anchor::MIN
+ || range.end == language::Anchor::MAX
+ || version.observed(range.end.timestamp);
+ observed_start && observed_end
+ })
+ }
+ ContextOperation::BufferOperation(_) => {
+ panic!("buffer operations should always be applied")
+ }
+ }
+ }
+
+ fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
+ self.operations.push(op.clone());
+ cx.emit(ContextEvent::Operation(op));
+ }
+
+ pub fn buffer(&self) -> &Model<Buffer> {
+ &self.buffer
+ }
+
+ pub fn path(&self) -> Option<&Path> {
+ self.path.as_deref()
+ }
+
+ pub fn summary(&self) -> Option<&ContextSummary> {
+ self.summary.as_ref()
+ }
+
+ pub fn edit_suggestions(&self) -> &[EditSuggestion] {
+ &self.edit_suggestions
+ }
+
+ pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
+ &self.pending_slash_commands
+ }
+
+ pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
+ &self.slash_command_output_sections
+ }
+
+ fn set_language(&mut self, cx: &mut ModelContext<Self>) {
+ let markdown = self.language_registry.language_for_name("Markdown");
+ cx.spawn(|this, mut cx| async move {
+ let markdown = markdown.await?;
+ this.update(&mut cx, |this, cx| {
+ this.buffer
+ .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn handle_buffer_event(
+ &mut self,
+ _: Model<Buffer>,
+ event: &language::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match event {
+ language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
+ ContextOperation::BufferOperation(operation.clone()),
+ )),
+ language::Event::Edited => {
+ self.count_remaining_tokens(cx);
+ self.reparse_edit_suggestions(cx);
+ self.reparse_slash_commands(cx);
+ cx.emit(ContextEvent::MessagesEdited);
+ }
+ _ => {}
+ }
+ }
+
+ pub(crate) fn token_count(&self) -> Option<usize> {
+ self.token_count
+ }
+
+ pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
+ let request = self.to_completion_request(cx);
+ self.pending_token_count = cx.spawn(|this, mut cx| {
+ async move {
+ cx.background_executor()
+ .timer(Duration::from_millis(200))
+ .await;
+
+ let token_count = cx
+ .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
+ .await?;
+
+ this.update(&mut cx, |this, cx| {
+ this.token_count = Some(token_count);
+ cx.notify()
+ })?;
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
+ }
+
+ pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
+ let buffer = self.buffer.read(cx);
+ let mut row_ranges = self
+ .edits_since_last_slash_command_parse
+ .consume()
+ .into_iter()
+ .map(|edit| {
+ let start_row = buffer.offset_to_point(edit.new.start).row;
+ let end_row = buffer.offset_to_point(edit.new.end).row + 1;
+ start_row..end_row
+ })
+ .peekable();
+
+ let mut removed = Vec::new();
+ let mut updated = Vec::new();
+ while let Some(mut row_range) = row_ranges.next() {
+ while let Some(next_row_range) = row_ranges.peek() {
+ if row_range.end >= next_row_range.start {
+ row_range.end = next_row_range.end;
+ row_ranges.next();
+ } else {
+ break;
+ }
+ }
+
+ let start = buffer.anchor_before(Point::new(row_range.start, 0));
+ let end = buffer.anchor_after(Point::new(
+ row_range.end - 1,
+ buffer.line_len(row_range.end - 1),
+ ));
+
+ let old_range = self.pending_command_indices_for_range(start..end, cx);
+
+ let mut new_commands = Vec::new();
+ let mut lines = buffer.text_for_range(start..end).lines();
+ let mut offset = lines.offset();
+ while let Some(line) = lines.next() {
+ if let Some(command_line) = SlashCommandLine::parse(line) {
+ let name = &line[command_line.name.clone()];
+ let argument = command_line.argument.as_ref().and_then(|argument| {
+ (!argument.is_empty()).then_some(&line[argument.clone()])
+ });
+ if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
+ if !command.requires_argument() || argument.is_some() {
+ let start_ix = offset + command_line.name.start - 1;
+ let end_ix = offset
+ + command_line
+ .argument
+ .map_or(command_line.name.end, |argument| argument.end);
+ let source_range =
+ buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
+ let pending_command = PendingSlashCommand {
+ name: name.to_string(),
+ argument: argument.map(ToString::to_string),
+ source_range,
+ status: PendingSlashCommandStatus::Idle,
+ };
+ updated.push(pending_command.clone());
+ new_commands.push(pending_command);
+ }
+ }
+ }
+
+ offset = lines.offset();
+ }
+
+ let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
+ removed.extend(removed_commands.map(|command| command.source_range));
+ }
+
+ if !updated.is_empty() || !removed.is_empty() {
+ cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
+ }
+ }
+
+ fn reparse_edit_suggestions(&mut self, cx: &mut ModelContext<Self>) {
+ self.pending_edit_suggestion_parse = Some(cx.spawn(|this, mut cx| async move {
+ cx.background_executor()
+ .timer(Duration::from_millis(200))
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ this.reparse_edit_suggestions_in_range(0..this.buffer.read(cx).len(), cx);
+ })
+ .ok();
+ }));
+ }
+
+ fn reparse_edit_suggestions_in_range(
+ &mut self,
+ range: Range<usize>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ self.buffer.update(cx, |buffer, _| {
+ let range_start = buffer.anchor_before(range.start);
+ let range_end = buffer.anchor_after(range.end);
+ let start_ix = self
+ .edit_suggestions
+ .binary_search_by(|probe| {
+ probe
+ .source_range
+ .end
+ .cmp(&range_start, buffer)
+ .then(Ordering::Greater)
+ })
+ .unwrap_err();
+ let end_ix = self
+ .edit_suggestions
+ .binary_search_by(|probe| {
+ probe
+ .source_range
+ .start
+ .cmp(&range_end, buffer)
+ .then(Ordering::Less)
+ })
+ .unwrap_err();
+
+ let mut new_edit_suggestions = Vec::new();
+ let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
+ while let Some(suggestion) = parse_next_edit_suggestion(&mut message_lines) {
+ let start_anchor = buffer.anchor_after(suggestion.outer_range.start);
+ let end_anchor = buffer.anchor_before(suggestion.outer_range.end);
+ new_edit_suggestions.push(EditSuggestion {
+ source_range: start_anchor..end_anchor,
+ full_path: suggestion.path,
+ });
+ }
+ self.edit_suggestions
+ .splice(start_ix..end_ix, new_edit_suggestions);
+ });
+ cx.emit(ContextEvent::EditSuggestionsChanged);
+ cx.notify();
+ }
+
+ pub fn pending_command_for_position(
+ &mut self,
+ position: language::Anchor,
+ cx: &mut ModelContext<Self>,
+ ) -> Option<&mut PendingSlashCommand> {
+ let buffer = self.buffer.read(cx);
+ match self
+ .pending_slash_commands
+ .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
+ {
+ Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
+ Err(ix) => {
+ let cmd = self.pending_slash_commands.get_mut(ix)?;
+ if position.cmp(&cmd.source_range.start, buffer).is_ge()
+ && position.cmp(&cmd.source_range.end, buffer).is_le()
+ {
+ Some(cmd)
+ } else {
+ None
+ }
+ }
+ }
+ }
+
+ pub fn pending_commands_for_range(
+ &self,
+ range: Range<language::Anchor>,
+ cx: &AppContext,
+ ) -> &[PendingSlashCommand] {
+ let range = self.pending_command_indices_for_range(range, cx);
+ &self.pending_slash_commands[range]
+ }
+
+ fn pending_command_indices_for_range(
+ &self,
+ range: Range<language::Anchor>,
+ cx: &AppContext,
+ ) -> Range<usize> {
+ let buffer = self.buffer.read(cx);
+ let start_ix = match self
+ .pending_slash_commands
+ .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
+ {
+ Ok(ix) | Err(ix) => ix,
+ };
+ let end_ix = match self
+ .pending_slash_commands
+ .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
+ {
+ Ok(ix) => ix + 1,
+ Err(ix) => ix,
+ };
+ start_ix..end_ix
+ }
+
+ pub fn insert_command_output(
+ &mut self,
+ command_range: Range<language::Anchor>,
+ output: Task<Result<SlashCommandOutput>>,
+ insert_trailing_newline: bool,
+ cx: &mut ModelContext<Self>,
+ ) {
+ self.reparse_slash_commands(cx);
+
+ let insert_output_task = cx.spawn(|this, mut cx| {
+ let command_range = command_range.clone();
+ async move {
+ let output = output.await;
+ this.update(&mut cx, |this, cx| match output {
+ Ok(mut output) => {
+ if insert_trailing_newline {
+ output.text.push('\n');
+ }
+
+ let version = this.version.clone();
+ let command_id = SlashCommandId(this.next_timestamp());
+ let (operation, event) = this.buffer.update(cx, |buffer, cx| {
+ let start = command_range.start.to_offset(buffer);
+ let old_end = command_range.end.to_offset(buffer);
+ let new_end = start + output.text.len();
+ buffer.edit([(start..old_end, output.text)], None, cx);
+
+ let mut sections = output
+ .sections
+ .into_iter()
+ .map(|section| SlashCommandOutputSection {
+ range: buffer.anchor_after(start + section.range.start)
+ ..buffer.anchor_before(start + section.range.end),
+ icon: section.icon,
+ label: section.label,
+ })
+ .collect::<Vec<_>>();
+ sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
+
+ this.slash_command_output_sections
+ .extend(sections.iter().cloned());
+ this.slash_command_output_sections
+ .sort_by(|a, b| a.range.cmp(&b.range, buffer));
+
+ let output_range =
+ buffer.anchor_after(start)..buffer.anchor_before(new_end);
+ this.finished_slash_commands.insert(command_id);
+
+ (
+ ContextOperation::SlashCommandFinished {
+ id: command_id,
+ output_range: output_range.clone(),
+ sections: sections.clone(),
+ version,
+ },
+ ContextEvent::SlashCommandFinished {
+ output_range,
+ sections,
+ run_commands_in_output: output.run_commands_in_text,
+ },
+ )
+ });
+
+ this.push_op(operation, cx);
+ cx.emit(event);
+ }
+ Err(error) => {
+ if let Some(pending_command) =
+ this.pending_command_for_position(command_range.start, cx)
+ {
+ pending_command.status =
+ PendingSlashCommandStatus::Error(error.to_string());
+ cx.emit(ContextEvent::PendingSlashCommandsUpdated {
+ removed: vec![pending_command.source_range.clone()],
+ updated: vec![pending_command.clone()],
+ });
+ }
+ }
+ })
+ .ok();
+ }
+ });
+
+ if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
+ pending_command.status = PendingSlashCommandStatus::Running {
+ _task: insert_output_task.shared(),
+ };
+ cx.emit(ContextEvent::PendingSlashCommandsUpdated {
+ removed: vec![pending_command.source_range.clone()],
+ updated: vec![pending_command.clone()],
+ });
+ }
+ }
+
+ pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
+ self.count_remaining_tokens(cx);
+ }
+
+ pub fn assist(
+ &mut self,
+ selected_messages: HashSet<MessageId>,
+ cx: &mut ModelContext<Self>,
+ ) -> Vec<MessageAnchor> {
+ let mut user_messages = Vec::new();
+
+ let last_message_id = if let Some(last_message_id) =
+ self.message_anchors.iter().rev().find_map(|message| {
+ message
+ .start
+ .is_valid(self.buffer.read(cx))
+ .then_some(message.id)
+ }) {
+ last_message_id
+ } else {
+ return Default::default();
+ };
+
+ let mut should_assist = false;
+ for selected_message_id in selected_messages {
+ let selected_message_role =
+ if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
+ metadata.role
+ } else {
+ continue;
+ };
+
+ if selected_message_role == Role::Assistant {
+ if let Some(user_message) = self.insert_message_after(
+ selected_message_id,
+ Role::User,
+ MessageStatus::Done,
+ cx,
+ ) {
+ user_messages.push(user_message);
+ }
+ } else {
+ should_assist = true;
+ }
+ }
+
+ if should_assist {
+ if !CompletionProvider::global(cx).is_authenticated() {
+ log::info!("completion provider has no credentials");
+ return Default::default();
+ }
+
+ let request = self.to_completion_request(cx);
+ let stream = CompletionProvider::global(cx).complete(request, cx);
+ let assistant_message = self
+ .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
+ .unwrap();
+
+ // Queue up the user's next reply.
+ let user_message = self
+ .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
+ .unwrap();
+ user_messages.push(user_message);
+
+ let task = cx.spawn({
+ |this, mut cx| async move {
+ let assistant_message_id = assistant_message.id;
+ let mut response_latency = None;
+ let stream_completion = async {
+ let request_start = Instant::now();
+ let mut messages = stream.await.inner.await?;
+
+ while let Some(message) = messages.next().await {
+ if response_latency.is_none() {
+ response_latency = Some(request_start.elapsed());
+ }
+ let text = message?;
+
+ this.update(&mut cx, |this, cx| {
+ let message_ix = this
+ .message_anchors
+ .iter()
+ .position(|message| message.id == assistant_message_id)?;
+ let message_range = this.buffer.update(cx, |buffer, cx| {
+ let message_start_offset =
+ this.message_anchors[message_ix].start.to_offset(buffer);
+ let message_old_end_offset = this.message_anchors
+ [message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message.start.to_offset(buffer).saturating_sub(1)
+ });
+ let message_new_end_offset =
+ message_old_end_offset + text.len();
+ buffer.edit(
+ [(message_old_end_offset..message_old_end_offset, text)],
+ None,
+ cx,
+ );
+ message_start_offset..message_new_end_offset
+ });
+ this.reparse_edit_suggestions_in_range(message_range, cx);
+ cx.emit(ContextEvent::StreamedCompletion);
+
+ Some(())
+ })?;
+ smol::future::yield_now().await;
+ }
+
+ this.update(&mut cx, |this, cx| {
+ this.pending_completions
+ .retain(|completion| completion.id != this.completion_count);
+ this.summarize(cx);
+ })?;
+
+ anyhow::Ok(())
+ };
+
+ let result = stream_completion.await;
+
+ this.update(&mut cx, |this, cx| {
+ let error_message = result
+ .err()
+ .map(|error| error.to_string().trim().to_string());
+
+ this.update_metadata(assistant_message_id, cx, |metadata| {
+ if let Some(error_message) = error_message.as_ref() {
+ metadata.status =
+ MessageStatus::Error(SharedString::from(error_message.clone()));
+ } else {
+ metadata.status = MessageStatus::Done;
+ }
+ });
+
+ if let Some(telemetry) = this.telemetry.as_ref() {
+ let model = CompletionProvider::global(cx).model();
+ telemetry.report_assistant_event(
+ Some(this.id.0.clone()),
+ AssistantKind::Panel,
+ model.telemetry_id(),
+ response_latency,
+ error_message,
+ );
+ }
+ })
+ .ok();
+ }
+ });
+
+ self.pending_completions.push(PendingCompletion {
+ id: post_inc(&mut self.completion_count),
+ _task: task,
+ });
+ }
+
+ user_messages
+ }
+
+ pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
+ let messages = self
+ .messages(cx)
+ .filter(|message| matches!(message.status, MessageStatus::Done))
+ .map(|message| message.to_request_message(self.buffer.read(cx)));
+
+ LanguageModelRequest {
+ model: CompletionProvider::global(cx).model(),
+ messages: messages.collect(),
+ stop: vec![],
+ temperature: 1.0,
+ }
+ }
+
+ pub fn cancel_last_assist(&mut self) -> bool {
+ self.pending_completions.pop().is_some()
+ }
+
+ pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
+ for id in ids {
+ if let Some(metadata) = self.messages_metadata.get(&id) {
+ let role = metadata.role.cycle();
+ self.update_metadata(id, cx, |metadata| metadata.role = role);
+ }
+ }
+ }
+
+ pub fn update_metadata(
+ &mut self,
+ id: MessageId,
+ cx: &mut ModelContext<Self>,
+ f: impl FnOnce(&mut MessageMetadata),
+ ) {
+ let version = self.version.clone();
+ let timestamp = self.next_timestamp();
+ if let Some(metadata) = self.messages_metadata.get_mut(&id) {
+ f(metadata);
+ metadata.timestamp = timestamp;
+ let operation = ContextOperation::UpdateMessage {
+ message_id: id,
+ metadata: metadata.clone(),
+ version,
+ };
+ self.push_op(operation, cx);
+ cx.emit(ContextEvent::MessagesEdited);
+ cx.notify();
+ }
+ }
+
+ fn insert_message_after(
+ &mut self,
+ message_id: MessageId,
+ role: Role,
+ status: MessageStatus,
+ cx: &mut ModelContext<Self>,
+ ) -> Option<MessageAnchor> {
+ if let Some(prev_message_ix) = self
+ .message_anchors
+ .iter()
+ .position(|message| message.id == message_id)
+ {
+ // Find the next valid message after the one we were given.
+ let mut next_message_ix = prev_message_ix + 1;
+ while let Some(next_message) = self.message_anchors.get(next_message_ix) {
+ if next_message.start.is_valid(self.buffer.read(cx)) {
+ break;
+ }
+ next_message_ix += 1;
+ }
+
+ let start = self.buffer.update(cx, |buffer, cx| {
+ let offset = self
+ .message_anchors
+ .get(next_message_ix)
+ .map_or(buffer.len(), |message| {
+ buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
+ });
+ buffer.edit([(offset..offset, "\n")], None, cx);
+ buffer.anchor_before(offset + 1)
+ });
+
+ let version = self.version.clone();
+ let anchor = MessageAnchor {
+ id: MessageId(self.next_timestamp()),
+ start,
+ };
+ let metadata = MessageMetadata {
+ role,
+ status,
+ timestamp: anchor.id.0,
+ };
+ self.insert_message(anchor.clone(), metadata.clone(), cx);
+ self.push_op(
+ ContextOperation::InsertMessage {
+ anchor: anchor.clone(),
+ metadata,
+ version,
+ },
+ cx,
+ );
+ Some(anchor)
+ } else {
+ None
+ }
+ }
+
+ pub fn split_message(
+ &mut self,
+ range: Range<usize>,
+ cx: &mut ModelContext<Self>,
+ ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
+ let start_message = self.message_for_offset(range.start, cx);
+ let end_message = self.message_for_offset(range.end, cx);
+ if let Some((start_message, end_message)) = start_message.zip(end_message) {
+ // Prevent splitting when range spans multiple messages.
+ if start_message.id != end_message.id {
+ return (None, None);
+ }
+
+ let message = start_message;
+ let role = message.role;
+ let mut edited_buffer = false;
+
+ let mut suffix_start = None;
+ if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
+ {
+ if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
+ suffix_start = Some(range.end + 1);
+ } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
+ suffix_start = Some(range.end);
+ }
+ }
+
+ let version = self.version.clone();
+ let suffix = if let Some(suffix_start) = suffix_start {
+ MessageAnchor {
+ id: MessageId(self.next_timestamp()),
+ start: self.buffer.read(cx).anchor_before(suffix_start),
+ }
+ } else {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.edit([(range.end..range.end, "\n")], None, cx);
+ });
+ edited_buffer = true;
+ MessageAnchor {
+ id: MessageId(self.next_timestamp()),
+ start: self.buffer.read(cx).anchor_before(range.end + 1),
+ }
+ };
+
+ let suffix_metadata = MessageMetadata {
+ role,
+ status: MessageStatus::Done,
+ timestamp: suffix.id.0,
+ };
+ self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
+ self.push_op(
+ ContextOperation::InsertMessage {
+ anchor: suffix.clone(),
+ metadata: suffix_metadata,
+ version,
+ },
+ cx,
+ );
+
+ let new_messages =
+ if range.start == range.end || range.start == message.offset_range.start {
+ (None, Some(suffix))
+ } else {
+ let mut prefix_end = None;
+ if range.start > message.offset_range.start
+ && range.end < message.offset_range.end - 1
+ {
+ if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
+ prefix_end = Some(range.start + 1);
+ } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
+ == Some('\n')
+ {
+ prefix_end = Some(range.start);
+ }
+ }
+
+ let version = self.version.clone();
+ let selection = if let Some(prefix_end) = prefix_end {
+ MessageAnchor {
+ id: MessageId(self.next_timestamp()),
+ start: self.buffer.read(cx).anchor_before(prefix_end),
+ }
+ } else {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.edit([(range.start..range.start, "\n")], None, cx)
+ });
+ edited_buffer = true;
+ MessageAnchor {
+ id: MessageId(self.next_timestamp()),
+ start: self.buffer.read(cx).anchor_before(range.end + 1),
+ }
+ };
+
+ let selection_metadata = MessageMetadata {
+ role,
+ status: MessageStatus::Done,
+ timestamp: selection.id.0,
+ };
+ self.insert_message(selection.clone(), selection_metadata.clone(), cx);
+ self.push_op(
+ ContextOperation::InsertMessage {
+ anchor: selection.clone(),
+ metadata: selection_metadata,
+ version,
+ },
+ cx,
+ );
+
+ (Some(selection), Some(suffix))
+ };
+
+ if !edited_buffer {
+ cx.emit(ContextEvent::MessagesEdited);
+ }
+ new_messages
+ } else {
+ (None, None)
+ }
+ }
+
+ fn insert_message(
+ &mut self,
+ new_anchor: MessageAnchor,
+ new_metadata: MessageMetadata,
+ cx: &mut ModelContext<Self>,
+ ) {
+ cx.emit(ContextEvent::MessagesEdited);
+
+ self.messages_metadata.insert(new_anchor.id, new_metadata);
+
+ let buffer = self.buffer.read(cx);
+ let insertion_ix = self
+ .message_anchors
+ .iter()
+ .position(|anchor| {
+ let comparison = new_anchor.start.cmp(&anchor.start, buffer);
+ comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
+ })
+ .unwrap_or(self.message_anchors.len());
+ self.message_anchors.insert(insertion_ix, new_anchor);
+ }
+
+ fn summarize(&mut self, cx: &mut ModelContext<Self>) {
+ if self.message_anchors.len() >= 2 && self.summary.is_none() {
+ if !CompletionProvider::global(cx).is_authenticated() {
+ return;
+ }
+
+ let messages = self
+ .messages(cx)
+ .map(|message| message.to_request_message(self.buffer.read(cx)))
+ .chain(Some(LanguageModelRequestMessage {
+ role: Role::User,
+ content: "Summarize the context into a short title without punctuation.".into(),
+ }));
+ let request = LanguageModelRequest {
+ model: CompletionProvider::global(cx).model(),
+ messages: messages.collect(),
+ stop: vec![],
+ temperature: 1.0,
+ };
+
+ let stream = CompletionProvider::global(cx).complete(request, cx);
+ self.pending_summary = cx.spawn(|this, mut cx| {
+ async move {
+ let mut messages = stream.await.inner.await?;
+
+ while let Some(message) = messages.next().await {
+ let text = message?;
+ let mut lines = text.lines();
+ this.update(&mut cx, |this, cx| {
+ let version = this.version.clone();
+ let timestamp = this.next_timestamp();
+ let summary = this.summary.get_or_insert(Default::default());
+ summary.text.extend(lines.next());
+ summary.timestamp = timestamp;
+ let operation = ContextOperation::UpdateSummary {
+ summary: summary.clone(),
+ version,
+ };
+ this.push_op(operation, cx);
+ cx.emit(ContextEvent::SummaryChanged);
+ })?;
+
+ // Stop if the LLM generated multiple lines.
+ if lines.next().is_some() {
+ break;
+ }
+ }
+
+ this.update(&mut cx, |this, cx| {
+ let version = this.version.clone();
+ let timestamp = this.next_timestamp();
+ if let Some(summary) = this.summary.as_mut() {
+ summary.done = true;
+ summary.timestamp = timestamp;
+ let operation = ContextOperation::UpdateSummary {
+ summary: summary.clone(),
+ version,
+ };
+ this.push_op(operation, cx);
+ cx.emit(ContextEvent::SummaryChanged);
+ }
+ })?;
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
+ }
+ }
+
+ fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
+ self.messages_for_offsets([offset], cx).pop()
+ }
+
+ pub fn messages_for_offsets(
+ &self,
+ offsets: impl IntoIterator<Item = usize>,
+ cx: &AppContext,
+ ) -> Vec<Message> {
+ let mut result = Vec::new();
+
+ let mut messages = self.messages(cx).peekable();
+ let mut offsets = offsets.into_iter().peekable();
+ let mut current_message = messages.next();
+ while let Some(offset) = offsets.next() {
+ // Locate the message that contains the offset.
+ while current_message.as_ref().map_or(false, |message| {
+ !message.offset_range.contains(&offset) && messages.peek().is_some()
+ }) {
+ current_message = messages.next();
+ }
+ let Some(message) = current_message.as_ref() else {
+ break;
+ };
+
+ // Skip offsets that are in the same message.
+ while offsets.peek().map_or(false, |offset| {
+ message.offset_range.contains(offset) || messages.peek().is_none()
+ }) {
+ offsets.next();
+ }
+
+ result.push(message.clone());
+ }
+ result
+ }
+
+ pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
+ let buffer = self.buffer.read(cx);
+ let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
+ iter::from_fn(move || {
+ if let Some((start_ix, message_anchor)) = message_anchors.next() {
+ let metadata = self.messages_metadata.get(&message_anchor.id)?;
+ let message_start = message_anchor.start.to_offset(buffer);
+ let mut message_end = None;
+ let mut end_ix = start_ix;
+ while let Some((_, next_message)) = message_anchors.peek() {
+ if next_message.start.is_valid(buffer) {
+ message_end = Some(next_message.start);
+ break;
+ } else {
+ end_ix += 1;
+ message_anchors.next();
+ }
+ }
+ let message_end = message_end
+ .unwrap_or(language::Anchor::MAX)
+ .to_offset(buffer);
+
+ return Some(Message {
+ index_range: start_ix..end_ix,
+ offset_range: message_start..message_end,
+ id: message_anchor.id,
+ anchor: message_anchor.start,
+ role: metadata.role,
+ status: metadata.status.clone(),
+ });
+ }
+ None
+ })
+ }
+
+ pub fn save(
+ &mut self,
+ debounce: Option<Duration>,
+ fs: Arc<dyn Fs>,
+ cx: &mut ModelContext<Context>,
+ ) {
+ if self.replica_id() != ReplicaId::default() {
+ // Prevent saving a remote context for now.
+ return;
+ }
+
+ self.pending_save = cx.spawn(|this, mut cx| async move {
+ if let Some(debounce) = debounce {
+ cx.background_executor().timer(debounce).await;
+ }
+
+ let (old_path, summary) = this.read_with(&cx, |this, _| {
+ let path = this.path.clone();
+ let summary = if let Some(summary) = this.summary.as_ref() {
+ if summary.done {
+ Some(summary.text.clone())
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+ (path, summary)
+ })?;
+
+ if let Some(summary) = summary {
+ let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
+ let path = if let Some(old_path) = old_path {
+ old_path
+ } else {
+ let mut discriminant = 1;
+ let mut new_path;
+ loop {
+ new_path = contexts_dir().join(&format!(
+ "{} - {}.zed.json",
+ summary.trim(),
+ discriminant
+ ));
+ if fs.is_file(&new_path).await {
+ discriminant += 1;
+ } else {
+ break;
+ }
+ }
+ new_path
+ };
+
+ fs.create_dir(contexts_dir().as_ref()).await?;
+ fs.atomic_write(path.clone(), serde_json::to_string(&context).unwrap())
+ .await?;
+ this.update(&mut cx, |this, _| this.path = Some(path))?;
+ }
+
+ Ok(())
+ });
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct ContextVersion {
+ context: clock::Global,
+ buffer: clock::Global,
+}
+
+impl ContextVersion {
+ pub fn from_proto(proto: &proto::ContextVersion) -> Self {
+ Self {
+ context: language::proto::deserialize_version(&proto.context_version),
+ buffer: language::proto::deserialize_version(&proto.buffer_version),
+ }
+ }
+
+ pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
+ proto::ContextVersion {
+ context_id: context_id.to_proto(),
+ context_version: language::proto::serialize_version(&self.context),
+ buffer_version: language::proto::serialize_version(&self.buffer),
+ }
+ }
+}
+
+#[derive(Debug)]
+enum EditParsingState {
+ None,
+ InOldText {
+ path: PathBuf,
+ start_offset: usize,
+ old_text_start_offset: usize,
+ },
+ InNewText {
+ path: PathBuf,
+ start_offset: usize,
+ old_text_range: Range<usize>,
+ new_text_start_offset: usize,
+ },
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct EditSuggestion {
+ pub source_range: Range<language::Anchor>,
+ pub full_path: PathBuf,
+}
+
+pub struct ParsedEditSuggestion {
+ pub path: PathBuf,
+ pub outer_range: Range<usize>,
+ pub old_text_range: Range<usize>,
+ pub new_text_range: Range<usize>,
+}
+
+pub fn parse_next_edit_suggestion(lines: &mut rope::Lines) -> Option<ParsedEditSuggestion> {
+ let mut state = EditParsingState::None;
+ loop {
+ let offset = lines.offset();
+ let message_line = lines.next()?;
+ match state {
+ EditParsingState::None => {
+ if let Some(rest) = message_line.strip_prefix("```edit ") {
+ let path = rest.trim();
+ if !path.is_empty() {
+ state = EditParsingState::InOldText {
+ path: PathBuf::from(path),
+ start_offset: offset,
+ old_text_start_offset: lines.offset(),
+ };
+ }
+ }
+ }
+ EditParsingState::InOldText {
+ path,
+ start_offset,
+ old_text_start_offset,
+ } => {
+ if message_line == "---" {
+ state = EditParsingState::InNewText {
+ path,
+ start_offset,
+ old_text_range: old_text_start_offset..offset,
+ new_text_start_offset: lines.offset(),
+ };
+ } else {
+ state = EditParsingState::InOldText {
+ path,
+ start_offset,
+ old_text_start_offset,
+ };
+ }
+ }
+ EditParsingState::InNewText {
+ path,
+ start_offset,
+ old_text_range,
+ new_text_start_offset,
+ } => {
+ if message_line == "```" {
+ return Some(ParsedEditSuggestion {
+ path,
+ outer_range: start_offset..offset + "```".len(),
+ old_text_range,
+ new_text_range: new_text_start_offset..offset,
+ });
+ } else {
+ state = EditParsingState::InNewText {
+ path,
+ start_offset,
+ old_text_range,
+ new_text_start_offset,
+ };
+ }
+ }
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct PendingSlashCommand {
+ pub name: String,
+ pub argument: Option<String>,
+ pub status: PendingSlashCommandStatus,
+ pub source_range: Range<language::Anchor>,
+}
+
+#[derive(Clone)]
+pub enum PendingSlashCommandStatus {
+ Idle,
+ Running { _task: Shared<Task<()>> },
+ Error(String),
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedMessage {
+ pub id: MessageId,
+ pub start: usize,
+ pub metadata: MessageMetadata,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedContext {
+ pub id: Option<ContextId>,
+ pub zed: String,
+ pub version: String,
+ pub text: String,
+ pub messages: Vec<SavedMessage>,
+ pub summary: String,
+ pub slash_command_output_sections:
+ Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
+}
+
+impl SavedContext {
+ pub const VERSION: &'static str = "0.4.0";
+
+ pub fn from_json(json: &str) -> Result<Self> {
+ let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
+ match saved_context_json
+ .get("version")
+ .ok_or_else(|| anyhow!("version not found"))?
+ {
+ serde_json::Value::String(version) => match version.as_str() {
+ SavedContext::VERSION => {
+ Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
+ }
+ SavedContextV0_3_0::VERSION => {
+ let saved_context =
+ serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
+ Ok(saved_context.upgrade())
+ }
+ SavedContextV0_2_0::VERSION => {
+ let saved_context =
+ serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
+ Ok(saved_context.upgrade())
+ }
+ SavedContextV0_1_0::VERSION => {
+ let saved_context =
+ serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
+ Ok(saved_context.upgrade())
+ }
+ _ => Err(anyhow!("unrecognized saved context version: {}", version)),
+ },
+ _ => Err(anyhow!("version not found on saved context")),
+ }
+ }
+
+ fn into_ops(
+ self,
+ buffer: &Model<Buffer>,
+ cx: &mut ModelContext<Context>,
+ ) -> Vec<ContextOperation> {
+ let mut operations = Vec::new();
+ let mut version = clock::Global::new();
+ let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
+
+ let mut first_message_metadata = None;
+ for message in self.messages {
+ if message.id == MessageId(clock::Lamport::default()) {
+ first_message_metadata = Some(message.metadata);
+ } else {
+ operations.push(ContextOperation::InsertMessage {
+ anchor: MessageAnchor {
+ id: message.id,
+ start: buffer.read(cx).anchor_before(message.start),
+ },
+ metadata: MessageMetadata {
+ role: message.metadata.role,
+ status: message.metadata.status,
+ timestamp: message.metadata.timestamp,
+ },
+ version: version.clone(),
+ });
+ version.observe(message.id.0);
+ next_timestamp.observe(message.id.0);
+ }
+ }
+
+ if let Some(metadata) = first_message_metadata {
+ let timestamp = next_timestamp.tick();
+ operations.push(ContextOperation::UpdateMessage {
+ message_id: MessageId(clock::Lamport::default()),
+ metadata: MessageMetadata {
+ role: metadata.role,
+ status: metadata.status,
+ timestamp,
+ },
+ version: version.clone(),
+ });
+ version.observe(timestamp);
+ }
+
+ let timestamp = next_timestamp.tick();
+ operations.push(ContextOperation::SlashCommandFinished {
+ id: SlashCommandId(timestamp),
+ output_range: language::Anchor::MIN..language::Anchor::MAX,
+ sections: self
+ .slash_command_output_sections
+ .into_iter()
+ .map(|section| {
+ let buffer = buffer.read(cx);
+ SlashCommandOutputSection {
+ range: buffer.anchor_after(section.range.start)
+ ..buffer.anchor_before(section.range.end),
+ icon: section.icon,
+ label: section.label,
+ }
+ })
+ .collect(),
+ version: version.clone(),
+ });
+ version.observe(timestamp);
+
+ let timestamp = next_timestamp.tick();
+ operations.push(ContextOperation::UpdateSummary {
+ summary: ContextSummary {
+ text: self.summary,
+ done: true,
+ timestamp,
+ },
+ version: version.clone(),
+ });
+ version.observe(timestamp);
+
+ operations
+ }
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+struct SavedMessageIdPreV0_4_0(usize);
+
+#[derive(Serialize, Deserialize)]
+struct SavedMessagePreV0_4_0 {
+ id: SavedMessageIdPreV0_4_0,
+ start: usize,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+struct SavedMessageMetadataPreV0_4_0 {
+ role: Role,
+ status: MessageStatus,
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedContextV0_3_0 {
+ id: Option<ContextId>,
+ zed: String,
+ version: String,
+ text: String,
+ messages: Vec<SavedMessagePreV0_4_0>,
+ message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
+ summary: String,
+ slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
+}
+
+impl SavedContextV0_3_0 {
+ const VERSION: &'static str = "0.3.0";
+
+ fn upgrade(self) -> SavedContext {
+ SavedContext {
+ id: self.id,
+ zed: self.zed,
+ version: SavedContext::VERSION.into(),
+ text: self.text,
+ messages: self
+ .messages
+ .into_iter()
+ .filter_map(|message| {
+ let metadata = self.message_metadata.get(&message.id)?;
+ let timestamp = clock::Lamport {
+ replica_id: ReplicaId::default(),
+ value: message.id.0 as u32,
+ };
+ Some(SavedMessage {
+ id: MessageId(timestamp),
+ start: message.start,
+ metadata: MessageMetadata {
+ role: metadata.role,
+ status: metadata.status.clone(),
+ timestamp,
+ },
+ })
+ })
+ .collect(),
+ summary: self.summary,
+ slash_command_output_sections: self.slash_command_output_sections,
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedContextV0_2_0 {
+ id: Option<ContextId>,
+ zed: String,
+ version: String,
+ text: String,
+ messages: Vec<SavedMessagePreV0_4_0>,
+ message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
+ summary: String,
+}
+
+impl SavedContextV0_2_0 {
+ const VERSION: &'static str = "0.2.0";
+
+ fn upgrade(self) -> SavedContext {
+ SavedContextV0_3_0 {
+ id: self.id,
+ zed: self.zed,
+ version: SavedContextV0_3_0::VERSION.to_string(),
+ text: self.text,
+ messages: self.messages,
+ message_metadata: self.message_metadata,
+ summary: self.summary,
+ slash_command_output_sections: Vec::new(),
+ }
+ .upgrade()
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedContextV0_1_0 {
+ id: Option<ContextId>,
+ zed: String,
+ version: String,
+ text: String,
+ messages: Vec<SavedMessagePreV0_4_0>,
+ message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
+ summary: String,
+ api_url: Option<String>,
+ model: OpenAiModel,
+}
+
+impl SavedContextV0_1_0 {
+ const VERSION: &'static str = "0.1.0";
+
+ fn upgrade(self) -> SavedContext {
+ SavedContextV0_2_0 {
+ id: self.id,
+ zed: self.zed,
+ version: SavedContextV0_2_0::VERSION.to_string(),
+ text: self.text,
+ messages: self.messages,
+ message_metadata: self.message_metadata,
+ summary: self.summary,
+ }
+ .upgrade()
+ }
+}
+
+#[derive(Clone)]
+pub struct SavedContextMetadata {
+ pub title: String,
+ pub path: PathBuf,
+ pub mtime: chrono::DateTime<chrono::Local>,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ assistant_panel,
+ slash_command::{active_command, file_command},
+ FakeCompletionProvider, MessageId,
+ };
+ use assistant_slash_command::{ArgumentCompletion, SlashCommand};
+ use fs::FakeFs;
+ use gpui::{AppContext, TestAppContext, WeakView};
+ use language::LspAdapterDelegate;
+ use parking_lot::Mutex;
+ use project::Project;
+ use rand::prelude::*;
+ use rope::Rope;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use std::{cell::RefCell, env, path::Path, rc::Rc, sync::atomic::AtomicBool};
+ use text::network::Network;
+ use ui::WindowContext;
+ use unindent::Unindent;
+ use util::{test::marked_text_ranges, RandomCharIter};
+ use workspace::Workspace;
+
+ #[gpui::test]
+ fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ FakeCompletionProvider::setup_test(cx);
+ cx.set_global(settings_store);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+
+ let context = cx.new_model(|cx| Context::local(registry, None, cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ let message_2 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..1),
+ (message_2.id, Role::Assistant, 1..1)
+ ]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..3)
+ ]
+ );
+
+ let message_3 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_3.id, Role::User, 4..4)
+ ]
+ );
+
+ let message_4 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..5),
+ (message_3.id, Role::User, 5..5),
+ ]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..6),
+ (message_3.id, Role::User, 6..7),
+ ]
+ );
+
+ // Deleting across message boundaries merges the messages.
+ buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_3.id, Role::User, 3..4),
+ ]
+ );
+
+ // Undoing the deletion should also undo the merge.
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..6),
+ (message_3.id, Role::User, 6..7),
+ ]
+ );
+
+ // Redoing the deletion should also redo the merge.
+ buffer.update(cx, |buffer, cx| buffer.redo(cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_3.id, Role::User, 3..4),
+ ]
+ );
+
+ // Ensure we can still insert after a merged message.
+ let message_5 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_5.id, Role::System, 3..4),
+ (message_3.id, Role::User, 4..5)
+ ]
+ );
+ }
+
+ #[gpui::test]
+ fn test_message_splitting(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ FakeCompletionProvider::setup_test(cx);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+
+ let context = cx.new_model(|cx| Context::local(registry, None, cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
+ });
+
+ let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
+ let message_2 = message_2.unwrap();
+
+ // We recycle newlines in the middle of a split message
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..16),
+ ]
+ );
+
+ let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
+ let message_3 = message_3.unwrap();
+
+ // We don't recycle newlines at the end of a split message
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..17),
+ ]
+ );
+
+ let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
+ let message_4 = message_4.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..17),
+ ]
+ );
+
+ let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
+ let message_5 = message_5.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..10),
+ (message_5.id, Role::User, 10..18),
+ ]
+ );
+
+ let (message_6, message_7) =
+ context.update(cx, |context, cx| context.split_message(14..16, cx));
+ let message_6 = message_6.unwrap();
+ let message_7 = message_7.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..10),
+ (message_5.id, Role::User, 10..14),
+ (message_6.id, Role::User, 14..17),
+ (message_7.id, Role::User, 17..19),
+ ]
+ );
+ }
+
+ #[gpui::test]
+ fn test_messages_for_offsets(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ FakeCompletionProvider::setup_test(cx);
+ cx.set_global(settings_store);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+ let context = cx.new_model(|cx| Context::local(registry, None, cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
+ let message_2 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
+
+ let message_3 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
+
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..8),
+ (message_3.id, Role::User, 8..11)
+ ]
+ );
+
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 4, 9], cx),
+ [message_1.id, message_2.id, message_3.id]
+ );
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 1, 11], cx),
+ [message_1.id, message_3.id]
+ );
+
+ let message_4 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..8),
+ (message_3.id, Role::User, 8..12),
+ (message_4.id, Role::User, 12..12)
+ ]
+ );
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
+ [message_1.id, message_2.id, message_3.id, message_4.id]
+ );
+
+ fn message_ids_for_offsets(
+ context: &Model<Context>,
+ offsets: &[usize],
+ cx: &AppContext,
+ ) -> Vec<MessageId> {
+ context
+ .read(cx)
+ .messages_for_offsets(offsets.iter().copied(), cx)
+ .into_iter()
+ .map(|message| message.id)
+ .collect()
+ }
+ }
+
+ #[gpui::test]
+ async fn test_slash_commands(cx: &mut TestAppContext) {
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(FakeCompletionProvider::setup_test);
+ cx.update(Project::init_settings);
+ cx.update(assistant_panel::init);
+ let fs = FakeFs::new(cx.background_executor.clone());
+
+ fs.insert_tree(
+ "/test",
+ json!({
+ "src": {
+ "lib.rs": "fn one() -> usize { 1 }",
+ "main.rs": "
+ use crate::one;
+ fn main() { one(); }
+ ".unindent(),
+ }
+ }),
+ )
+ .await;
+
+ let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
+ slash_command_registry.register_command(file_command::FileSlashCommand, false);
+ slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
+
+ let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+ let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
+
+ let output_ranges = Rc::new(RefCell::new(HashSet::default()));
+ context.update(cx, |_, cx| {
+ cx.subscribe(&context, {
+ let ranges = output_ranges.clone();
+ move |_, _, event, _| match event {
+ ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
+ for range in removed {
+ ranges.borrow_mut().remove(range);
+ }
+ for command in updated {
+ ranges.borrow_mut().insert(command.source_range.clone());
+ }
+ }
+ _ => {}
+ }
+ })
+ .detach();
+ });
+
+ let buffer = context.read_with(cx, |context, _| context.buffer.clone());
+
+ // Insert a slash command
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ «/file src/lib.rs»
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ // Edit the argument of the slash command.
+ buffer.update(cx, |buffer, cx| {
+ let edit_offset = buffer.text().find("lib.rs").unwrap();
+ buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ «/file src/main.rs»
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ // Edit the name of the slash command, using one that doesn't exist.
+ buffer.update(cx, |buffer, cx| {
+ let edit_offset = buffer.text().find("/file").unwrap();
+ buffer.edit(
+ [(edit_offset..edit_offset + "/file".len(), "/unknown")],
+ None,
+ cx,
+ );
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ /unknown src/main.rs
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ #[track_caller]
+ fn assert_text_and_output_ranges(
+ buffer: &Model<Buffer>,
+ ranges: &HashSet<Range<language::Anchor>>,
+ expected_marked_text: &str,
+ cx: &mut TestAppContext,
+ ) {
+ let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
+ let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
+ let mut ranges = ranges
+ .iter()
+ .map(|range| range.to_offset(buffer))
+ .collect::<Vec<_>>();
+ ranges.sort_by_key(|a| a.start);
+ (buffer.text(), ranges)
+ });
+
+ assert_eq!(actual_text, expected_text);
+ assert_eq!(actual_ranges, expected_ranges);
+ }
+ }
+
+ #[test]
+ fn test_parse_next_edit_suggestion() {
+ let text = "
+ some output:
+
+ ```edit src/foo.rs
+ let a = 1;
+ let b = 2;
+ ---
+ let w = 1;
+ let x = 2;
+ let y = 3;
+ let z = 4;
+ ```
+
+ some more output:
+
+ ```edit src/foo.rs
+ let c = 1;
+ ---
+ ```
+
+ and the conclusion.
+ "
+ .unindent();
+
+ let rope = Rope::from(text.as_str());
+ let mut lines = rope.chunks().lines();
+ let mut suggestions = vec![];
+ while let Some(suggestion) = parse_next_edit_suggestion(&mut lines) {
+ suggestions.push((
+ suggestion.path.clone(),
+ text[suggestion.old_text_range].to_string(),
+ text[suggestion.new_text_range].to_string(),
+ ));
+ }
+
+ assert_eq!(
+ suggestions,
+ vec![
+ (
+ Path::new("src/foo.rs").into(),
+ [
+ " let a = 1;", //
+ " let b = 2;",
+ "",
+ ]
+ .join("\n"),
+ [
+ " let w = 1;",
+ " let x = 2;",
+ " let y = 3;",
+ " let z = 4;",
+ "",
+ ]
+ .join("\n"),
+ ),
+ (
+ Path::new("src/foo.rs").into(),
+ [
+ " let c = 1;", //
+ "",
+ ]
+ .join("\n"),
+ String::new(),
+ )
+ ]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_serialization(cx: &mut TestAppContext) {
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(FakeCompletionProvider::setup_test);
+ cx.update(assistant_panel::init);
+ let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+ let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
+ let buffer = context.read_with(cx, |context, _| context.buffer.clone());
+ let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
+ let message_1 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ let message_2 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
+ buffer.finalize_last_transaction();
+ });
+ let _message_3 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
+ assert_eq!(
+ cx.read(|cx| messages(&context, cx)),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+
+ let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
+ let deserialized_context = cx.new_model(|cx| {
+ Context::deserialize(
+ serialized_context,
+ Default::default(),
+ registry.clone(),
+ None,
+ cx,
+ )
+ });
+ let deserialized_buffer =
+ deserialized_context.read_with(cx, |context, _| context.buffer.clone());
+ assert_eq!(
+ deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
+ "a\nb\nc\n"
+ );
+ assert_eq!(
+ cx.read(|cx| messages(&deserialized_context, cx)),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+ }
+
+ #[gpui::test(iterations = 100)]
+ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
+ let min_peers = env::var("MIN_PEERS")
+ .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
+ .unwrap_or(2);
+ let max_peers = env::var("MAX_PEERS")
+ .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
+ .unwrap_or(5);
+ let operations = env::var("OPERATIONS")
+ .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
+ .unwrap_or(50);
+
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(FakeCompletionProvider::setup_test);
+ cx.update(assistant_panel::init);
+ let slash_commands = cx.update(SlashCommandRegistry::default_global);
+ slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
+ slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
+ slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
+
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
+ let network = Arc::new(Mutex::new(Network::new(rng.clone())));
+ let mut contexts = Vec::new();
+
+ let num_peers = rng.gen_range(min_peers..=max_peers);
+ let context_id = ContextId::new();
+ for i in 0..num_peers {
+ let context = cx.new_model(|cx| {
+ Context::new(
+ context_id.clone(),
+ i as ReplicaId,
+ language::Capability::ReadWrite,
+ registry.clone(),
+ None,
+ cx,
+ )
+ });
+
+ cx.update(|cx| {
+ cx.subscribe(&context, {
+ let network = network.clone();
+ move |_, event, _| {
+ if let ContextEvent::Operation(op) = event {
+ network
+ .lock()
+ .broadcast(i as ReplicaId, vec![op.to_proto()]);
+ }
+ }
+ })
+ .detach();
+ });
+
+ contexts.push(context);
+ network.lock().add_peer(i as ReplicaId);
+ }
+
+ let mut mutation_count = operations;
+
+ while mutation_count > 0
+ || !network.lock().is_idle()
+ || network.lock().contains_disconnected_peers()
+ {
+ let context_index = rng.gen_range(0..contexts.len());
+ let context = &contexts[context_index];
+
+ match rng.gen_range(0..100) {
+ 0..=29 if mutation_count > 0 => {
+ log::info!("Context {}: edit buffer", context_index);
+ context.update(cx, |context, cx| {
+ context
+ .buffer
+ .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
+ });
+ mutation_count -= 1;
+ }
+ 30..=44 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
+ log::info!("Context {}: split message at {:?}", context_index, range);
+ context.split_message(range, cx);
+ });
+ mutation_count -= 1;
+ }
+ 45..=59 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ if let Some(message) = context.messages(cx).choose(&mut rng) {
+ let role = *[Role::User, Role::Assistant, Role::System]
+ .choose(&mut rng)
+ .unwrap();
+ log::info!(
+ "Context {}: insert message after {:?} with {:?}",
+ context_index,
+ message.id,
+ role
+ );
+ context.insert_message_after(message.id, role, MessageStatus::Done, cx);
+ }
+ });
+ mutation_count -= 1;
+ }
+ 60..=74 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ let command_text = "/".to_string()
+ + slash_commands
+ .command_names()
+ .choose(&mut rng)
+ .unwrap()
+ .clone()
+ .as_ref();
+
+ let command_range = context.buffer.update(cx, |buffer, cx| {
+ let offset = buffer.random_byte_range(0, &mut rng).start;
+ buffer.edit(
+ [(offset..offset, format!("\n{}\n", command_text))],
+ None,
+ cx,
+ );
+ offset + 1..offset + 1 + command_text.len()
+ });
+
+ let output_len = rng.gen_range(1..=10);
+ let output_text = RandomCharIter::new(&mut rng)
+ .filter(|c| *c != '\r')
+ .take(output_len)
+ .collect::<String>();
+
+ let num_sections = rng.gen_range(0..=3);
+ let mut sections = Vec::with_capacity(num_sections);
+ for _ in 0..num_sections {
+ let section_start = rng.gen_range(0..output_len);
+ let section_end = rng.gen_range(section_start..=output_len);
+ sections.push(SlashCommandOutputSection {
+ range: section_start..section_end,
+ icon: ui::IconName::Ai,
+ label: "section".into(),
+ });
+ }
+
+ log::info!(
+ "Context {}: insert slash command output at {:?} with {:?}",
+ context_index,
+ command_range,
+ sections
+ );
+
+ let command_range =
+ context.buffer.read(cx).anchor_after(command_range.start)
+ ..context.buffer.read(cx).anchor_after(command_range.end);
+ context.insert_command_output(
+ command_range,
+ Task::ready(Ok(SlashCommandOutput {
+ text: output_text,
+ sections,
+ run_commands_in_text: false,
+ })),
+ true,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+ mutation_count -= 1;
+ }
+ 75..=84 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ if let Some(message) = context.messages(cx).choose(&mut rng) {
+ let new_status = match rng.gen_range(0..3) {
+ 0 => MessageStatus::Done,
+ 1 => MessageStatus::Pending,
+ _ => MessageStatus::Error(SharedString::from("Random error")),
+ };
+ log::info!(
+ "Context {}: update message {:?} status to {:?}",
+ context_index,
+ message.id,
+ new_status
+ );
+ context.update_metadata(message.id, cx, |metadata| {
+ metadata.status = new_status;
+ });
+ }
+ });
+ mutation_count -= 1;
+ }
+ _ => {
+ let replica_id = context_index as ReplicaId;
+ if network.lock().is_disconnected(replica_id) {
+ network.lock().reconnect_peer(replica_id, 0);
+
+ let (ops_to_send, ops_to_receive) = cx.read(|cx| {
+ let host_context = &contexts[0].read(cx);
+ let guest_context = context.read(cx);
+ (
+ guest_context.serialize_ops(&host_context.version(cx), cx),
+ host_context.serialize_ops(&guest_context.version(cx), cx),
+ )
+ });
+ let ops_to_send = ops_to_send.await;
+ let ops_to_receive = ops_to_receive
+ .await
+ .into_iter()
+ .map(ContextOperation::from_proto)
+ .collect::<Result<Vec<_>>>()
+ .unwrap();
+ log::info!(
+ "Context {}: reconnecting. Sent {} operations, received {} operations",
+ context_index,
+ ops_to_send.len(),
+ ops_to_receive.len()
+ );
+
+ network.lock().broadcast(replica_id, ops_to_send);
+ context
+ .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
+ .unwrap();
+ } else if rng.gen_bool(0.1) && replica_id != 0 {
+ log::info!("Context {}: disconnecting", context_index);
+ network.lock().disconnect_peer(replica_id);
+ } else if network.lock().has_unreceived(replica_id) {
+ log::info!("Context {}: applying operations", context_index);
+ let ops = network.lock().receive(replica_id);
+ let ops = ops
+ .into_iter()
+ .map(ContextOperation::from_proto)
+ .collect::<Result<Vec<_>>>()
+ .unwrap();
+ context
+ .update(cx, |context, cx| context.apply_ops(ops, cx))
+ .unwrap();
+ }
+ }
+ }
+ }
+
+ cx.read(|cx| {
+ let first_context = contexts[0].read(cx);
+ for context in &contexts[1..] {
+ let context = context.read(cx);
+ assert!(context.pending_ops.is_empty());
+ assert_eq!(
+ context.buffer.read(cx).text(),
+ first_context.buffer.read(cx).text(),
+ "Context {} text != Context 0 text",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.message_anchors,
+ first_context.message_anchors,
+ "Context {} messages != Context 0 messages",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.messages_metadata,
+ first_context.messages_metadata,
+ "Context {} message metadata != Context 0 message metadata",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.slash_command_output_sections,
+ first_context.slash_command_output_sections,
+ "Context {} slash command output sections != Context 0 slash command output sections",
+ context.buffer.read(cx).replica_id()
+ );
+ }
+ });
+ }
+
+ fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
+ context
+ .read(cx)
+ .messages(cx)
+ .map(|message| (message.id, message.role, message.offset_range))
+ .collect()
+ }
+
+ #[derive(Clone)]
+ struct FakeSlashCommand(String);
+
+ impl SlashCommand for FakeSlashCommand {
+ fn name(&self) -> String {
+ self.0.clone()
+ }
+
+ fn description(&self) -> String {
+ format!("Fake slash command: {}", self.0)
+ }
+
+ fn menu_text(&self) -> String {
+ format!("Run fake command: {}", self.0)
+ }
+
+ fn complete_argument(
+ self: Arc<Self>,
+ _query: String,
+ _cancel: Arc<AtomicBool>,
+ _workspace: Option<WeakView<Workspace>>,
+ _cx: &mut AppContext,
+ ) -> Task<Result<Vec<ArgumentCompletion>>> {
+ Task::ready(Ok(vec![]))
+ }
+
+ fn requires_argument(&self) -> bool {
+ false
+ }
+
+ fn run(
+ self: Arc<Self>,
+ _argument: Option<&str>,
+ _workspace: WeakView<Workspace>,
+ _delegate: Arc<dyn LspAdapterDelegate>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<SlashCommandOutput>> {
+ Task::ready(Ok(SlashCommandOutput {
+ text: format!("Executed fake command: {}", self.0),
+ sections: vec![],
+ run_commands_in_text: false,
+ }))
+ }
+ }
+}
@@ -1,97 +1,117 @@
-use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
-use anyhow::{anyhow, Result};
-use assistant_slash_command::SlashCommandOutputSection;
-use collections::HashMap;
+use crate::{
+ Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
+ SavedContextMetadata,
+};
+use anyhow::{anyhow, Context as _, Result};
+use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
+use clock::ReplicaId;
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
-use gpui::{AppContext, Model, ModelContext, Task};
+use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel};
+use language::LanguageRegistry;
use paths::contexts_dir;
+use project::Project;
use regex::Regex;
-use serde::{Deserialize, Serialize};
-use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration};
-use ui::Context;
+use std::{
+ cmp::Reverse,
+ ffi::OsStr,
+ mem,
+ path::{Path, PathBuf},
+ sync::Arc,
+ time::Duration,
+};
use util::{ResultExt, TryFutureExt};
-#[derive(Serialize, Deserialize)]
-pub struct SavedMessage {
- pub id: MessageId,
- pub start: usize,
-}
-
-#[derive(Serialize, Deserialize)]
-pub struct SavedContext {
- pub id: Option<String>,
- pub zed: String,
- pub version: String,
- pub text: String,
- pub messages: Vec<SavedMessage>,
- pub message_metadata: HashMap<MessageId, MessageMetadata>,
- pub summary: String,
- pub slash_command_output_sections: Vec<SlashCommandOutputSection<usize>>,
-}
-
-impl SavedContext {
- pub const VERSION: &'static str = "0.3.0";
-}
-
-#[derive(Serialize, Deserialize)]
-pub struct SavedContextV0_2_0 {
- pub id: Option<String>,
- pub zed: String,
- pub version: String,
- pub text: String,
- pub messages: Vec<SavedMessage>,
- pub message_metadata: HashMap<MessageId, MessageMetadata>,
- pub summary: String,
-}
-
-#[derive(Serialize, Deserialize)]
-struct SavedContextV0_1_0 {
- id: Option<String>,
- zed: String,
- version: String,
- text: String,
- messages: Vec<SavedMessage>,
- message_metadata: HashMap<MessageId, MessageMetadata>,
- summary: String,
- api_url: Option<String>,
- model: OpenAiModel,
+pub fn init(client: &Arc<Client>) {
+ client.add_model_message_handler(ContextStore::handle_advertise_contexts);
+ client.add_model_request_handler(ContextStore::handle_open_context);
+ client.add_model_message_handler(ContextStore::handle_update_context);
+ client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
}
#[derive(Clone)]
-pub struct SavedContextMetadata {
- pub title: String,
- pub path: PathBuf,
- pub mtime: chrono::DateTime<chrono::Local>,
+pub struct RemoteContextMetadata {
+ pub id: ContextId,
+ pub summary: Option<String>,
}
pub struct ContextStore {
+ contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
+ host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
+ languages: Arc<LanguageRegistry>,
+ telemetry: Arc<Telemetry>,
_watch_updates: Task<Option<()>>,
+ client: Arc<Client>,
+ project: Model<Project>,
+ project_is_shared: bool,
+ client_subscription: Option<client::Subscription>,
+ _project_subscriptions: Vec<gpui::Subscription>,
+}
+
+enum ContextHandle {
+ Weak(WeakModel<Context>),
+ Strong(Model<Context>),
+}
+
+impl ContextHandle {
+ fn upgrade(&self) -> Option<Model<Context>> {
+ match self {
+ ContextHandle::Weak(weak) => weak.upgrade(),
+ ContextHandle::Strong(strong) => Some(strong.clone()),
+ }
+ }
+
+ fn downgrade(&self) -> WeakModel<Context> {
+ match self {
+ ContextHandle::Weak(weak) => weak.clone(),
+ ContextHandle::Strong(strong) => strong.downgrade(),
+ }
+ }
}
impl ContextStore {
- pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
+ pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
+ let fs = project.read(cx).fs().clone();
+ let languages = project.read(cx).languages().clone();
+ let telemetry = project.read(cx).client().telemetry().clone();
cx.spawn(|mut cx| async move {
const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
- let this = cx.new_model(|cx: &mut ModelContext<Self>| Self {
- contexts_metadata: Vec::new(),
- fs,
- _watch_updates: cx.spawn(|this, mut cx| {
- async move {
- while events.next().await.is_some() {
- this.update(&mut cx, |this, cx| this.reload(cx))?
- .await
- .log_err();
+ let this = cx.new_model(|cx: &mut ModelContext<Self>| {
+ let mut this = Self {
+ contexts: Vec::new(),
+ contexts_metadata: Vec::new(),
+ host_contexts: Vec::new(),
+ fs,
+ languages,
+ telemetry,
+ _watch_updates: cx.spawn(|this, mut cx| {
+ async move {
+ while events.next().await.is_some() {
+ this.update(&mut cx, |this, cx| this.reload(cx))?
+ .await
+ .log_err();
+ }
+ anyhow::Ok(())
}
- anyhow::Ok(())
- }
- .log_err()
- }),
+ .log_err()
+ }),
+ client_subscription: None,
+ _project_subscriptions: vec![
+ cx.observe(&project, Self::handle_project_changed),
+ cx.subscribe(&project, Self::handle_project_event),
+ ],
+ project_is_shared: false,
+ client: project.read(cx).client(),
+ project: project.clone(),
+ };
+ this.handle_project_changed(project, cx);
+ this.synchronize_contexts(cx);
+ this
})?;
this.update(&mut cx, |this, cx| this.reload(cx))?
.await
@@ -100,52 +120,431 @@ impl ContextStore {
})
}
- pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> {
- let fs = self.fs.clone();
- cx.background_executor().spawn(async move {
- let saved_context = fs.load(&path).await?;
- let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?;
- match saved_context_json
- .get("version")
- .ok_or_else(|| anyhow!("version not found"))?
- {
- serde_json::Value::String(version) => match version.as_str() {
- SavedContext::VERSION => {
- Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
- }
- "0.2.0" => {
- let saved_context =
- serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
- Ok(SavedContext {
- id: saved_context.id,
- zed: saved_context.zed,
- version: saved_context.version,
- text: saved_context.text,
- messages: saved_context.messages,
- message_metadata: saved_context.message_metadata,
- summary: saved_context.summary,
- slash_command_output_sections: Vec::new(),
+ async fn handle_advertise_contexts(
+ this: Model<Self>,
+ envelope: TypedEnvelope<proto::AdvertiseContexts>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, cx| {
+ this.host_contexts = envelope
+ .payload
+ .contexts
+ .into_iter()
+ .map(|context| RemoteContextMetadata {
+ id: ContextId::from_proto(context.context_id),
+ summary: context.summary,
+ })
+ .collect();
+ cx.notify();
+ })
+ }
+
+ async fn handle_open_context(
+ this: Model<Self>,
+ envelope: TypedEnvelope<proto::OpenContext>,
+ mut cx: AsyncAppContext,
+ ) -> Result<proto::OpenContextResponse> {
+ let context_id = ContextId::from_proto(envelope.payload.context_id);
+ let operations = this.update(&mut cx, |this, cx| {
+ if this.project.read(cx).is_remote() {
+ return Err(anyhow!("only the host contexts can be opened"));
+ }
+
+ let context = this
+ .loaded_context_for_id(&context_id, cx)
+ .context("context not found")?;
+ if context.read(cx).replica_id() != ReplicaId::default() {
+ return Err(anyhow!("context must be opened via the host"));
+ }
+
+ anyhow::Ok(
+ context
+ .read(cx)
+ .serialize_ops(&ContextVersion::default(), cx),
+ )
+ })??;
+ let operations = operations.await;
+ Ok(proto::OpenContextResponse {
+ context: Some(proto::Context { operations }),
+ })
+ }
+
+ async fn handle_update_context(
+ this: Model<Self>,
+ envelope: TypedEnvelope<proto::UpdateContext>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, cx| {
+ let context_id = ContextId::from_proto(envelope.payload.context_id);
+ if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
+ let operation_proto = envelope.payload.operation.context("invalid operation")?;
+ let operation = ContextOperation::from_proto(operation_proto)?;
+ context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
+ }
+ Ok(())
+ })?
+ }
+
+ async fn handle_synchronize_contexts(
+ this: Model<Self>,
+ envelope: TypedEnvelope<proto::SynchronizeContexts>,
+ mut cx: AsyncAppContext,
+ ) -> Result<proto::SynchronizeContextsResponse> {
+ this.update(&mut cx, |this, cx| {
+ if this.project.read(cx).is_remote() {
+ return Err(anyhow!("only the host can synchronize contexts"));
+ }
+
+ let mut local_versions = Vec::new();
+ for remote_version_proto in envelope.payload.contexts {
+ let remote_version = ContextVersion::from_proto(&remote_version_proto);
+ let context_id = ContextId::from_proto(remote_version_proto.context_id);
+ if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
+ let context = context.read(cx);
+ let operations = context.serialize_ops(&remote_version, cx);
+ local_versions.push(context.version(cx).to_proto(context_id.clone()));
+ let client = this.client.clone();
+ let project_id = envelope.payload.project_id;
+ cx.background_executor()
+ .spawn(async move {
+ let operations = operations.await;
+ for operation in operations {
+ client.send(proto::UpdateContext {
+ project_id,
+ context_id: context_id.to_proto(),
+ operation: Some(operation),
+ })?;
+ }
+ anyhow::Ok(())
})
+ .detach_and_log_err(cx);
+ }
+ }
+
+ this.advertise_contexts(cx);
+
+ anyhow::Ok(proto::SynchronizeContextsResponse {
+ contexts: local_versions,
+ })
+ })?
+ }
+
+ fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
+ let is_shared = self.project.read(cx).is_shared();
+ let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
+ if is_shared == was_shared {
+ return;
+ }
+
+ if is_shared {
+ self.contexts.retain_mut(|context| {
+ if let Some(strong_context) = context.upgrade() {
+ *context = ContextHandle::Strong(strong_context);
+ true
+ } else {
+ false
+ }
+ });
+ let remote_id = self.project.read(cx).remote_id().unwrap();
+ self.client_subscription = self
+ .client
+ .subscribe_to_entity(remote_id)
+ .log_err()
+ .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
+ self.advertise_contexts(cx);
+ } else {
+ self.client_subscription = None;
+ }
+ }
+
+ fn handle_project_event(
+ &mut self,
+ _: Model<Project>,
+ event: &project::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match event {
+ project::Event::Reshared => {
+ self.advertise_contexts(cx);
+ }
+ project::Event::HostReshared | project::Event::Rejoined => {
+ self.synchronize_contexts(cx);
+ }
+ project::Event::DisconnectedFromHost => {
+ self.contexts.retain_mut(|context| {
+ if let Some(strong_context) = context.upgrade() {
+ *context = ContextHandle::Weak(context.downgrade());
+ strong_context.update(cx, |context, cx| {
+ if context.replica_id() != ReplicaId::default() {
+ context.set_capability(language::Capability::ReadOnly, cx);
+ }
+ });
+ true
+ } else {
+ false
}
- "0.1.0" => {
- let saved_context =
- serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
- Ok(SavedContext {
- id: saved_context.id,
- zed: saved_context.zed,
- version: saved_context.version,
- text: saved_context.text,
- messages: saved_context.messages,
- message_metadata: saved_context.message_metadata,
- summary: saved_context.summary,
- slash_command_output_sections: Vec::new(),
- })
+ });
+ self.host_contexts.clear();
+ cx.notify();
+ }
+ _ => {}
+ }
+ }
+
+ pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
+ let context = cx.new_model(|cx| {
+ Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx)
+ });
+ self.register_context(&context, cx);
+ context
+ }
+
+ pub fn open_local_context(
+ &mut self,
+ path: PathBuf,
+ cx: &ModelContext<Self>,
+ ) -> Task<Result<Model<Context>>> {
+ if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
+ return Task::ready(Ok(existing_context));
+ }
+
+ let fs = self.fs.clone();
+ let languages = self.languages.clone();
+ let telemetry = self.telemetry.clone();
+ let load = cx.background_executor().spawn({
+ let path = path.clone();
+ async move {
+ let saved_context = fs.load(&path).await?;
+ SavedContext::from_json(&saved_context)
+ }
+ });
+
+ cx.spawn(|this, mut cx| async move {
+ let saved_context = load.await?;
+ let context = cx.new_model(|cx| {
+ Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx)
+ })?;
+ this.update(&mut cx, |this, cx| {
+ if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
+ existing_context
+ } else {
+ this.register_context(&context, cx);
+ context
+ }
+ })
+ })
+ }
+
+ fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
+ self.contexts.iter().find_map(|context| {
+ let context = context.upgrade()?;
+ if context.read(cx).path() == Some(path) {
+ Some(context)
+ } else {
+ None
+ }
+ })
+ }
+
+ fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
+ self.contexts.iter().find_map(|context| {
+ let context = context.upgrade()?;
+ if context.read(cx).id() == id {
+ Some(context)
+ } else {
+ None
+ }
+ })
+ }
+
+ pub fn open_remote_context(
+ &mut self,
+ context_id: ContextId,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Model<Context>>> {
+ let project = self.project.read(cx);
+ let Some(project_id) = project.remote_id() else {
+ return Task::ready(Err(anyhow!("project was not remote")));
+ };
+ if project.is_local() {
+ return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
+ }
+
+ if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
+ return Task::ready(Ok(context));
+ }
+
+ let replica_id = project.replica_id();
+ let capability = project.capability();
+ let language_registry = self.languages.clone();
+ let telemetry = self.telemetry.clone();
+ let request = self.client.request(proto::OpenContext {
+ project_id,
+ context_id: context_id.to_proto(),
+ });
+ cx.spawn(|this, mut cx| async move {
+ let response = request.await?;
+ let context_proto = response.context.context("invalid context")?;
+ let context = cx.new_model(|cx| {
+ Context::new(
+ context_id.clone(),
+ replica_id,
+ capability,
+ language_registry,
+ Some(telemetry),
+ cx,
+ )
+ })?;
+ let operations = cx
+ .background_executor()
+ .spawn(async move {
+ context_proto
+ .operations
+ .into_iter()
+ .map(|op| ContextOperation::from_proto(op))
+ .collect::<Result<Vec<_>>>()
+ })
+ .await?;
+ context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
+ this.update(&mut cx, |this, cx| {
+ if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
+ existing_context
+ } else {
+ this.register_context(&context, cx);
+ this.synchronize_contexts(cx);
+ context
+ }
+ })
+ })
+ }
+
+ fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
+ let handle = if self.project_is_shared {
+ ContextHandle::Strong(context.clone())
+ } else {
+ ContextHandle::Weak(context.downgrade())
+ };
+ self.contexts.push(handle);
+ self.advertise_contexts(cx);
+ cx.subscribe(context, Self::handle_context_event).detach();
+ }
+
+ fn handle_context_event(
+ &mut self,
+ context: Model<Context>,
+ event: &ContextEvent,
+ cx: &mut ModelContext<Self>,
+ ) {
+ let Some(project_id) = self.project.read(cx).remote_id() else {
+ return;
+ };
+
+ match event {
+ ContextEvent::SummaryChanged => {
+ self.advertise_contexts(cx);
+ }
+ ContextEvent::Operation(operation) => {
+ let context_id = context.read(cx).id().to_proto();
+ let operation = operation.to_proto();
+ self.client
+ .send(proto::UpdateContext {
+ project_id,
+ context_id,
+ operation: Some(operation),
+ })
+ .log_err();
+ }
+ _ => {}
+ }
+ }
+
+ fn advertise_contexts(&self, cx: &AppContext) {
+ let Some(project_id) = self.project.read(cx).remote_id() else {
+ return;
+ };
+
+ // For now, only the host can advertise their open contexts.
+ if self.project.read(cx).is_remote() {
+ return;
+ }
+
+ let contexts = self
+ .contexts
+ .iter()
+ .rev()
+ .filter_map(|context| {
+ let context = context.upgrade()?.read(cx);
+ if context.replica_id() == ReplicaId::default() {
+ Some(proto::ContextMetadata {
+ context_id: context.id().to_proto(),
+ summary: context.summary().map(|summary| summary.text.clone()),
+ })
+ } else {
+ None
+ }
+ })
+ .collect();
+ self.client
+ .send(proto::AdvertiseContexts {
+ project_id,
+ contexts,
+ })
+ .ok();
+ }
+
+ fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
+ let Some(project_id) = self.project.read(cx).remote_id() else {
+ return;
+ };
+
+ let contexts = self
+ .contexts
+ .iter()
+ .filter_map(|context| {
+ let context = context.upgrade()?.read(cx);
+ if context.replica_id() != ReplicaId::default() {
+ Some(context.version(cx).to_proto(context.id().clone()))
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ let client = self.client.clone();
+ let request = self.client.request(proto::SynchronizeContexts {
+ project_id,
+ contexts,
+ });
+ cx.spawn(|this, cx| async move {
+ let response = request.await?;
+
+ let mut context_ids = Vec::new();
+ let mut operations = Vec::new();
+ this.read_with(&cx, |this, cx| {
+ for context_version_proto in response.contexts {
+ let context_version = ContextVersion::from_proto(&context_version_proto);
+ let context_id = ContextId::from_proto(context_version_proto.context_id);
+ if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
+ context_ids.push(context_id);
+ operations.push(context.read(cx).serialize_ops(&context_version, cx));
}
- _ => Err(anyhow!("unrecognized saved context version: {}", version)),
- },
- _ => Err(anyhow!("version not found on saved context")),
+ }
+ })?;
+
+ let operations = futures::future::join_all(operations).await;
+ for (context_id, operations) in context_ids.into_iter().zip(operations) {
+ for operation in operations {
+ client.send(proto::UpdateContext {
+ project_id,
+ context_id: context_id.to_proto(),
+ operation: Some(operation),
+ })?;
+ }
}
+
+ anyhow::Ok(())
})
+ .detach_and_log_err(cx);
}
pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
@@ -178,6 +577,10 @@ impl ContextStore {
})
}
+ pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
+ &self.host_contexts
+ }
+
fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
@@ -3,7 +3,6 @@ use crate::{
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use anyhow::{anyhow, Result};
-use assistant_slash_command::SlashCommandRegistry;
use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet};
use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle};
@@ -448,7 +447,6 @@ impl PromptLibrary {
self.set_active_prompt(Some(prompt_id), cx);
} else if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
let language_registry = self.language_registry.clone();
- let commands = SlashCommandRegistry::global(cx);
let prompt = self.store.load(prompt_id);
self.pending_load = cx.spawn(|this, mut cx| async move {
let prompt = prompt.await;
@@ -477,7 +475,7 @@ impl PromptLibrary {
editor.set_use_modal_editing(false);
editor.set_current_line_highlight(Some(CurrentLineHighlight::None));
editor.set_completion_provider(Box::new(
- SlashCommandCompletionProvider::new(commands, None, None),
+ SlashCommandCompletionProvider::new(None, None),
));
if focus {
editor.focus(cx);
@@ -31,7 +31,6 @@ pub mod tabs_command;
pub mod term_command;
pub(crate) struct SlashCommandCompletionProvider {
- commands: Arc<SlashCommandRegistry>,
cancel_flag: Mutex<Arc<AtomicBool>>,
editor: Option<WeakView<ContextEditor>>,
workspace: Option<WeakView<Workspace>>,
@@ -46,14 +45,12 @@ pub(crate) struct SlashCommandLine {
impl SlashCommandCompletionProvider {
pub fn new(
- commands: Arc<SlashCommandRegistry>,
editor: Option<WeakView<ContextEditor>>,
workspace: Option<WeakView<Workspace>>,
) -> Self {
Self {
cancel_flag: Mutex::new(Arc::new(AtomicBool::new(false))),
editor,
- commands,
workspace,
}
}
@@ -65,8 +62,8 @@ impl SlashCommandCompletionProvider {
name_range: Range<Anchor>,
cx: &mut WindowContext,
) -> Task<Result<Vec<project::Completion>>> {
- let candidates = self
- .commands
+ let commands = SlashCommandRegistry::global(cx);
+ let candidates = commands
.command_names()
.into_iter()
.enumerate()
@@ -76,7 +73,6 @@ impl SlashCommandCompletionProvider {
char_bag: def.as_ref().into(),
})
.collect::<Vec<_>>();
- let commands = self.commands.clone();
let command_name = command_name.to_string();
let editor = self.editor.clone();
let workspace = self.workspace.clone();
@@ -155,7 +151,8 @@ impl SlashCommandCompletionProvider {
flag.store(true, SeqCst);
*flag = new_cancel_flag.clone();
- if let Some(command) = self.commands.command(command_name) {
+ let commands = SlashCommandRegistry::global(cx);
+ if let Some(command) = commands.command(command_name) {
let completions = command.complete_argument(
argument,
new_cancel_flag.clone(),
@@ -67,7 +67,7 @@ pub struct SlashCommandOutput {
pub run_commands_in_text: bool,
}
-#[derive(Clone, Serialize, Deserialize)]
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SlashCommandOutputSection<T> {
pub range: Range<T>,
pub icon: IconName,
@@ -18,4 +18,5 @@ test-support = ["dep:parking_lot"]
[dependencies]
chrono.workspace = true
parking_lot = { workspace = true, optional = true }
+serde.workspace = true
smallvec.workspace = true
@@ -1,5 +1,6 @@
mod system_clock;
+use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{
cmp::{self, Ordering},
@@ -16,7 +17,7 @@ pub type Seq = u32;
/// A [Lamport timestamp](https://en.wikipedia.org/wiki/Lamport_timestamp),
/// used to determine the ordering of events in the editor.
-#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)]
+#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct Lamport {
pub replica_id: ReplicaId,
pub value: Seq,
@@ -161,6 +162,10 @@ impl Lamport {
}
}
+ pub fn as_u64(self) -> u64 {
+ ((self.value as u64) << 32) | (self.replica_id as u64)
+ }
+
pub fn tick(&mut self) -> Self {
let timestamp = *self;
self.value += 1;
@@ -71,6 +71,7 @@ util.workspace = true
uuid.workspace = true
[dev-dependencies]
+assistant = { workspace = true, features = ["test-support"] }
async-trait.workspace = true
audio.workspace = true
call = { workspace = true, features = ["test-support"] }
@@ -595,6 +595,14 @@ impl Server {
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
+ .add_request_handler(user_handler(
+ forward_mutating_project_request::<proto::OpenContext>,
+ ))
+ .add_request_handler(user_handler(
+ forward_mutating_project_request::<proto::SynchronizeContexts>,
+ ))
+ .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
+ .add_message_handler(update_context)
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
@@ -3056,6 +3064,53 @@ async fn update_buffer(
Ok(())
}
+async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(message.project_id);
+
+ let operation = message.operation.as_ref().context("invalid operation")?;
+ let capability = match operation.variant.as_ref() {
+ Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
+ if let Some(buffer_op) = buffer_op.operation.as_ref() {
+ match buffer_op.variant {
+ None | Some(proto::operation::Variant::UpdateSelections(_)) => {
+ Capability::ReadOnly
+ }
+ _ => Capability::ReadWrite,
+ }
+ } else {
+ Capability::ReadWrite
+ }
+ }
+ Some(_) => Capability::ReadWrite,
+ None => Capability::ReadOnly,
+ };
+
+ let guard = session
+ .db()
+ .await
+ .connections_for_buffer_update(
+ project_id,
+ session.principal_id(),
+ session.connection_id,
+ capability,
+ )
+ .await?;
+
+ let (host, guests) = &*guard;
+
+ broadcast(
+ Some(session.connection_id),
+ guests.iter().chain([host]).copied(),
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, message.clone())
+ },
+ );
+
+ Ok(())
+}
+
/// Notify other participants that a project has been updated.
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
request: T,
@@ -6,6 +6,7 @@ use crate::{
},
};
use anyhow::{anyhow, Result};
+use assistant::ContextStore;
use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{User, RECEIVE_TIMEOUT};
use collections::{HashMap, HashSet};
@@ -6449,3 +6450,123 @@ async fn test_preview_tabs(cx: &mut TestAppContext) {
assert!(!pane.can_navigate_forward());
});
}
+
+#[gpui::test(iterations = 10)]
+async fn test_context_collaboration_with_reconnect(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+ let active_call_a = cx_a.read(ActiveCall::global);
+
+ client_a.fs().insert_tree("/a", Default::default()).await;
+ let (project_a, _) = client_a.build_local_project("/a", cx_a).await;
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
+
+ // Client A sees that a guest has joined.
+ executor.run_until_parked();
+
+ project_a.read_with(cx_a, |project, _| {
+ assert_eq!(project.collaborators().len(), 1);
+ });
+ project_b.read_with(cx_b, |project, _| {
+ assert_eq!(project.collaborators().len(), 1);
+ });
+
+ let context_store_a = cx_a
+ .update(|cx| ContextStore::new(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let context_store_b = cx_b
+ .update(|cx| ContextStore::new(project_b.clone(), cx))
+ .await
+ .unwrap();
+
+ // Client A creates a new context.
+ let context_a = context_store_a.update(cx_a, |store, cx| store.create(cx));
+ executor.run_until_parked();
+
+ // Client B retrieves host's contexts and joins one.
+ let context_b = context_store_b
+ .update(cx_b, |store, cx| {
+ let host_contexts = store.host_contexts().to_vec();
+ assert_eq!(host_contexts.len(), 1);
+ store.open_remote_context(host_contexts[0].id.clone(), cx)
+ })
+ .await
+ .unwrap();
+
+ // Host and guest make changes
+ context_a.update(cx_a, |context, cx| {
+ context.buffer().update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "Host change\n")], None, cx)
+ })
+ });
+ context_b.update(cx_b, |context, cx| {
+ context.buffer().update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "Guest change\n")], None, cx)
+ })
+ });
+ executor.run_until_parked();
+ assert_eq!(
+ context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
+ "Guest change\nHost change\n"
+ );
+ assert_eq!(
+ context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
+ "Guest change\nHost change\n"
+ );
+
+ // Disconnect client A and make some changes while disconnected.
+ server.disconnect_client(client_a.peer_id().unwrap());
+ server.forbid_connections();
+ context_a.update(cx_a, |context, cx| {
+ context.buffer().update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "Host offline change\n")], None, cx)
+ })
+ });
+ context_b.update(cx_b, |context, cx| {
+ context.buffer().update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "Guest offline change\n")], None, cx)
+ })
+ });
+ executor.run_until_parked();
+ assert_eq!(
+ context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
+ "Host offline change\nGuest change\nHost change\n"
+ );
+ assert_eq!(
+ context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
+ "Guest offline change\nGuest change\nHost change\n"
+ );
+
+ // Allow client A to reconnect and verify that contexts converge.
+ server.allow_connections();
+ executor.advance_clock(RECEIVE_TIMEOUT);
+ assert_eq!(
+ context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
+ "Guest offline change\nHost offline change\nGuest change\nHost change\n"
+ );
+ assert_eq!(
+ context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
+ "Guest offline change\nHost offline change\nGuest change\nHost change\n"
+ );
+
+ // Client A disconnects without being able to reconnect. Context B becomes readonly.
+ server.forbid_connections();
+ server.disconnect_client(client_a.peer_id().unwrap());
+ executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+ context_b.read_with(cx_b, |context, cx| {
+ assert!(context.buffer().read(cx).read_only());
+ });
+}
@@ -294,6 +294,8 @@ impl TestServer {
menu::init();
dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
+ assistant::FakeCompletionProvider::setup_test(cx);
+ assistant::context_store::init(&client);
});
client
@@ -1903,6 +1903,10 @@ impl Buffer {
self.deferred_ops.insert(deferred_ops);
}
+ pub fn has_deferred_ops(&self) -> bool {
+ !self.deferred_ops.is_empty() || self.text.has_deferred_ops()
+ }
+
fn can_apply_op(&self, operation: &Operation) -> bool {
match operation {
Operation::Buffer(_) => {
@@ -1,7 +1,7 @@
//! Handles conversions of `language` items to and from the [`rpc`] protocol.
use crate::{diagnostic_set::DiagnosticEntry, CursorShape, Diagnostic};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context as _, Result};
use clock::ReplicaId;
use lsp::{DiagnosticSeverity, LanguageServerId};
use rpc::proto;
@@ -231,6 +231,21 @@ pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor {
}
}
+pub fn serialize_anchor_range(range: Range<Anchor>) -> proto::AnchorRange {
+ proto::AnchorRange {
+ start: Some(serialize_anchor(&range.start)),
+ end: Some(serialize_anchor(&range.end)),
+ }
+}
+
+/// Deserializes an [`Range<Anchor>`] from the RPC representation.
+pub fn deserialize_anchor_range(range: proto::AnchorRange) -> Result<Range<Anchor>> {
+ Ok(
+ deserialize_anchor(range.start.context("invalid anchor")?).context("invalid anchor")?
+ ..deserialize_anchor(range.end.context("invalid anchor")?).context("invalid anchor")?,
+ )
+}
+
// This behavior is currently copied in the collab database, for snapshotting channel notes
/// Deserializes an [`crate::Operation`] from the RPC representation.
pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operation> {
@@ -355,6 +355,9 @@ pub enum Event {
},
CollaboratorJoined(proto::PeerId),
CollaboratorLeft(proto::PeerId),
+ HostReshared,
+ Reshared,
+ Rejoined,
RefreshInlayHints,
RevealInProjectPanel(ProjectEntryId),
SnippetEdit(BufferId, Vec<(lsp::Range, Snippet)>),
@@ -1716,6 +1719,7 @@ impl Project {
self.shared_buffers.clear();
self.set_collaborators_from_proto(message.collaborators, cx)?;
self.metadata_changed(cx);
+ cx.emit(Event::Reshared);
Ok(())
}
@@ -1753,6 +1757,7 @@ impl Project {
.collect();
self.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync)
.unwrap();
+ cx.emit(Event::Rejoined);
cx.notify();
Ok(())
}
@@ -1805,9 +1810,11 @@ impl Project {
}
}
- self.client.send(proto::UnshareProject {
- project_id: remote_id,
- })?;
+ self.client
+ .send(proto::UnshareProject {
+ project_id: remote_id,
+ })
+ .ok();
Ok(())
} else {
@@ -8810,6 +8817,7 @@ impl Project {
.retain(|_, buffer| !matches!(buffer, OpenBuffer::Operations(_)));
this.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync)
.unwrap();
+ cx.emit(Event::HostReshared);
}
cx.emit(Event::CollaboratorUpdated {
@@ -255,7 +255,14 @@ message Envelope {
TaskTemplates task_templates = 206;
LinkedEditingRange linked_editing_range = 209;
- LinkedEditingRangeResponse linked_editing_range_response = 210; // current max
+ LinkedEditingRangeResponse linked_editing_range_response = 210;
+
+ AdvertiseContexts advertise_contexts = 211;
+ OpenContext open_context = 212;
+ OpenContextResponse open_context_response = 213;
+ UpdateContext update_context = 214;
+ SynchronizeContexts synchronize_contexts = 215;
+ SynchronizeContextsResponse synchronize_contexts_response = 216; // current max
}
reserved 158 to 161;
@@ -2222,3 +2229,117 @@ message TaskSourceKind {
string name = 1;
}
}
+
+message ContextMessageStatus {
+ oneof variant {
+ Done done = 1;
+ Pending pending = 2;
+ Error error = 3;
+ }
+
+ message Done {}
+
+ message Pending {}
+
+ message Error {
+ string message = 1;
+ }
+}
+
+message ContextMessage {
+ LamportTimestamp id = 1;
+ Anchor start = 2;
+ LanguageModelRole role = 3;
+ ContextMessageStatus status = 4;
+}
+
+message SlashCommandOutputSection {
+ AnchorRange range = 1;
+ string icon_name = 2;
+ string label = 3;
+}
+
+message ContextOperation {
+ oneof variant {
+ InsertMessage insert_message = 1;
+ UpdateMessage update_message = 2;
+ UpdateSummary update_summary = 3;
+ SlashCommandFinished slash_command_finished = 4;
+ BufferOperation buffer_operation = 5;
+ }
+
+ message InsertMessage {
+ ContextMessage message = 1;
+ repeated VectorClockEntry version = 2;
+ }
+
+ message UpdateMessage {
+ LamportTimestamp message_id = 1;
+ LanguageModelRole role = 2;
+ ContextMessageStatus status = 3;
+ LamportTimestamp timestamp = 4;
+ repeated VectorClockEntry version = 5;
+ }
+
+ message UpdateSummary {
+ string summary = 1;
+ bool done = 2;
+ LamportTimestamp timestamp = 3;
+ repeated VectorClockEntry version = 4;
+ }
+
+ message SlashCommandFinished {
+ LamportTimestamp id = 1;
+ AnchorRange output_range = 2;
+ repeated SlashCommandOutputSection sections = 3;
+ repeated VectorClockEntry version = 4;
+ }
+
+ message BufferOperation {
+ Operation operation = 1;
+ }
+}
+
+message Context {
+ repeated ContextOperation operations = 1;
+}
+
+message ContextMetadata {
+ string context_id = 1;
+ optional string summary = 2;
+}
+
+message AdvertiseContexts {
+ uint64 project_id = 1;
+ repeated ContextMetadata contexts = 2;
+}
+
+message OpenContext {
+ uint64 project_id = 1;
+ string context_id = 2;
+}
+
+message OpenContextResponse {
+ Context context = 1;
+}
+
+message UpdateContext {
+ uint64 project_id = 1;
+ string context_id = 2;
+ ContextOperation operation = 3;
+}
+
+message ContextVersion {
+ string context_id = 1;
+ repeated VectorClockEntry context_version = 2;
+ repeated VectorClockEntry buffer_version = 3;
+}
+
+message SynchronizeContexts {
+ uint64 project_id = 1;
+ repeated ContextVersion contexts = 2;
+}
+
+message SynchronizeContextsResponse {
+ repeated ContextVersion contexts = 1;
+}
@@ -337,7 +337,13 @@ messages!(
(OpenNewBuffer, Foreground),
(RestartLanguageServers, Foreground),
(LinkedEditingRange, Background),
- (LinkedEditingRangeResponse, Background)
+ (LinkedEditingRangeResponse, Background),
+ (AdvertiseContexts, Foreground),
+ (OpenContext, Foreground),
+ (OpenContextResponse, Foreground),
+ (UpdateContext, Foreground),
+ (SynchronizeContexts, Foreground),
+ (SynchronizeContextsResponse, Foreground),
);
request_messages!(
@@ -449,7 +455,9 @@ request_messages!(
(DeleteDevServerProject, Ack),
(RegenerateDevServerToken, RegenerateDevServerTokenResponse),
(RenameDevServer, Ack),
- (RestartLanguageServers, Ack)
+ (RestartLanguageServers, Ack),
+ (OpenContext, OpenContextResponse),
+ (SynchronizeContexts, SynchronizeContextsResponse),
);
entity_messages!(
@@ -511,6 +519,10 @@ entity_messages!(
UpdateWorktree,
UpdateWorktreeSettings,
LspExtExpandMacro,
+ AdvertiseContexts,
+ OpenContext,
+ UpdateContext,
+ SynchronizeContexts,
);
entity_messages!(
@@ -1,12 +1,15 @@
+use std::fmt::Debug;
+
use clock::ReplicaId;
+use collections::{BTreeMap, HashSet};
pub struct Network<T: Clone, R: rand::Rng> {
- inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
- all_messages: Vec<T>,
+ inboxes: BTreeMap<ReplicaId, Vec<Envelope<T>>>,
+ disconnected_peers: HashSet<ReplicaId>,
rng: R,
}
-#[derive(Clone)]
+#[derive(Clone, Debug)]
struct Envelope<T: Clone> {
message: T,
}
@@ -14,8 +17,8 @@ struct Envelope<T: Clone> {
impl<T: Clone, R: rand::Rng> Network<T, R> {
pub fn new(rng: R) -> Self {
Network {
- inboxes: Default::default(),
- all_messages: Vec::new(),
+ inboxes: BTreeMap::default(),
+ disconnected_peers: HashSet::default(),
rng,
}
}
@@ -24,6 +27,24 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
self.inboxes.insert(id, Vec::new());
}
+ pub fn disconnect_peer(&mut self, id: ReplicaId) {
+ self.disconnected_peers.insert(id);
+ self.inboxes.get_mut(&id).unwrap().clear();
+ }
+
+ pub fn reconnect_peer(&mut self, id: ReplicaId, replicate_from: ReplicaId) {
+ assert!(self.disconnected_peers.remove(&id));
+ self.replicate(replicate_from, id);
+ }
+
+ pub fn is_disconnected(&self, id: ReplicaId) -> bool {
+ self.disconnected_peers.contains(&id)
+ }
+
+ pub fn contains_disconnected_peers(&self) -> bool {
+ !self.disconnected_peers.is_empty()
+ }
+
pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
self.inboxes
.insert(new_replica_id, self.inboxes[&old_replica_id].clone());
@@ -34,8 +55,13 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
}
pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
+ // Drop messages from disconnected peers.
+ if self.disconnected_peers.contains(&sender) {
+ return;
+ }
+
for (replica, inbox) in self.inboxes.iter_mut() {
- if *replica != sender {
+ if *replica != sender && !self.disconnected_peers.contains(replica) {
for message in &messages {
// Insert one or more duplicates of this message, potentially *before* the previous
// message sent by this peer to simulate out-of-order delivery.
@@ -51,7 +77,6 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
}
}
}
- self.all_messages.extend(messages);
}
pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
@@ -1265,6 +1265,10 @@ impl Buffer {
}
}
+ pub fn has_deferred_ops(&self) -> bool {
+ !self.deferred_ops.is_empty()
+ }
+
pub fn peek_undo_stack(&self) -> Option<&HistoryEntry> {
self.history.undo_stack.last()
}
@@ -1,6 +1,6 @@
use gpui::{svg, AnimationElement, Hsla, IntoElement, Rems, Transformation};
use serde::{Deserialize, Serialize};
-use strum::EnumIter;
+use strum::{EnumIter, EnumString, IntoStaticStr};
use crate::{prelude::*, Indicator};
@@ -90,7 +90,9 @@ impl IconSize {
}
}
-#[derive(Debug, PartialEq, Copy, Clone, EnumIter, Serialize, Deserialize)]
+#[derive(
+ Debug, Eq, PartialEq, Copy, Clone, EnumIter, EnumString, IntoStaticStr, Serialize, Deserialize,
+)]
pub enum IconName {
Ai,
ArrowCircle,