diff --git a/Cargo.lock b/Cargo.lock
index 9aa674cbb69aaa52df5466dda41d8b9c2d9be5b1..c0f6ef03c296306a73264461a8767ccd1b346c20 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5314,6 +5314,7 @@ dependencies = [
"serde_json",
"settings",
"supermaven",
+ "sweep_ai",
"telemetry",
"theme",
"ui",
@@ -16590,6 +16591,33 @@ dependencies = [
"zeno",
]
+[[package]]
+name = "sweep_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "arrayvec",
+ "brotli",
+ "client",
+ "collections",
+ "edit_prediction",
+ "feature_flags",
+ "futures 0.3.31",
+ "gpui",
+ "http_client",
+ "indoc",
+ "language",
+ "project",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "tree-sitter-rust",
+ "util",
+ "workspace",
+ "zlog",
+]
+
[[package]]
name = "symphonia"
version = "0.5.5"
@@ -21316,6 +21344,7 @@ dependencies = [
"snippets_ui",
"supermaven",
"svg_preview",
+ "sweep_ai",
"sysinfo 0.37.2",
"system_specs",
"tab_switcher",
diff --git a/Cargo.toml b/Cargo.toml
index be56964f753cded4b1e054583b989f798c3ca1e3..e74647c6320f149d8eadad08ff3624859fe76624 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -165,6 +165,7 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
+ "crates/sweep_ai",
"crates/codestral",
"crates/svg_preview",
"crates/system_specs",
@@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree" }
supermaven = { path = "crates/supermaven" }
supermaven_api = { path = "crates/supermaven_api" }
+sweep_ai = { path = "crates/sweep_ai" }
codestral = { path = "crates/codestral" }
system_specs = { path = "crates/system_specs" }
tab_switcher = { path = "crates/tab_switcher" }
@@ -478,6 +480,7 @@ bitflags = "2.6.0"
blade-graphics = { version = "0.7.0" }
blade-macros = { version = "0.3.0" }
blade-util = { version = "0.3.0" }
+brotli = "8.0.2"
bytes = "1.0"
cargo_metadata = "0.19"
cargo_toml = "0.21"
diff --git a/assets/icons/sweep_ai.svg b/assets/icons/sweep_ai.svg
new file mode 100644
index 0000000000000000000000000000000000000000..c78d12727d78ddcc2a86bdb3e46349752cadaf7d
--- /dev/null
+++ b/assets/icons/sweep_ai.svg
@@ -0,0 +1 @@
+
diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs
index da1543a2790599fbe590f4e29d3594588bd2f351..6396b68cbc5f805466618bd460f9ed46ce05d086 100644
--- a/crates/agent_ui/src/agent_ui.rs
+++ b/crates/agent_ui/src/agent_ui.rs
@@ -346,7 +346,9 @@ fn update_command_palette_filter(cx: &mut App) {
filter.show_namespace("supermaven");
filter.show_action_types(edit_prediction_actions.iter());
}
- EditPredictionProvider::Zed | EditPredictionProvider::Codestral => {
+ EditPredictionProvider::Zed
+ | EditPredictionProvider::Codestral
+ | EditPredictionProvider::Experimental(_) => {
filter.show_namespace("edit_prediction");
filter.hide_namespace("copilot");
filter.hide_namespace("supermaven");
diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml
index 189db7f7bac3eaea36a154424c4e7702f1387d24..3ed3d9411510ad2d978b221d8cb3412465a66879 100644
--- a/crates/edit_prediction_button/Cargo.toml
+++ b/crates/edit_prediction_button/Cargo.toml
@@ -18,18 +18,19 @@ client.workspace = true
cloud_llm_client.workspace = true
codestral.workspace = true
copilot.workspace = true
+edit_prediction.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
gpui.workspace = true
indoc.workspace = true
-edit_prediction.workspace = true
language.workspace = true
paths.workspace = true
project.workspace = true
regex.workspace = true
settings.workspace = true
supermaven.workspace = true
+sweep_ai.workspace = true
telemetry.workspace = true
ui.workspace = true
workspace.workspace = true
diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs
index 685d408205863e5ad110a5e57891c0695f998cfb..51f228db76aaee5e286dd950c17dd01b303d29b8 100644
--- a/crates/edit_prediction_button/src/edit_prediction_button.rs
+++ b/crates/edit_prediction_button/src/edit_prediction_button.rs
@@ -18,12 +18,15 @@ use language::{
};
use project::DisableAiSettings;
use regex::Regex;
-use settings::{Settings, SettingsStore, update_settings_file};
+use settings::{
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file,
+};
use std::{
sync::{Arc, LazyLock},
time::Duration,
};
use supermaven::{AccountStatus, Supermaven};
+use sweep_ai::SweepFeatureFlag;
use ui::{
Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
@@ -78,7 +81,7 @@ impl Render for EditPredictionButton {
let all_language_settings = all_language_settings(None, cx);
- match all_language_settings.edit_predictions.provider {
+ match &all_language_settings.edit_predictions.provider {
EditPredictionProvider::None => div().hidden(),
EditPredictionProvider::Copilot => {
@@ -297,6 +300,15 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
+ EditPredictionProvider::Experimental(provider_name) => {
+ if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::()
+ {
+ div().child(Icon::new(IconName::SweepAi))
+ } else {
+ div()
+ }
+ }
EditPredictionProvider::Zed => {
let enabled = self.editor_enabled.unwrap_or(true);
@@ -525,7 +537,7 @@ impl EditPredictionButton {
set_completion_provider(fs.clone(), cx, provider);
})
}
- EditPredictionProvider::None => continue,
+ EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue,
};
}
}
diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs
index a0865773ac394722c113a43fe323de218b2f145a..da3b298751d9c1921d14722490e3cbc680292099 100644
--- a/crates/icons/src/icons.rs
+++ b/crates/icons/src/icons.rs
@@ -217,6 +217,7 @@ pub enum IconName {
SupermavenError,
SupermavenInit,
SwatchBook,
+ SweepAi,
Tab,
Terminal,
TerminalAlt,
diff --git a/crates/reqwest_client/src/reqwest_client.rs b/crates/reqwest_client/src/reqwest_client.rs
index c2a58877b32ab6049edc5b50f7ad025f0c83f46e..4213a239ec813f255139a97770a74608371fb73e 100644
--- a/crates/reqwest_client/src/reqwest_client.rs
+++ b/crates/reqwest_client/src/reqwest_client.rs
@@ -80,20 +80,22 @@ impl ReqwestClient {
}
}
+pub fn runtime() -> &'static tokio::runtime::Runtime {
+ RUNTIME.get_or_init(|| {
+ tokio::runtime::Builder::new_multi_thread()
+ // Since we now have two executors, let's try to keep our footprint small
+ .worker_threads(1)
+ .enable_all()
+ .build()
+ .expect("Failed to initialize HTTP client")
+ })
+}
+
impl From for ReqwestClient {
fn from(client: reqwest::Client) -> Self {
let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
log::debug!("no tokio runtime found, creating one for Reqwest...");
- let runtime = RUNTIME.get_or_init(|| {
- tokio::runtime::Builder::new_multi_thread()
- // Since we now have two executors, let's try to keep our footprint small
- .worker_threads(1)
- .enable_all()
- .build()
- .expect("Failed to initialize HTTP client")
- });
-
- runtime.handle().clone()
+ runtime().handle().clone()
});
Self {
client,
diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs
index fc11dd4956a50906951af8fa43a7dacc61568f70..ed70116862bbda6af22d4027a406535ae0c19d67 100644
--- a/crates/settings/src/settings_content/language.rs
+++ b/crates/settings/src/settings_content/language.rs
@@ -3,7 +3,7 @@ use std::num::NonZeroU32;
use collections::{HashMap, HashSet};
use gpui::{Modifiers, SharedString};
use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Serialize, de::Error as _};
use serde_with::skip_serializing_none;
use settings_macros::MergeFrom;
use std::sync::Arc;
@@ -68,9 +68,7 @@ pub struct FeaturesContent {
}
/// The provider that supplies edit predictions.
-#[derive(
- Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom,
-)]
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, JsonSchema, MergeFrom)]
#[serde(rename_all = "snake_case")]
pub enum EditPredictionProvider {
None,
@@ -79,6 +77,47 @@ pub enum EditPredictionProvider {
Supermaven,
Zed,
Codestral,
+ Experimental(&'static str),
+}
+
+pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
+
+impl<'de> Deserialize<'de> for EditPredictionProvider {
+ fn deserialize(deserializer: D) -> Result
+ where
+ D: serde::Deserializer<'de>,
+ {
+ #[derive(Deserialize)]
+ #[serde(rename_all = "snake_case")]
+ pub enum Content {
+ None,
+ Copilot,
+ Supermaven,
+ Zed,
+ Codestral,
+ Experimental(String),
+ }
+
+ Ok(match Content::deserialize(deserializer)? {
+ Content::None => EditPredictionProvider::None,
+ Content::Copilot => EditPredictionProvider::Copilot,
+ Content::Supermaven => EditPredictionProvider::Supermaven,
+ Content::Zed => EditPredictionProvider::Zed,
+ Content::Codestral => EditPredictionProvider::Codestral,
+ Content::Experimental(name) => {
+ if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ )
+ } else {
+ return Err(D::Error::custom(format!(
+ "Unknown experimental edit prediction provider: {}",
+ name
+ )));
+ }
+ }
+ })
+ }
}
impl EditPredictionProvider {
@@ -88,7 +127,8 @@ impl EditPredictionProvider {
EditPredictionProvider::None
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven
- | EditPredictionProvider::Codestral => false,
+ | EditPredictionProvider::Codestral
+ | EditPredictionProvider::Experimental(_) => false,
}
}
}
diff --git a/crates/sweep_ai/Cargo.toml b/crates/sweep_ai/Cargo.toml
new file mode 100644
index 0000000000000000000000000000000000000000..4edf7ea1bb6af9a6657ccfe310c0253b118ec2e7
--- /dev/null
+++ b/crates/sweep_ai/Cargo.toml
@@ -0,0 +1,43 @@
+[package]
+name = "sweep_ai"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+exclude = ["fixtures"]
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/sweep_ai.rs"
+doctest = false
+
+[dependencies]
+anyhow.workspace = true
+arrayvec.workspace = true
+brotli.workspace = true
+client.workspace = true
+collections.workspace = true
+edit_prediction.workspace = true
+feature_flags.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http_client.workspace = true
+language.workspace = true
+project.workspace = true
+release_channel.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
+workspace.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
+http_client = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
+language = { workspace = true, features = ["test-support"] }
+reqwest_client = { workspace = true, features = ["test-support"] }
+tree-sitter-rust.workspace = true
+workspace = { workspace = true, features = ["test-support"] }
+zlog.workspace = true
diff --git a/crates/sweep_ai/LICENSE-GPL b/crates/sweep_ai/LICENSE-GPL
new file mode 120000
index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4
--- /dev/null
+++ b/crates/sweep_ai/LICENSE-GPL
@@ -0,0 +1 @@
+../../LICENSE-GPL
\ No newline at end of file
diff --git a/crates/sweep_ai/src/api.rs b/crates/sweep_ai/src/api.rs
new file mode 100644
index 0000000000000000000000000000000000000000..edb392885e476e3924d285613af1f0a4e8be8599
--- /dev/null
+++ b/crates/sweep_ai/src/api.rs
@@ -0,0 +1,90 @@
+use std::{path::Path, sync::Arc};
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Clone, Serialize)]
+pub struct AutocompleteRequest {
+ pub debug_info: Arc,
+ pub repo_name: String,
+ pub branch: Option,
+ pub file_path: Arc,
+ pub file_contents: String,
+ pub recent_changes: String,
+ pub cursor_position: usize,
+ pub original_file_contents: String,
+ pub file_chunks: Vec,
+ pub retrieval_chunks: Vec,
+ pub recent_user_actions: Vec,
+ pub multiple_suggestions: bool,
+ pub privacy_mode_enabled: bool,
+ pub changes_above_cursor: bool,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct FileChunk {
+ pub file_path: String,
+ pub start_line: usize,
+ pub end_line: usize,
+ pub content: String,
+ pub timestamp: Option,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct RetrievalChunk {
+ pub file_path: String,
+ pub start_line: usize,
+ pub end_line: usize,
+ pub content: String,
+ pub timestamp: u64,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct UserAction {
+ pub action_type: ActionType,
+ pub line_number: usize,
+ pub offset: usize,
+ pub file_path: String,
+ pub timestamp: u64,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum ActionType {
+ CursorMovement,
+ InsertChar,
+ DeleteChar,
+ InsertSelection,
+ DeleteSelection,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct AutocompleteResponse {
+ pub autocomplete_id: String,
+ pub start_index: usize,
+ pub end_index: usize,
+ pub completion: String,
+ #[allow(dead_code)]
+ pub confidence: f64,
+ #[allow(dead_code)]
+ pub logprobs: Option,
+ #[allow(dead_code)]
+ pub finish_reason: Option,
+ #[allow(dead_code)]
+ pub elapsed_time_ms: u64,
+ #[allow(dead_code)]
+ #[serde(default, rename = "completions")]
+ pub additional_completions: Vec,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone, Deserialize)]
+pub struct AdditionalCompletion {
+ pub start_index: usize,
+ pub end_index: usize,
+ pub completion: String,
+ pub confidence: f64,
+ pub autocomplete_id: String,
+ pub logprobs: Option,
+ pub finish_reason: Option,
+}
diff --git a/crates/sweep_ai/src/sweep_ai.rs b/crates/sweep_ai/src/sweep_ai.rs
new file mode 100644
index 0000000000000000000000000000000000000000..e8a2522c0b34896ad09fd8a8d346e2ba31c9a1e7
--- /dev/null
+++ b/crates/sweep_ai/src/sweep_ai.rs
@@ -0,0 +1,776 @@
+mod api;
+
+use anyhow::{Context as _, Result};
+use arrayvec::ArrayVec;
+use client::telemetry;
+use collections::HashMap;
+use feature_flags::FeatureFlag;
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
+use http_client::{AsyncBody, Method};
+use language::{Anchor, Buffer, BufferSnapshot, EditPreview, ToOffset as _, ToPoint, text_diff};
+use project::Project;
+use release_channel::{AppCommitSha, AppVersion};
+use std::collections::{VecDeque, hash_map};
+use std::fmt::{self, Display};
+use std::mem;
+use std::{
+ cmp,
+ fmt::Write,
+ ops::Range,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use util::ResultExt;
+use util::rel_path::RelPath;
+use workspace::Workspace;
+
+use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
+
+const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+const MAX_EVENT_COUNT: usize = 16;
+
+const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+pub struct SweepFeatureFlag;
+
+impl FeatureFlag for SweepFeatureFlag {
+ const NAME: &str = "sweep-ai";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
+#[derive(Clone)]
+struct SweepAiGlobal(Entity);
+
+impl Global for SweepAiGlobal {}
+
+#[derive(Clone)]
+pub struct EditPrediction {
+ id: EditPredictionId,
+ path: Arc,
+ edits: Arc<[(Range, Arc)]>,
+ snapshot: BufferSnapshot,
+ edit_preview: EditPreview,
+}
+
+impl EditPrediction {
+ fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> {
+ edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+ }
+}
+
+impl fmt::Debug for EditPrediction {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("EditPrediction")
+ .field("path", &self.path)
+ .field("edits", &self.edits)
+ .finish_non_exhaustive()
+ }
+}
+
+#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(String);
+
+impl Display for EditPredictionId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+pub struct SweepAi {
+ projects: HashMap,
+ debug_info: Arc,
+}
+
+struct SweepAiProject {
+ events: VecDeque,
+ registered_buffers: HashMap,
+}
+
+impl SweepAi {
+ pub fn global(cx: &mut App) -> Option> {
+ cx.try_global::()
+ .map(|global| global.0.clone())
+ }
+
+ pub fn register(cx: &mut App) -> Entity {
+ Self::global(cx).unwrap_or_else(|| {
+ let entity = cx.new(|cx| Self::new(cx));
+ cx.set_global(SweepAiGlobal(entity.clone()));
+ entity
+ })
+ }
+
+ pub fn clear_history(&mut self) {
+ for sweep_ai_project in self.projects.values_mut() {
+ sweep_ai_project.events.clear();
+ }
+ }
+
+ fn new(cx: &mut Context) -> Self {
+ Self {
+ projects: HashMap::default(),
+ debug_info: format!(
+ "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
+ version = AppVersion::global(cx),
+ sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
+ os = telemetry::os_name(),
+ )
+ .into(),
+ }
+ }
+
+ fn get_or_init_sweep_ai_project(
+ &mut self,
+ project: &Entity,
+ cx: &mut Context,
+ ) -> &mut SweepAiProject {
+ let project_id = project.entity_id();
+ match self.projects.entry(project_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ cx.observe_release(project, move |this, _, _cx| {
+ this.projects.remove(&project_id);
+ })
+ .detach();
+ entry.insert(SweepAiProject {
+ events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+ registered_buffers: HashMap::default(),
+ })
+ }
+ }
+ }
+
+ fn push_event(sweep_ai_project: &mut SweepAiProject, event: Event) {
+ let events = &mut sweep_ai_project.events;
+
+ if let Some(Event::BufferChange {
+ new_snapshot: last_new_snapshot,
+ timestamp: last_timestamp,
+ ..
+ }) = events.back_mut()
+ {
+ // Coalesce edits for the same buffer when they happen one after the other.
+ let Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ timestamp,
+ } = &event;
+
+ if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+ && old_snapshot.remote_id() == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version
+ {
+ *last_new_snapshot = new_snapshot.clone();
+ *last_timestamp = *timestamp;
+ return;
+ }
+ }
+
+ if events.len() >= MAX_EVENT_COUNT {
+ // These are halved instead of popping to improve prompt caching.
+ events.drain(..MAX_EVENT_COUNT / 2);
+ }
+
+ events.push_back(event);
+ }
+
+ pub fn register_buffer(
+ &mut self,
+ buffer: &Entity,
+ project: &Entity,
+ cx: &mut Context,
+ ) {
+ let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
+ Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
+ }
+
+ fn register_buffer_impl<'a>(
+ sweep_ai_project: &'a mut SweepAiProject,
+ buffer: &Entity,
+ project: &Entity,
+ cx: &mut Context,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+ match sweep_ai_project.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ sweep_ai_project.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
+ }
+ }
+
+ pub fn request_completion(
+ &mut self,
+ workspace: &WeakEntity,
+ project: &Entity,
+ active_buffer: &Entity,
+ position: language::Anchor,
+ cx: &mut Context,
+ ) -> Task>> {
+ let snapshot = active_buffer.read(cx).snapshot();
+ let debug_info = self.debug_info.clone();
+ let full_path: Arc = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let project_file = project::File::from_dyn(snapshot.file());
+ let repo_name = project_file
+ .map(|file| file.worktree.read(cx).root_name_str())
+ .unwrap_or("untitled")
+ .into();
+ let offset = position.to_offset(&snapshot);
+
+ let project_state = self.get_or_init_sweep_ai_project(project, cx);
+ let events = project_state.events.clone();
+ let http_client = cx.http_client();
+
+ let Some(recent_buffers) = workspace
+ .read_with(cx, |workspace, cx| {
+ workspace
+ .recent_navigation_history_iter(cx)
+ .filter_map(|(project_path, _)| {
+ let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+
+ if active_buffer == &buffer {
+ None
+ } else {
+ Some(buffer.read(cx).snapshot())
+ }
+ })
+ .take(3)
+ .collect::>()
+ })
+ .log_err()
+ else {
+ return Task::ready(Ok(None));
+ };
+
+ let result = cx.background_spawn({
+ let full_path = full_path.clone();
+ async move {
+ let text = snapshot.text();
+
+ let mut recent_changes = String::new();
+
+ for event in events {
+ writeln!(&mut recent_changes, "{event}")?;
+ }
+
+ let file_chunks = recent_buffers
+ .into_iter()
+ .map(|snapshot| {
+ let end_point = language::Point::new(30, 0).min(snapshot.max_point());
+ FileChunk {
+ content: snapshot
+ .text_for_range(language::Point::zero()..end_point)
+ .collect(),
+ file_path: snapshot
+ .file()
+ .map(|f| f.path().as_unix_str())
+ .unwrap_or("untitled")
+ .to_string(),
+ start_line: 0,
+ end_line: end_point.row as usize,
+ timestamp: snapshot.file().and_then(|file| {
+ Some(
+ file.disk_state()
+ .mtime()?
+ .to_seconds_and_nanos_for_persistence()?
+ .0,
+ )
+ }),
+ }
+ })
+ .collect();
+
+ let request_body = AutocompleteRequest {
+ debug_info,
+ repo_name,
+ file_path: full_path.clone(),
+ file_contents: text.clone(),
+ original_file_contents: text,
+ cursor_position: offset,
+ recent_changes: recent_changes.clone(),
+ changes_above_cursor: true,
+ multiple_suggestions: false,
+ branch: None,
+ file_chunks,
+ retrieval_chunks: vec![],
+ recent_user_actions: vec![],
+ // TODO
+ privacy_mode_enabled: false,
+ };
+
+ let mut buf: Vec = Vec::new();
+ let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+ serde_json::to_writer(writer, &request_body)?;
+ let body: AsyncBody = buf.into();
+
+ let request = http_client::Request::builder()
+ .uri(SWEEP_API_URL)
+ .header("Content-Type", "application/json")
+ .header(
+ "Authorization",
+ format!("Bearer {}", std::env::var("SWEEP_TOKEN").unwrap()),
+ )
+ .header("Connection", "keep-alive")
+ .header("Content-Encoding", "br")
+ .method(Method::POST)
+ .body(body)?;
+
+ let mut response = http_client.send(request).await?;
+
+ let mut body: Vec = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+
+ let old_text = snapshot
+ .text_for_range(response.start_index..response.end_index)
+ .collect::();
+ let edits = text_diff(&old_text, &response.completion)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(response.start_index + range.start)
+ ..snapshot.anchor_before(response.start_index + range.end),
+ text,
+ )
+ })
+ .collect::>();
+
+ anyhow::Ok((response.autocomplete_id, edits, snapshot))
+ }
+ });
+
+ let buffer = active_buffer.clone();
+
+ cx.spawn(async move |_, cx| {
+ let (id, edits, old_snapshot) = result.await?;
+
+ if edits.is_empty() {
+ return anyhow::Ok(None);
+ }
+
+ let Some((edits, new_snapshot, preview_task)) =
+ buffer.read_with(cx, |buffer, cx| {
+ let new_snapshot = buffer.snapshot();
+
+ let edits: Arc<[(Range, Arc)]> =
+ edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
+ .into();
+ let preview_task = buffer.preview_edits(edits.clone(), cx);
+
+ Some((edits, new_snapshot, preview_task))
+ })?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ let prediction = EditPrediction {
+ id: EditPredictionId(id),
+ path: full_path,
+ edits,
+ snapshot: new_snapshot,
+ edit_preview: preview_task.await,
+ };
+
+ anyhow::Ok(Some(prediction))
+ })
+ }
+
+ fn report_changes_for_buffer(
+ &mut self,
+ buffer: &Entity,
+ project: &Entity,
+ cx: &mut Context,
+ ) -> BufferSnapshot {
+ let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
+
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version != registered_buffer.snapshot.version {
+ let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ Self::push_event(
+ sweep_ai_project,
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot: new_snapshot.clone(),
+ timestamp: Instant::now(),
+ },
+ );
+ }
+
+ new_snapshot
+ }
+}
+
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+#[derive(Clone)]
+pub enum Event {
+ BufferChange {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ timestamp: Instant,
+ },
+}
+
+impl Display for Event {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ ..
+ } => {
+ let old_path = old_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(RelPath::unix("untitled").unwrap());
+ let new_path = new_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(RelPath::unix("untitled").unwrap());
+ if old_path != new_path {
+ // TODO confirm how to do this for sweep
+ // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
+ }
+
+ let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+ if !diff.is_empty() {
+ write!(
+ f,
+ "File: {}:\n{}\n",
+ new_path.display(util::paths::PathStyle::Posix),
+ diff
+ )?
+ }
+
+ fmt::Result::Ok(())
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct CurrentEditPrediction {
+ buffer_id: EntityId,
+ completion: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+ fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
+ if self.buffer_id != old_completion.buffer_id {
+ return true;
+ }
+
+ let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
+ return true;
+ };
+ let Some(new_edits) = self.completion.interpolate(snapshot) else {
+ return false;
+ };
+
+ if old_edits.len() == 1 && new_edits.len() == 1 {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
+ } else {
+ true
+ }
+ }
+}
+
+struct PendingCompletion {
+ id: usize,
+ _task: Task<()>,
+}
+
+pub struct SweepAiEditPredictionProvider {
+ workspace: WeakEntity,
+ sweep_ai: Entity,
+ pending_completions: ArrayVec,
+ next_pending_completion_id: usize,
+ current_completion: Option,
+ last_request_timestamp: Instant,
+ project: Entity,
+}
+
+impl SweepAiEditPredictionProvider {
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+
+ pub fn new(
+ sweep_ai: Entity,
+ workspace: WeakEntity,
+ project: Entity,
+ ) -> Self {
+ Self {
+ sweep_ai,
+ pending_completions: ArrayVec::new(),
+ next_pending_completion_id: 0,
+ current_completion: None,
+ last_request_timestamp: Instant::now(),
+ project,
+ workspace,
+ }
+ }
+}
+
+impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
+ fn name() -> &'static str {
+ "zed-predict"
+ }
+
+ fn display_name() -> &'static str {
+ "Zed's Edit Predictions"
+ }
+
+ fn show_completions_in_menu() -> bool {
+ true
+ }
+
+ fn show_tab_accept_marker() -> bool {
+ true
+ }
+
+ fn is_enabled(
+ &self,
+ _buffer: &Entity,
+ _cursor_position: language::Anchor,
+ _cx: &App,
+ ) -> bool {
+ true
+ }
+
+ fn is_refreshing(&self) -> bool {
+ !self.pending_completions.is_empty()
+ }
+
+ fn refresh(
+ &mut self,
+ buffer: Entity,
+ position: language::Anchor,
+ _debounce: bool,
+ cx: &mut Context,
+ ) {
+ if let Some(current_completion) = self.current_completion.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if current_completion
+ .completion
+ .interpolate(&snapshot)
+ .is_some()
+ {
+ return;
+ }
+ }
+
+ let pending_completion_id = self.next_pending_completion_id;
+ self.next_pending_completion_id += 1;
+ let last_request_timestamp = self.last_request_timestamp;
+
+ let project = self.project.clone();
+ let workspace = self.workspace.clone();
+ let task = cx.spawn(async move |this, cx| {
+ if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
+ .checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
+
+ let completion_request = this.update(cx, |this, cx| {
+ this.last_request_timestamp = Instant::now();
+ this.sweep_ai.update(cx, |sweep_ai, cx| {
+ sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
+ })
+ });
+
+ let completion = match completion_request {
+ Ok(completion_request) => {
+ let completion_request = completion_request.await;
+ completion_request.map(|c| {
+ c.map(|completion| CurrentEditPrediction {
+ buffer_id: buffer.entity_id(),
+ completion,
+ })
+ })
+ }
+ Err(error) => Err(error),
+ };
+
+ let Some(new_completion) = completion
+ .context("edit prediction failed")
+ .log_err()
+ .flatten()
+ else {
+ this.update(cx, |this, cx| {
+ if this.pending_completions[0].id == pending_completion_id {
+ this.pending_completions.remove(0);
+ } else {
+ this.pending_completions.clear();
+ }
+
+ cx.notify();
+ })
+ .ok();
+ return;
+ };
+
+ this.update(cx, |this, cx| {
+ if this.pending_completions[0].id == pending_completion_id {
+ this.pending_completions.remove(0);
+ } else {
+ this.pending_completions.clear();
+ }
+
+ if let Some(old_completion) = this.current_completion.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if new_completion.should_replace_completion(old_completion, &snapshot) {
+ this.current_completion = Some(new_completion);
+ }
+ } else {
+ this.current_completion = Some(new_completion);
+ }
+
+ cx.notify();
+ })
+ .ok();
+ });
+
+ // We always maintain at most two pending completions. When we already
+ // have two, we replace the newest one.
+ if self.pending_completions.len() <= 1 {
+ self.pending_completions.push(PendingCompletion {
+ id: pending_completion_id,
+ _task: task,
+ });
+ } else if self.pending_completions.len() == 2 {
+ self.pending_completions.pop();
+ self.pending_completions.push(PendingCompletion {
+ id: pending_completion_id,
+ _task: task,
+ });
+ }
+ }
+
+ fn cycle(
+ &mut self,
+ _buffer: Entity,
+ _cursor_position: language::Anchor,
+ _direction: edit_prediction::Direction,
+ _cx: &mut Context,
+ ) {
+ // Right now we don't support cycling.
+ }
+
+ fn accept(&mut self, _cx: &mut Context) {
+ self.pending_completions.clear();
+ }
+
+ fn discard(&mut self, _cx: &mut Context) {
+ self.pending_completions.clear();
+ self.current_completion.take();
+ }
+
+ fn suggest(
+ &mut self,
+ buffer: &Entity,
+ cursor_position: language::Anchor,
+ cx: &mut Context,
+ ) -> Option {
+ let CurrentEditPrediction {
+ buffer_id,
+ completion,
+ ..
+ } = self.current_completion.as_mut()?;
+
+ // Invalidate previous completion if it was generated for a different buffer.
+ if *buffer_id != buffer.entity_id() {
+ self.current_completion.take();
+ return None;
+ }
+
+ let buffer = buffer.read(cx);
+ let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
+ self.current_completion.take();
+ return None;
+ };
+
+ let cursor_row = cursor_position.to_point(buffer).row;
+ let (closest_edit_ix, (closest_edit_range, _)) =
+ edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+ let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
+ let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+ cmp::min(distance_from_start, distance_from_end)
+ })?;
+
+ let mut edit_start_ix = closest_edit_ix;
+ for (range, _) in edits[..edit_start_ix].iter().rev() {
+ let distance_from_closest_edit =
+ closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_start_ix -= 1;
+ } else {
+ break;
+ }
+ }
+
+ let mut edit_end_ix = closest_edit_ix + 1;
+ for (range, _) in &edits[edit_end_ix..] {
+ let distance_from_closest_edit =
+ range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_end_ix += 1;
+ } else {
+ break;
+ }
+ }
+
+ Some(edit_prediction::EditPrediction::Local {
+ id: Some(completion.id.to_string().into()),
+ edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+ edit_preview: Some(completion.edit_preview.clone()),
+ })
+ }
+}
diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml
index ca81955e33b524fe27b8777566473e89a03a5558..79892fefdd7776e2fd7f99cbfa6caf24bb174a4b 100644
--- a/crates/zed/Cargo.toml
+++ b/crates/zed/Cargo.toml
@@ -133,6 +133,7 @@ snippet_provider.workspace = true
snippets_ui.workspace = true
supermaven.workspace = true
svg_preview.workspace = true
+sweep_ai.workspace = true
sysinfo.workspace = true
tab_switcher.workspace = true
task.workspace = true
diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs
index 74b6687f62c641ce4076778efa4369a45529f4f9..1723ca91f143c8529e14e24e0bdd85dc7b1c14d4 100644
--- a/crates/zed/src/zed/edit_prediction_registry.rs
+++ b/crates/zed/src/zed/edit_prediction_registry.rs
@@ -7,9 +7,10 @@ use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
-use settings::SettingsStore;
+use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
+use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
use ui::Window;
use zeta::ZetaEditPredictionProvider;
use zeta2::Zeta2FeatureFlag;
@@ -202,6 +203,38 @@ fn assign_edit_prediction_provider(
let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
+ EditPredictionProvider::Experimental(name) => {
+ if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::()
+ {
+ if let Some(project) = editor.project()
+ && let Some(workspace) = editor.workspace()
+ {
+ let sweep_ai = sweep_ai::SweepAi::register(cx);
+
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ {
+ sweep_ai.update(cx, |sweep_ai, cx| {
+ sweep_ai.register_buffer(buffer, project, cx);
+ });
+ }
+
+ let provider = cx.new(|_| {
+ sweep_ai::SweepAiEditPredictionProvider::new(
+ sweep_ai,
+ workspace.downgrade(),
+ project.clone(),
+ )
+ });
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ }
+ } else {
+ editor.set_edit_prediction_provider::(
+ None, window, cx,
+ );
+ }
+ }
EditPredictionProvider::Zed => {
if user_store.read(cx).current_user().is_some() {
let mut worktree = None;