diff --git a/Cargo.lock b/Cargo.lock index 9aa674cbb69aaa52df5466dda41d8b9c2d9be5b1..c0f6ef03c296306a73264461a8767ccd1b346c20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5314,6 +5314,7 @@ dependencies = [ "serde_json", "settings", "supermaven", + "sweep_ai", "telemetry", "theme", "ui", @@ -16590,6 +16591,33 @@ dependencies = [ "zeno", ] +[[package]] +name = "sweep_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "arrayvec", + "brotli", + "client", + "collections", + "edit_prediction", + "feature_flags", + "futures 0.3.31", + "gpui", + "http_client", + "indoc", + "language", + "project", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "tree-sitter-rust", + "util", + "workspace", + "zlog", +] + [[package]] name = "symphonia" version = "0.5.5" @@ -21316,6 +21344,7 @@ dependencies = [ "snippets_ui", "supermaven", "svg_preview", + "sweep_ai", "sysinfo 0.37.2", "system_specs", "tab_switcher", diff --git a/Cargo.toml b/Cargo.toml index be56964f753cded4b1e054583b989f798c3ca1e3..e74647c6320f149d8eadad08ff3624859fe76624 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -165,6 +165,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/sweep_ai", "crates/codestral", "crates/svg_preview", "crates/system_specs", @@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree" } supermaven = { path = "crates/supermaven" } supermaven_api = { path = "crates/supermaven_api" } +sweep_ai = { path = "crates/sweep_ai" } codestral = { path = "crates/codestral" } system_specs = { path = "crates/system_specs" } tab_switcher = { path = "crates/tab_switcher" } @@ -478,6 +480,7 @@ bitflags = "2.6.0" blade-graphics = { version = "0.7.0" } blade-macros = { version = "0.3.0" } blade-util = { version = "0.3.0" } +brotli = "8.0.2" bytes = "1.0" cargo_metadata = "0.19" cargo_toml = "0.21" diff --git a/assets/icons/sweep_ai.svg b/assets/icons/sweep_ai.svg new file mode 100644 index 0000000000000000000000000000000000000000..c78d12727d78ddcc2a86bdb3e46349752cadaf7d --- /dev/null +++ b/assets/icons/sweep_ai.svg @@ -0,0 +1 @@ + diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index da1543a2790599fbe590f4e29d3594588bd2f351..6396b68cbc5f805466618bd460f9ed46ce05d086 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -346,7 +346,9 @@ fn update_command_palette_filter(cx: &mut App) { filter.show_namespace("supermaven"); filter.show_action_types(edit_prediction_actions.iter()); } - EditPredictionProvider::Zed | EditPredictionProvider::Codestral => { + EditPredictionProvider::Zed + | EditPredictionProvider::Codestral + | EditPredictionProvider::Experimental(_) => { filter.show_namespace("edit_prediction"); filter.hide_namespace("copilot"); filter.hide_namespace("supermaven"); diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index 189db7f7bac3eaea36a154424c4e7702f1387d24..3ed3d9411510ad2d978b221d8cb3412465a66879 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -18,18 +18,19 @@ client.workspace = true cloud_llm_client.workspace = true codestral.workspace = true copilot.workspace = true +edit_prediction.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true indoc.workspace = true -edit_prediction.workspace = true language.workspace = true paths.workspace = true project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true +sweep_ai.workspace = true telemetry.workspace = true ui.workspace = true workspace.workspace = true diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 685d408205863e5ad110a5e57891c0695f998cfb..51f228db76aaee5e286dd950c17dd01b303d29b8 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -18,12 +18,15 @@ use language::{ }; use project::DisableAiSettings; use regex::Regex; -use settings::{Settings, SettingsStore, update_settings_file}; +use settings::{ + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, +}; use std::{ sync::{Arc, LazyLock}, time::Duration, }; use supermaven::{AccountStatus, Supermaven}; +use sweep_ai::SweepFeatureFlag; use ui::{ Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton, IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*, @@ -78,7 +81,7 @@ impl Render for EditPredictionButton { let all_language_settings = all_language_settings(None, cx); - match all_language_settings.edit_predictions.provider { + match &all_language_settings.edit_predictions.provider { EditPredictionProvider::None => div().hidden(), EditPredictionProvider::Copilot => { @@ -297,6 +300,15 @@ impl Render for EditPredictionButton { .with_handle(self.popover_menu_handle.clone()), ) } + EditPredictionProvider::Experimental(provider_name) => { + if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + div().child(Icon::new(IconName::SweepAi)) + } else { + div() + } + } EditPredictionProvider::Zed => { let enabled = self.editor_enabled.unwrap_or(true); @@ -525,7 +537,7 @@ impl EditPredictionButton { set_completion_provider(fs.clone(), cx, provider); }) } - EditPredictionProvider::None => continue, + EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue, }; } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index a0865773ac394722c113a43fe323de218b2f145a..da3b298751d9c1921d14722490e3cbc680292099 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -217,6 +217,7 @@ pub enum IconName { SupermavenError, SupermavenInit, SwatchBook, + SweepAi, Tab, Terminal, TerminalAlt, diff --git a/crates/reqwest_client/src/reqwest_client.rs b/crates/reqwest_client/src/reqwest_client.rs index c2a58877b32ab6049edc5b50f7ad025f0c83f46e..4213a239ec813f255139a97770a74608371fb73e 100644 --- a/crates/reqwest_client/src/reqwest_client.rs +++ b/crates/reqwest_client/src/reqwest_client.rs @@ -80,20 +80,22 @@ impl ReqwestClient { } } +pub fn runtime() -> &'static tokio::runtime::Runtime { + RUNTIME.get_or_init(|| { + tokio::runtime::Builder::new_multi_thread() + // Since we now have two executors, let's try to keep our footprint small + .worker_threads(1) + .enable_all() + .build() + .expect("Failed to initialize HTTP client") + }) +} + impl From for ReqwestClient { fn from(client: reqwest::Client) -> Self { let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| { log::debug!("no tokio runtime found, creating one for Reqwest..."); - let runtime = RUNTIME.get_or_init(|| { - tokio::runtime::Builder::new_multi_thread() - // Since we now have two executors, let's try to keep our footprint small - .worker_threads(1) - .enable_all() - .build() - .expect("Failed to initialize HTTP client") - }); - - runtime.handle().clone() + runtime().handle().clone() }); Self { client, diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index fc11dd4956a50906951af8fa43a7dacc61568f70..ed70116862bbda6af22d4027a406535ae0c19d67 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -3,7 +3,7 @@ use std::num::NonZeroU32; use collections::{HashMap, HashSet}; use gpui::{Modifiers, SharedString}; use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::Error as _}; use serde_with::skip_serializing_none; use settings_macros::MergeFrom; use std::sync::Arc; @@ -68,9 +68,7 @@ pub struct FeaturesContent { } /// The provider that supplies edit predictions. -#[derive( - Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom, -)] +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, JsonSchema, MergeFrom)] #[serde(rename_all = "snake_case")] pub enum EditPredictionProvider { None, @@ -79,6 +77,47 @@ pub enum EditPredictionProvider { Supermaven, Zed, Codestral, + Experimental(&'static str), +} + +pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep"; + +impl<'de> Deserialize<'de> for EditPredictionProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(rename_all = "snake_case")] + pub enum Content { + None, + Copilot, + Supermaven, + Zed, + Codestral, + Experimental(String), + } + + Ok(match Content::deserialize(deserializer)? { + Content::None => EditPredictionProvider::None, + Content::Copilot => EditPredictionProvider::Copilot, + Content::Supermaven => EditPredictionProvider::Supermaven, + Content::Zed => EditPredictionProvider::Zed, + Content::Codestral => EditPredictionProvider::Codestral, + Content::Experimental(name) => { + if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME { + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) + } else { + return Err(D::Error::custom(format!( + "Unknown experimental edit prediction provider: {}", + name + ))); + } + } + }) + } } impl EditPredictionProvider { @@ -88,7 +127,8 @@ impl EditPredictionProvider { EditPredictionProvider::None | EditPredictionProvider::Copilot | EditPredictionProvider::Supermaven - | EditPredictionProvider::Codestral => false, + | EditPredictionProvider::Codestral + | EditPredictionProvider::Experimental(_) => false, } } } diff --git a/crates/sweep_ai/Cargo.toml b/crates/sweep_ai/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4edf7ea1bb6af9a6657ccfe310c0253b118ec2e7 --- /dev/null +++ b/crates/sweep_ai/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "sweep_ai" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" +exclude = ["fixtures"] + +[lints] +workspace = true + +[lib] +path = "src/sweep_ai.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +arrayvec.workspace = true +brotli.workspace = true +client.workspace = true +collections.workspace = true +edit_prediction.workspace = true +feature_flags.workspace = true +futures.workspace = true +gpui.workspace = true +http_client.workspace = true +language.workspace = true +project.workspace = true +release_channel.workspace = true +serde.workspace = true +serde_json.workspace = true +util.workspace = true +workspace.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } +http_client = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } +reqwest_client = { workspace = true, features = ["test-support"] } +tree-sitter-rust.workspace = true +workspace = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/sweep_ai/LICENSE-GPL b/crates/sweep_ai/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/sweep_ai/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/sweep_ai/src/api.rs b/crates/sweep_ai/src/api.rs new file mode 100644 index 0000000000000000000000000000000000000000..edb392885e476e3924d285613af1f0a4e8be8599 --- /dev/null +++ b/crates/sweep_ai/src/api.rs @@ -0,0 +1,90 @@ +use std::{path::Path, sync::Arc}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize)] +pub struct AutocompleteRequest { + pub debug_info: Arc, + pub repo_name: String, + pub branch: Option, + pub file_path: Arc, + pub file_contents: String, + pub recent_changes: String, + pub cursor_position: usize, + pub original_file_contents: String, + pub file_chunks: Vec, + pub retrieval_chunks: Vec, + pub recent_user_actions: Vec, + pub multiple_suggestions: bool, + pub privacy_mode_enabled: bool, + pub changes_above_cursor: bool, +} + +#[derive(Debug, Clone, Serialize)] +pub struct FileChunk { + pub file_path: String, + pub start_line: usize, + pub end_line: usize, + pub content: String, + pub timestamp: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct RetrievalChunk { + pub file_path: String, + pub start_line: usize, + pub end_line: usize, + pub content: String, + pub timestamp: u64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct UserAction { + pub action_type: ActionType, + pub line_number: usize, + pub offset: usize, + pub file_path: String, + pub timestamp: u64, +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum ActionType { + CursorMovement, + InsertChar, + DeleteChar, + InsertSelection, + DeleteSelection, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AutocompleteResponse { + pub autocomplete_id: String, + pub start_index: usize, + pub end_index: usize, + pub completion: String, + #[allow(dead_code)] + pub confidence: f64, + #[allow(dead_code)] + pub logprobs: Option, + #[allow(dead_code)] + pub finish_reason: Option, + #[allow(dead_code)] + pub elapsed_time_ms: u64, + #[allow(dead_code)] + #[serde(default, rename = "completions")] + pub additional_completions: Vec, +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Deserialize)] +pub struct AdditionalCompletion { + pub start_index: usize, + pub end_index: usize, + pub completion: String, + pub confidence: f64, + pub autocomplete_id: String, + pub logprobs: Option, + pub finish_reason: Option, +} diff --git a/crates/sweep_ai/src/sweep_ai.rs b/crates/sweep_ai/src/sweep_ai.rs new file mode 100644 index 0000000000000000000000000000000000000000..e8a2522c0b34896ad09fd8a8d346e2ba31c9a1e7 --- /dev/null +++ b/crates/sweep_ai/src/sweep_ai.rs @@ -0,0 +1,776 @@ +mod api; + +use anyhow::{Context as _, Result}; +use arrayvec::ArrayVec; +use client::telemetry; +use collections::HashMap; +use feature_flags::FeatureFlag; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity}; +use http_client::{AsyncBody, Method}; +use language::{Anchor, Buffer, BufferSnapshot, EditPreview, ToOffset as _, ToPoint, text_diff}; +use project::Project; +use release_channel::{AppCommitSha, AppVersion}; +use std::collections::{VecDeque, hash_map}; +use std::fmt::{self, Display}; +use std::mem; +use std::{ + cmp, + fmt::Write, + ops::Range, + path::Path, + sync::Arc, + time::{Duration, Instant}, +}; +use util::ResultExt; +use util::rel_path::RelPath; +use workspace::Workspace; + +use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk}; + +const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); +const MAX_EVENT_COUNT: usize = 16; + +const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; + +pub struct SweepFeatureFlag; + +impl FeatureFlag for SweepFeatureFlag { + const NAME: &str = "sweep-ai"; + + fn enabled_for_staff() -> bool { + false + } +} + +#[derive(Clone)] +struct SweepAiGlobal(Entity); + +impl Global for SweepAiGlobal {} + +#[derive(Clone)] +pub struct EditPrediction { + id: EditPredictionId, + path: Arc, + edits: Arc<[(Range, Arc)]>, + snapshot: BufferSnapshot, + edit_preview: EditPreview, +} + +impl EditPrediction { + fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { + edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) + } +} + +impl fmt::Debug for EditPrediction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EditPrediction") + .field("path", &self.path) + .field("edits", &self.edits) + .finish_non_exhaustive() + } +} + +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct EditPredictionId(String); + +impl Display for EditPredictionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct SweepAi { + projects: HashMap, + debug_info: Arc, +} + +struct SweepAiProject { + events: VecDeque, + registered_buffers: HashMap, +} + +impl SweepAi { + pub fn global(cx: &mut App) -> Option> { + cx.try_global::() + .map(|global| global.0.clone()) + } + + pub fn register(cx: &mut App) -> Entity { + Self::global(cx).unwrap_or_else(|| { + let entity = cx.new(|cx| Self::new(cx)); + cx.set_global(SweepAiGlobal(entity.clone())); + entity + }) + } + + pub fn clear_history(&mut self) { + for sweep_ai_project in self.projects.values_mut() { + sweep_ai_project.events.clear(); + } + } + + fn new(cx: &mut Context) -> Self { + Self { + projects: HashMap::default(), + debug_info: format!( + "Zed v{version} ({sha}) - OS: {os} - Zed v{version}", + version = AppVersion::global(cx), + sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()), + os = telemetry::os_name(), + ) + .into(), + } + } + + fn get_or_init_sweep_ai_project( + &mut self, + project: &Entity, + cx: &mut Context, + ) -> &mut SweepAiProject { + let project_id = project.entity_id(); + match self.projects.entry(project_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + cx.observe_release(project, move |this, _, _cx| { + this.projects.remove(&project_id); + }) + .detach(); + entry.insert(SweepAiProject { + events: VecDeque::with_capacity(MAX_EVENT_COUNT), + registered_buffers: HashMap::default(), + }) + } + } + } + + fn push_event(sweep_ai_project: &mut SweepAiProject, event: Event) { + let events = &mut sweep_ai_project.events; + + if let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + timestamp: last_timestamp, + .. + }) = events.back_mut() + { + // Coalesce edits for the same buffer when they happen one after the other. + let Event::BufferChange { + old_snapshot, + new_snapshot, + timestamp, + } = &event; + + if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL + && old_snapshot.remote_id() == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version + { + *last_new_snapshot = new_snapshot.clone(); + *last_timestamp = *timestamp; + return; + } + } + + if events.len() >= MAX_EVENT_COUNT { + // These are halved instead of popping to improve prompt caching. + events.drain(..MAX_EVENT_COUNT / 2); + } + + events.push_back(event); + } + + pub fn register_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) { + let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx); + Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); + } + + fn register_buffer_impl<'a>( + sweep_ai_project: &'a mut SweepAiProject, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + match sweep_ai_project.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + sweep_ai_project.registered_buffers.remove(&buffer_id); + }), + ], + }) + } + } + } + + pub fn request_completion( + &mut self, + workspace: &WeakEntity, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task>> { + let snapshot = active_buffer.read(cx).snapshot(); + let debug_info = self.debug_info.clone(); + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let project_file = project::File::from_dyn(snapshot.file()); + let repo_name = project_file + .map(|file| file.worktree.read(cx).root_name_str()) + .unwrap_or("untitled") + .into(); + let offset = position.to_offset(&snapshot); + + let project_state = self.get_or_init_sweep_ai_project(project, cx); + let events = project_state.events.clone(); + let http_client = cx.http_client(); + + let Some(recent_buffers) = workspace + .read_with(cx, |workspace, cx| { + workspace + .recent_navigation_history_iter(cx) + .filter_map(|(project_path, _)| { + let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; + + if active_buffer == &buffer { + None + } else { + Some(buffer.read(cx).snapshot()) + } + }) + .take(3) + .collect::>() + }) + .log_err() + else { + return Task::ready(Ok(None)); + }; + + let result = cx.background_spawn({ + let full_path = full_path.clone(); + async move { + let text = snapshot.text(); + + let mut recent_changes = String::new(); + + for event in events { + writeln!(&mut recent_changes, "{event}")?; + } + + let file_chunks = recent_buffers + .into_iter() + .map(|snapshot| { + let end_point = language::Point::new(30, 0).min(snapshot.max_point()); + FileChunk { + content: snapshot + .text_for_range(language::Point::zero()..end_point) + .collect(), + file_path: snapshot + .file() + .map(|f| f.path().as_unix_str()) + .unwrap_or("untitled") + .to_string(), + start_line: 0, + end_line: end_point.row as usize, + timestamp: snapshot.file().and_then(|file| { + Some( + file.disk_state() + .mtime()? + .to_seconds_and_nanos_for_persistence()? + .0, + ) + }), + } + }) + .collect(); + + let request_body = AutocompleteRequest { + debug_info, + repo_name, + file_path: full_path.clone(), + file_contents: text.clone(), + original_file_contents: text, + cursor_position: offset, + recent_changes: recent_changes.clone(), + changes_above_cursor: true, + multiple_suggestions: false, + branch: None, + file_chunks, + retrieval_chunks: vec![], + recent_user_actions: vec![], + // TODO + privacy_mode_enabled: false, + }; + + let mut buf: Vec = Vec::new(); + let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); + serde_json::to_writer(writer, &request_body)?; + let body: AsyncBody = buf.into(); + + let request = http_client::Request::builder() + .uri(SWEEP_API_URL) + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("Bearer {}", std::env::var("SWEEP_TOKEN").unwrap()), + ) + .header("Connection", "keep-alive") + .header("Content-Encoding", "br") + .method(Method::POST) + .body(body)?; + + let mut response = http_client.send(request).await?; + + let mut body: Vec = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let response: AutocompleteResponse = serde_json::from_slice(&body)?; + + let old_text = snapshot + .text_for_range(response.start_index..response.end_index) + .collect::(); + let edits = text_diff(&old_text, &response.completion) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(response.start_index + range.start) + ..snapshot.anchor_before(response.start_index + range.end), + text, + ) + }) + .collect::>(); + + anyhow::Ok((response.autocomplete_id, edits, snapshot)) + } + }); + + let buffer = active_buffer.clone(); + + cx.spawn(async move |_, cx| { + let (id, edits, old_snapshot) = result.await?; + + if edits.is_empty() { + return anyhow::Ok(None); + } + + let Some((edits, new_snapshot, preview_task)) = + buffer.read_with(cx, |buffer, cx| { + let new_snapshot = buffer.snapshot(); + + let edits: Arc<[(Range, Arc)]> = + edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)? + .into(); + let preview_task = buffer.preview_edits(edits.clone(), cx); + + Some((edits, new_snapshot, preview_task)) + })? + else { + return anyhow::Ok(None); + }; + + let prediction = EditPrediction { + id: EditPredictionId(id), + path: full_path, + edits, + snapshot: new_snapshot, + edit_preview: preview_task.await, + }; + + anyhow::Ok(Some(prediction)) + }) + } + + fn report_changes_for_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> BufferSnapshot { + let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx); + let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); + + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version != registered_buffer.snapshot.version { + let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + Self::push_event( + sweep_ai_project, + Event::BufferChange { + old_snapshot, + new_snapshot: new_snapshot.clone(), + timestamp: Instant::now(), + }, + ); + } + + new_snapshot + } +} + +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +#[derive(Clone)] +pub enum Event { + BufferChange { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + timestamp: Instant, + }, +} + +impl Display for Event { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + let old_path = old_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(RelPath::unix("untitled").unwrap()); + let new_path = new_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(RelPath::unix("untitled").unwrap()); + if old_path != new_path { + // TODO confirm how to do this for sweep + // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?; + } + + let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); + if !diff.is_empty() { + write!( + f, + "File: {}:\n{}\n", + new_path.display(util::paths::PathStyle::Posix), + diff + )? + } + + fmt::Result::Ok(()) + } + } + } +} + +#[derive(Debug, Clone)] +struct CurrentEditPrediction { + buffer_id: EntityId, + completion: EditPrediction, +} + +impl CurrentEditPrediction { + fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { + if self.buffer_id != old_completion.buffer_id { + return true; + } + + let Some(old_edits) = old_completion.completion.interpolate(snapshot) else { + return true; + }; + let Some(new_edits) = self.completion.interpolate(snapshot) else { + return false; + }; + + if old_edits.len() == 1 && new_edits.len() == 1 { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text.as_ref()) + } else { + true + } + } +} + +struct PendingCompletion { + id: usize, + _task: Task<()>, +} + +pub struct SweepAiEditPredictionProvider { + workspace: WeakEntity, + sweep_ai: Entity, + pending_completions: ArrayVec, + next_pending_completion_id: usize, + current_completion: Option, + last_request_timestamp: Instant, + project: Entity, +} + +impl SweepAiEditPredictionProvider { + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + + pub fn new( + sweep_ai: Entity, + workspace: WeakEntity, + project: Entity, + ) -> Self { + Self { + sweep_ai, + pending_completions: ArrayVec::new(), + next_pending_completion_id: 0, + current_completion: None, + last_request_timestamp: Instant::now(), + project, + workspace, + } + } +} + +impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider { + fn name() -> &'static str { + "zed-predict" + } + + fn display_name() -> &'static str { + "Zed's Edit Predictions" + } + + fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn is_enabled( + &self, + _buffer: &Entity, + _cursor_position: language::Anchor, + _cx: &App, + ) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + !self.pending_completions.is_empty() + } + + fn refresh( + &mut self, + buffer: Entity, + position: language::Anchor, + _debounce: bool, + cx: &mut Context, + ) { + if let Some(current_completion) = self.current_completion.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if current_completion + .completion + .interpolate(&snapshot) + .is_some() + { + return; + } + } + + let pending_completion_id = self.next_pending_completion_id; + self.next_pending_completion_id += 1; + let last_request_timestamp = self.last_request_timestamp; + + let project = self.project.clone(); + let workspace = self.workspace.clone(); + let task = cx.spawn(async move |this, cx| { + if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) + .checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } + + let completion_request = this.update(cx, |this, cx| { + this.last_request_timestamp = Instant::now(); + this.sweep_ai.update(cx, |sweep_ai, cx| { + sweep_ai.request_completion(&workspace, &project, &buffer, position, cx) + }) + }); + + let completion = match completion_request { + Ok(completion_request) => { + let completion_request = completion_request.await; + completion_request.map(|c| { + c.map(|completion| CurrentEditPrediction { + buffer_id: buffer.entity_id(), + completion, + }) + }) + } + Err(error) => Err(error), + }; + + let Some(new_completion) = completion + .context("edit prediction failed") + .log_err() + .flatten() + else { + this.update(cx, |this, cx| { + if this.pending_completions[0].id == pending_completion_id { + this.pending_completions.remove(0); + } else { + this.pending_completions.clear(); + } + + cx.notify(); + }) + .ok(); + return; + }; + + this.update(cx, |this, cx| { + if this.pending_completions[0].id == pending_completion_id { + this.pending_completions.remove(0); + } else { + this.pending_completions.clear(); + } + + if let Some(old_completion) = this.current_completion.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if new_completion.should_replace_completion(old_completion, &snapshot) { + this.current_completion = Some(new_completion); + } + } else { + this.current_completion = Some(new_completion); + } + + cx.notify(); + }) + .ok(); + }); + + // We always maintain at most two pending completions. When we already + // have two, we replace the newest one. + if self.pending_completions.len() <= 1 { + self.pending_completions.push(PendingCompletion { + id: pending_completion_id, + _task: task, + }); + } else if self.pending_completions.len() == 2 { + self.pending_completions.pop(); + self.pending_completions.push(PendingCompletion { + id: pending_completion_id, + _task: task, + }); + } + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: language::Anchor, + _direction: edit_prediction::Direction, + _cx: &mut Context, + ) { + // Right now we don't support cycling. + } + + fn accept(&mut self, _cx: &mut Context) { + self.pending_completions.clear(); + } + + fn discard(&mut self, _cx: &mut Context) { + self.pending_completions.clear(); + self.current_completion.take(); + } + + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) -> Option { + let CurrentEditPrediction { + buffer_id, + completion, + .. + } = self.current_completion.as_mut()?; + + // Invalidate previous completion if it was generated for a different buffer. + if *buffer_id != buffer.entity_id() { + self.current_completion.take(); + return None; + } + + let buffer = buffer.read(cx); + let Some(edits) = completion.interpolate(&buffer.snapshot()) else { + self.current_completion.take(); + return None; + }; + + let cursor_row = cursor_position.to_point(buffer).row; + let (closest_edit_ix, (closest_edit_range, _)) = + edits.iter().enumerate().min_by_key(|(_, (range, _))| { + let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); + let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); + cmp::min(distance_from_start, distance_from_end) + })?; + + let mut edit_start_ix = closest_edit_ix; + for (range, _) in edits[..edit_start_ix].iter().rev() { + let distance_from_closest_edit = + closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_start_ix -= 1; + } else { + break; + } + } + + let mut edit_end_ix = closest_edit_ix + 1; + for (range, _) in &edits[edit_end_ix..] { + let distance_from_closest_edit = + range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_end_ix += 1; + } else { + break; + } + } + + Some(edit_prediction::EditPrediction::Local { + id: Some(completion.id.to_string().into()), + edits: edits[edit_start_ix..edit_end_ix].to_vec(), + edit_preview: Some(completion.edit_preview.clone()), + }) + } +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index ca81955e33b524fe27b8777566473e89a03a5558..79892fefdd7776e2fd7f99cbfa6caf24bb174a4b 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -133,6 +133,7 @@ snippet_provider.workspace = true snippets_ui.workspace = true supermaven.workspace = true svg_preview.workspace = true +sweep_ai.workspace = true sysinfo.workspace = true tab_switcher.workspace = true task.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 74b6687f62c641ce4076778efa4369a45529f4f9..1723ca91f143c8529e14e24e0bdd85dc7b1c14d4 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -7,9 +7,10 @@ use feature_flags::FeatureFlagAppExt; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::MistralLanguageModelProvider; -use settings::SettingsStore; +use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore}; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; +use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag}; use ui::Window; use zeta::ZetaEditPredictionProvider; use zeta2::Zeta2FeatureFlag; @@ -202,6 +203,38 @@ fn assign_edit_prediction_provider( let provider = cx.new(|_| CodestralCompletionProvider::new(http_client)); editor.set_edit_prediction_provider(Some(provider), window, cx); } + EditPredictionProvider::Experimental(name) => { + if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + if let Some(project) = editor.project() + && let Some(workspace) = editor.workspace() + { + let sweep_ai = sweep_ai::SweepAi::register(cx); + + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + { + sweep_ai.update(cx, |sweep_ai, cx| { + sweep_ai.register_buffer(buffer, project, cx); + }); + } + + let provider = cx.new(|_| { + sweep_ai::SweepAiEditPredictionProvider::new( + sweep_ai, + workspace.downgrade(), + project.clone(), + ) + }); + editor.set_edit_prediction_provider(Some(provider), window, cx); + } + } else { + editor.set_edit_prediction_provider::( + None, window, cx, + ); + } + } EditPredictionProvider::Zed => { if user_store.read(cx).current_user().is_some() { let mut worktree = None;