Add experimental Sweep edit prediction provider (#42927)

Agus Zubiaga , Max Brunsfeld , and Ben Kunkle created

Only for staff

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

Cargo.lock                                                  |  29 
Cargo.toml                                                  |   3 
assets/icons/sweep_ai.svg                                   |   0 
crates/agent_ui/src/agent_ui.rs                             |   4 
crates/edit_prediction_button/Cargo.toml                    |   3 
crates/edit_prediction_button/src/edit_prediction_button.rs |  18 
crates/icons/src/icons.rs                                   |   1 
crates/reqwest_client/src/reqwest_client.rs                 |  22 
crates/settings/src/settings_content/language.rs            |  50 
crates/sweep_ai/Cargo.toml                                  |  43 
crates/sweep_ai/LICENSE-GPL                                 |   1 
crates/sweep_ai/src/api.rs                                  |  90 
crates/sweep_ai/src/sweep_ai.rs                             | 776 +++++++
crates/zed/Cargo.toml                                       |   1 
crates/zed/src/zed/edit_prediction_registry.rs              |  35 
15 files changed, 1,055 insertions(+), 21 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5314,6 +5314,7 @@ dependencies = [
  "serde_json",
  "settings",
  "supermaven",
+ "sweep_ai",
  "telemetry",
  "theme",
  "ui",
@@ -16590,6 +16591,33 @@ dependencies = [
  "zeno",
 ]
 
+[[package]]
+name = "sweep_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "arrayvec",
+ "brotli",
+ "client",
+ "collections",
+ "edit_prediction",
+ "feature_flags",
+ "futures 0.3.31",
+ "gpui",
+ "http_client",
+ "indoc",
+ "language",
+ "project",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "tree-sitter-rust",
+ "util",
+ "workspace",
+ "zlog",
+]
+
 [[package]]
 name = "symphonia"
 version = "0.5.5"
@@ -21316,6 +21344,7 @@ dependencies = [
  "snippets_ui",
  "supermaven",
  "svg_preview",
+ "sweep_ai",
  "sysinfo 0.37.2",
  "system_specs",
  "tab_switcher",

Cargo.toml 🔗

@@ -165,6 +165,7 @@ members = [
     "crates/sum_tree",
     "crates/supermaven",
     "crates/supermaven_api",
+    "crates/sweep_ai",
     "crates/codestral",
     "crates/svg_preview",
     "crates/system_specs",
@@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" }
 sum_tree = { path = "crates/sum_tree" }
 supermaven = { path = "crates/supermaven" }
 supermaven_api = { path = "crates/supermaven_api" }
+sweep_ai = { path = "crates/sweep_ai" }
 codestral = { path = "crates/codestral" }
 system_specs = { path = "crates/system_specs" }
 tab_switcher = { path = "crates/tab_switcher" }
@@ -478,6 +480,7 @@ bitflags = "2.6.0"
 blade-graphics = { version = "0.7.0" }
 blade-macros = { version = "0.3.0" }
 blade-util = { version = "0.3.0" }
+brotli = "8.0.2"
 bytes = "1.0"
 cargo_metadata = "0.19"
 cargo_toml = "0.21"

crates/agent_ui/src/agent_ui.rs 🔗

@@ -346,7 +346,9 @@ fn update_command_palette_filter(cx: &mut App) {
                     filter.show_namespace("supermaven");
                     filter.show_action_types(edit_prediction_actions.iter());
                 }
-                EditPredictionProvider::Zed | EditPredictionProvider::Codestral => {
+                EditPredictionProvider::Zed
+                | EditPredictionProvider::Codestral
+                | EditPredictionProvider::Experimental(_) => {
                     filter.show_namespace("edit_prediction");
                     filter.hide_namespace("copilot");
                     filter.hide_namespace("supermaven");

crates/edit_prediction_button/Cargo.toml 🔗

@@ -18,18 +18,19 @@ client.workspace = true
 cloud_llm_client.workspace = true
 codestral.workspace = true
 copilot.workspace = true
+edit_prediction.workspace = true
 editor.workspace = true
 feature_flags.workspace = true
 fs.workspace = true
 gpui.workspace = true
 indoc.workspace = true
-edit_prediction.workspace = true
 language.workspace = true
 paths.workspace = true
 project.workspace = true
 regex.workspace = true
 settings.workspace = true
 supermaven.workspace = true
+sweep_ai.workspace = true
 telemetry.workspace = true
 ui.workspace = true
 workspace.workspace = true

crates/edit_prediction_button/src/edit_prediction_button.rs 🔗

@@ -18,12 +18,15 @@ use language::{
 };
 use project::DisableAiSettings;
 use regex::Regex;
-use settings::{Settings, SettingsStore, update_settings_file};
+use settings::{
+    EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file,
+};
 use std::{
     sync::{Arc, LazyLock},
     time::Duration,
 };
 use supermaven::{AccountStatus, Supermaven};
+use sweep_ai::SweepFeatureFlag;
 use ui::{
     Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
     IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
@@ -78,7 +81,7 @@ impl Render for EditPredictionButton {
 
         let all_language_settings = all_language_settings(None, cx);
 
-        match all_language_settings.edit_predictions.provider {
+        match &all_language_settings.edit_predictions.provider {
             EditPredictionProvider::None => div().hidden(),
 
             EditPredictionProvider::Copilot => {
@@ -297,6 +300,15 @@ impl Render for EditPredictionButton {
                         .with_handle(self.popover_menu_handle.clone()),
                 )
             }
+            EditPredictionProvider::Experimental(provider_name) => {
+                if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+                    && cx.has_flag::<SweepFeatureFlag>()
+                {
+                    div().child(Icon::new(IconName::SweepAi))
+                } else {
+                    div()
+                }
+            }
 
             EditPredictionProvider::Zed => {
                 let enabled = self.editor_enabled.unwrap_or(true);
@@ -525,7 +537,7 @@ impl EditPredictionButton {
                             set_completion_provider(fs.clone(), cx, provider);
                         })
                     }
-                    EditPredictionProvider::None => continue,
+                    EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue,
                 };
             }
         }

crates/icons/src/icons.rs 🔗

@@ -217,6 +217,7 @@ pub enum IconName {
     SupermavenError,
     SupermavenInit,
     SwatchBook,
+    SweepAi,
     Tab,
     Terminal,
     TerminalAlt,

crates/reqwest_client/src/reqwest_client.rs 🔗

@@ -80,20 +80,22 @@ impl ReqwestClient {
     }
 }
 
+pub fn runtime() -> &'static tokio::runtime::Runtime {
+    RUNTIME.get_or_init(|| {
+        tokio::runtime::Builder::new_multi_thread()
+            // Since we now have two executors, let's try to keep our footprint small
+            .worker_threads(1)
+            .enable_all()
+            .build()
+            .expect("Failed to initialize HTTP client")
+    })
+}
+
 impl From<reqwest::Client> for ReqwestClient {
     fn from(client: reqwest::Client) -> Self {
         let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
             log::debug!("no tokio runtime found, creating one for Reqwest...");
-            let runtime = RUNTIME.get_or_init(|| {
-                tokio::runtime::Builder::new_multi_thread()
-                    // Since we now have two executors, let's try to keep our footprint small
-                    .worker_threads(1)
-                    .enable_all()
-                    .build()
-                    .expect("Failed to initialize HTTP client")
-            });
-
-            runtime.handle().clone()
+            runtime().handle().clone()
         });
         Self {
             client,

crates/settings/src/settings_content/language.rs 🔗

@@ -3,7 +3,7 @@ use std::num::NonZeroU32;
 use collections::{HashMap, HashSet};
 use gpui::{Modifiers, SharedString};
 use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Serialize, de::Error as _};
 use serde_with::skip_serializing_none;
 use settings_macros::MergeFrom;
 use std::sync::Arc;
@@ -68,9 +68,7 @@ pub struct FeaturesContent {
 }
 
 /// The provider that supplies edit predictions.
-#[derive(
-    Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom,
-)]
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, JsonSchema, MergeFrom)]
 #[serde(rename_all = "snake_case")]
 pub enum EditPredictionProvider {
     None,
@@ -79,6 +77,47 @@ pub enum EditPredictionProvider {
     Supermaven,
     Zed,
     Codestral,
+    Experimental(&'static str),
+}
+
+pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
+
+impl<'de> Deserialize<'de> for EditPredictionProvider {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        #[derive(Deserialize)]
+        #[serde(rename_all = "snake_case")]
+        pub enum Content {
+            None,
+            Copilot,
+            Supermaven,
+            Zed,
+            Codestral,
+            Experimental(String),
+        }
+
+        Ok(match Content::deserialize(deserializer)? {
+            Content::None => EditPredictionProvider::None,
+            Content::Copilot => EditPredictionProvider::Copilot,
+            Content::Supermaven => EditPredictionProvider::Supermaven,
+            Content::Zed => EditPredictionProvider::Zed,
+            Content::Codestral => EditPredictionProvider::Codestral,
+            Content::Experimental(name) => {
+                if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME {
+                    EditPredictionProvider::Experimental(
+                        EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+                    )
+                } else {
+                    return Err(D::Error::custom(format!(
+                        "Unknown experimental edit prediction provider: {}",
+                        name
+                    )));
+                }
+            }
+        })
+    }
 }
 
 impl EditPredictionProvider {
@@ -88,7 +127,8 @@ impl EditPredictionProvider {
             EditPredictionProvider::None
             | EditPredictionProvider::Copilot
             | EditPredictionProvider::Supermaven
-            | EditPredictionProvider::Codestral => false,
+            | EditPredictionProvider::Codestral
+            | EditPredictionProvider::Experimental(_) => false,
         }
     }
 }

crates/sweep_ai/Cargo.toml 🔗

@@ -0,0 +1,43 @@
+[package]
+name = "sweep_ai"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+exclude = ["fixtures"]
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/sweep_ai.rs"
+doctest = false
+
+[dependencies]
+anyhow.workspace = true
+arrayvec.workspace = true
+brotli.workspace = true
+client.workspace = true
+collections.workspace = true
+edit_prediction.workspace = true
+feature_flags.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http_client.workspace = true
+language.workspace = true
+project.workspace = true
+release_channel.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
+workspace.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
+http_client = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
+language = { workspace = true, features = ["test-support"] }
+reqwest_client = { workspace = true, features = ["test-support"] }
+tree-sitter-rust.workspace = true
+workspace = { workspace = true, features = ["test-support"] }
+zlog.workspace = true

crates/sweep_ai/src/api.rs 🔗

@@ -0,0 +1,90 @@
+use std::{path::Path, sync::Arc};
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Clone, Serialize)]
+pub struct AutocompleteRequest {
+    pub debug_info: Arc<str>,
+    pub repo_name: String,
+    pub branch: Option<String>,
+    pub file_path: Arc<Path>,
+    pub file_contents: String,
+    pub recent_changes: String,
+    pub cursor_position: usize,
+    pub original_file_contents: String,
+    pub file_chunks: Vec<FileChunk>,
+    pub retrieval_chunks: Vec<RetrievalChunk>,
+    pub recent_user_actions: Vec<UserAction>,
+    pub multiple_suggestions: bool,
+    pub privacy_mode_enabled: bool,
+    pub changes_above_cursor: bool,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct FileChunk {
+    pub file_path: String,
+    pub start_line: usize,
+    pub end_line: usize,
+    pub content: String,
+    pub timestamp: Option<u64>,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct RetrievalChunk {
+    pub file_path: String,
+    pub start_line: usize,
+    pub end_line: usize,
+    pub content: String,
+    pub timestamp: u64,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct UserAction {
+    pub action_type: ActionType,
+    pub line_number: usize,
+    pub offset: usize,
+    pub file_path: String,
+    pub timestamp: u64,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum ActionType {
+    CursorMovement,
+    InsertChar,
+    DeleteChar,
+    InsertSelection,
+    DeleteSelection,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct AutocompleteResponse {
+    pub autocomplete_id: String,
+    pub start_index: usize,
+    pub end_index: usize,
+    pub completion: String,
+    #[allow(dead_code)]
+    pub confidence: f64,
+    #[allow(dead_code)]
+    pub logprobs: Option<serde_json::Value>,
+    #[allow(dead_code)]
+    pub finish_reason: Option<String>,
+    #[allow(dead_code)]
+    pub elapsed_time_ms: u64,
+    #[allow(dead_code)]
+    #[serde(default, rename = "completions")]
+    pub additional_completions: Vec<AdditionalCompletion>,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone, Deserialize)]
+pub struct AdditionalCompletion {
+    pub start_index: usize,
+    pub end_index: usize,
+    pub completion: String,
+    pub confidence: f64,
+    pub autocomplete_id: String,
+    pub logprobs: Option<serde_json::Value>,
+    pub finish_reason: Option<String>,
+}

crates/sweep_ai/src/sweep_ai.rs 🔗

@@ -0,0 +1,776 @@
+mod api;
+
+use anyhow::{Context as _, Result};
+use arrayvec::ArrayVec;
+use client::telemetry;
+use collections::HashMap;
+use feature_flags::FeatureFlag;
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
+use http_client::{AsyncBody, Method};
+use language::{Anchor, Buffer, BufferSnapshot, EditPreview, ToOffset as _, ToPoint, text_diff};
+use project::Project;
+use release_channel::{AppCommitSha, AppVersion};
+use std::collections::{VecDeque, hash_map};
+use std::fmt::{self, Display};
+use std::mem;
+use std::{
+    cmp,
+    fmt::Write,
+    ops::Range,
+    path::Path,
+    sync::Arc,
+    time::{Duration, Instant},
+};
+use util::ResultExt;
+use util::rel_path::RelPath;
+use workspace::Workspace;
+
+use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
+
+const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+const MAX_EVENT_COUNT: usize = 16;
+
+const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+pub struct SweepFeatureFlag;
+
+impl FeatureFlag for SweepFeatureFlag {
+    const NAME: &str = "sweep-ai";
+
+    fn enabled_for_staff() -> bool {
+        false
+    }
+}
+
+#[derive(Clone)]
+struct SweepAiGlobal(Entity<SweepAi>);
+
+impl Global for SweepAiGlobal {}
+
+#[derive(Clone)]
+pub struct EditPrediction {
+    id: EditPredictionId,
+    path: Arc<Path>,
+    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+    snapshot: BufferSnapshot,
+    edit_preview: EditPreview,
+}
+
+impl EditPrediction {
+    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
+        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+    }
+}
+
+impl fmt::Debug for EditPrediction {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("EditPrediction")
+            .field("path", &self.path)
+            .field("edits", &self.edits)
+            .finish_non_exhaustive()
+    }
+}
+
+#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(String);
+
+impl Display for EditPredictionId {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
+pub struct SweepAi {
+    projects: HashMap<EntityId, SweepAiProject>,
+    debug_info: Arc<str>,
+}
+
+struct SweepAiProject {
+    events: VecDeque<Event>,
+    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+}
+
+impl SweepAi {
+    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
+        cx.try_global::<SweepAiGlobal>()
+            .map(|global| global.0.clone())
+    }
+
+    pub fn register(cx: &mut App) -> Entity<Self> {
+        Self::global(cx).unwrap_or_else(|| {
+            let entity = cx.new(|cx| Self::new(cx));
+            cx.set_global(SweepAiGlobal(entity.clone()));
+            entity
+        })
+    }
+
+    pub fn clear_history(&mut self) {
+        for sweep_ai_project in self.projects.values_mut() {
+            sweep_ai_project.events.clear();
+        }
+    }
+
+    fn new(cx: &mut Context<Self>) -> Self {
+        Self {
+            projects: HashMap::default(),
+            debug_info: format!(
+                "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
+                version = AppVersion::global(cx),
+                sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
+                os = telemetry::os_name(),
+            )
+            .into(),
+        }
+    }
+
+    fn get_or_init_sweep_ai_project(
+        &mut self,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> &mut SweepAiProject {
+        let project_id = project.entity_id();
+        match self.projects.entry(project_id) {
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+            hash_map::Entry::Vacant(entry) => {
+                cx.observe_release(project, move |this, _, _cx| {
+                    this.projects.remove(&project_id);
+                })
+                .detach();
+                entry.insert(SweepAiProject {
+                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+                    registered_buffers: HashMap::default(),
+                })
+            }
+        }
+    }
+
+    fn push_event(sweep_ai_project: &mut SweepAiProject, event: Event) {
+        let events = &mut sweep_ai_project.events;
+
+        if let Some(Event::BufferChange {
+            new_snapshot: last_new_snapshot,
+            timestamp: last_timestamp,
+            ..
+        }) = events.back_mut()
+        {
+            // Coalesce edits for the same buffer when they happen one after the other.
+            let Event::BufferChange {
+                old_snapshot,
+                new_snapshot,
+                timestamp,
+            } = &event;
+
+            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
+                && old_snapshot.version == last_new_snapshot.version
+            {
+                *last_new_snapshot = new_snapshot.clone();
+                *last_timestamp = *timestamp;
+                return;
+            }
+        }
+
+        if events.len() >= MAX_EVENT_COUNT {
+            // These are halved instead of popping to improve prompt caching.
+            events.drain(..MAX_EVENT_COUNT / 2);
+        }
+
+        events.push_back(event);
+    }
+
+    pub fn register_buffer(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) {
+        let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
+        Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
+    }
+
+    fn register_buffer_impl<'a>(
+        sweep_ai_project: &'a mut SweepAiProject,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> &'a mut RegisteredBuffer {
+        let buffer_id = buffer.entity_id();
+        match sweep_ai_project.registered_buffers.entry(buffer_id) {
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+            hash_map::Entry::Vacant(entry) => {
+                let snapshot = buffer.read(cx).snapshot();
+                let project_entity_id = project.entity_id();
+                entry.insert(RegisteredBuffer {
+                    snapshot,
+                    _subscriptions: [
+                        cx.subscribe(buffer, {
+                            let project = project.downgrade();
+                            move |this, buffer, event, cx| {
+                                if let language::BufferEvent::Edited = event
+                                    && let Some(project) = project.upgrade()
+                                {
+                                    this.report_changes_for_buffer(&buffer, &project, cx);
+                                }
+                            }
+                        }),
+                        cx.observe_release(buffer, move |this, _buffer, _cx| {
+                            let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
+                            else {
+                                return;
+                            };
+                            sweep_ai_project.registered_buffers.remove(&buffer_id);
+                        }),
+                    ],
+                })
+            }
+        }
+    }
+
+    pub fn request_completion(
+        &mut self,
+        workspace: &WeakEntity<Workspace>,
+        project: &Entity<Project>,
+        active_buffer: &Entity<Buffer>,
+        position: language::Anchor,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<Option<EditPrediction>>> {
+        let snapshot = active_buffer.read(cx).snapshot();
+        let debug_info = self.debug_info.clone();
+        let full_path: Arc<Path> = snapshot
+            .file()
+            .map(|file| file.full_path(cx))
+            .unwrap_or_else(|| "untitled".into())
+            .into();
+
+        let project_file = project::File::from_dyn(snapshot.file());
+        let repo_name = project_file
+            .map(|file| file.worktree.read(cx).root_name_str())
+            .unwrap_or("untitled")
+            .into();
+        let offset = position.to_offset(&snapshot);
+
+        let project_state = self.get_or_init_sweep_ai_project(project, cx);
+        let events = project_state.events.clone();
+        let http_client = cx.http_client();
+
+        let Some(recent_buffers) = workspace
+            .read_with(cx, |workspace, cx| {
+                workspace
+                    .recent_navigation_history_iter(cx)
+                    .filter_map(|(project_path, _)| {
+                        let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+
+                        if active_buffer == &buffer {
+                            None
+                        } else {
+                            Some(buffer.read(cx).snapshot())
+                        }
+                    })
+                    .take(3)
+                    .collect::<Vec<_>>()
+            })
+            .log_err()
+        else {
+            return Task::ready(Ok(None));
+        };
+
+        let result = cx.background_spawn({
+            let full_path = full_path.clone();
+            async move {
+                let text = snapshot.text();
+
+                let mut recent_changes = String::new();
+
+                for event in events {
+                    writeln!(&mut recent_changes, "{event}")?;
+                }
+
+                let file_chunks = recent_buffers
+                    .into_iter()
+                    .map(|snapshot| {
+                        let end_point = language::Point::new(30, 0).min(snapshot.max_point());
+                        FileChunk {
+                            content: snapshot
+                                .text_for_range(language::Point::zero()..end_point)
+                                .collect(),
+                            file_path: snapshot
+                                .file()
+                                .map(|f| f.path().as_unix_str())
+                                .unwrap_or("untitled")
+                                .to_string(),
+                            start_line: 0,
+                            end_line: end_point.row as usize,
+                            timestamp: snapshot.file().and_then(|file| {
+                                Some(
+                                    file.disk_state()
+                                        .mtime()?
+                                        .to_seconds_and_nanos_for_persistence()?
+                                        .0,
+                                )
+                            }),
+                        }
+                    })
+                    .collect();
+
+                let request_body = AutocompleteRequest {
+                    debug_info,
+                    repo_name,
+                    file_path: full_path.clone(),
+                    file_contents: text.clone(),
+                    original_file_contents: text,
+                    cursor_position: offset,
+                    recent_changes: recent_changes.clone(),
+                    changes_above_cursor: true,
+                    multiple_suggestions: false,
+                    branch: None,
+                    file_chunks,
+                    retrieval_chunks: vec![],
+                    recent_user_actions: vec![],
+                    // TODO
+                    privacy_mode_enabled: false,
+                };
+
+                let mut buf: Vec<u8> = Vec::new();
+                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+                serde_json::to_writer(writer, &request_body)?;
+                let body: AsyncBody = buf.into();
+
+                let request = http_client::Request::builder()
+                    .uri(SWEEP_API_URL)
+                    .header("Content-Type", "application/json")
+                    .header(
+                        "Authorization",
+                        format!("Bearer {}", std::env::var("SWEEP_TOKEN").unwrap()),
+                    )
+                    .header("Connection", "keep-alive")
+                    .header("Content-Encoding", "br")
+                    .method(Method::POST)
+                    .body(body)?;
+
+                let mut response = http_client.send(request).await?;
+
+                let mut body: Vec<u8> = Vec::new();
+                response.body_mut().read_to_end(&mut body).await?;
+
+                if !response.status().is_success() {
+                    anyhow::bail!(
+                        "Request failed with status: {:?}\nBody: {}",
+                        response.status(),
+                        String::from_utf8_lossy(&body),
+                    );
+                };
+
+                let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+
+                let old_text = snapshot
+                    .text_for_range(response.start_index..response.end_index)
+                    .collect::<String>();
+                let edits = text_diff(&old_text, &response.completion)
+                    .into_iter()
+                    .map(|(range, text)| {
+                        (
+                            snapshot.anchor_after(response.start_index + range.start)
+                                ..snapshot.anchor_before(response.start_index + range.end),
+                            text,
+                        )
+                    })
+                    .collect::<Vec<_>>();
+
+                anyhow::Ok((response.autocomplete_id, edits, snapshot))
+            }
+        });
+
+        let buffer = active_buffer.clone();
+
+        cx.spawn(async move |_, cx| {
+            let (id, edits, old_snapshot) = result.await?;
+
+            if edits.is_empty() {
+                return anyhow::Ok(None);
+            }
+
+            let Some((edits, new_snapshot, preview_task)) =
+                buffer.read_with(cx, |buffer, cx| {
+                    let new_snapshot = buffer.snapshot();
+
+                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
+                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
+                            .into();
+                    let preview_task = buffer.preview_edits(edits.clone(), cx);
+
+                    Some((edits, new_snapshot, preview_task))
+                })?
+            else {
+                return anyhow::Ok(None);
+            };
+
+            let prediction = EditPrediction {
+                id: EditPredictionId(id),
+                path: full_path,
+                edits,
+                snapshot: new_snapshot,
+                edit_preview: preview_task.await,
+            };
+
+            anyhow::Ok(Some(prediction))
+        })
+    }
+
+    fn report_changes_for_buffer(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> BufferSnapshot {
+        let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
+        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
+
+        let new_snapshot = buffer.read(cx).snapshot();
+        if new_snapshot.version != registered_buffer.snapshot.version {
+            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+            Self::push_event(
+                sweep_ai_project,
+                Event::BufferChange {
+                    old_snapshot,
+                    new_snapshot: new_snapshot.clone(),
+                    timestamp: Instant::now(),
+                },
+            );
+        }
+
+        new_snapshot
+    }
+}
+
+struct RegisteredBuffer {
+    snapshot: BufferSnapshot,
+    _subscriptions: [gpui::Subscription; 2],
+}
+
+#[derive(Clone)]
+pub enum Event {
+    BufferChange {
+        old_snapshot: BufferSnapshot,
+        new_snapshot: BufferSnapshot,
+        timestamp: Instant,
+    },
+}
+
+impl Display for Event {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Event::BufferChange {
+                old_snapshot,
+                new_snapshot,
+                ..
+            } => {
+                let old_path = old_snapshot
+                    .file()
+                    .map(|f| f.path().as_ref())
+                    .unwrap_or(RelPath::unix("untitled").unwrap());
+                let new_path = new_snapshot
+                    .file()
+                    .map(|f| f.path().as_ref())
+                    .unwrap_or(RelPath::unix("untitled").unwrap());
+                if old_path != new_path {
+                    // TODO confirm how to do this for sweep
+                    // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
+                }
+
+                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+                if !diff.is_empty() {
+                    write!(
+                        f,
+                        "File: {}:\n{}\n",
+                        new_path.display(util::paths::PathStyle::Posix),
+                        diff
+                    )?
+                }
+
+                fmt::Result::Ok(())
+            }
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+struct CurrentEditPrediction {
+    buffer_id: EntityId,
+    completion: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
+        if self.buffer_id != old_completion.buffer_id {
+            return true;
+        }
+
+        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
+            return true;
+        };
+        let Some(new_edits) = self.completion.interpolate(snapshot) else {
+            return false;
+        };
+
+        if old_edits.len() == 1 && new_edits.len() == 1 {
+            let (old_range, old_text) = &old_edits[0];
+            let (new_range, new_text) = &new_edits[0];
+            new_range == old_range && new_text.starts_with(old_text.as_ref())
+        } else {
+            true
+        }
+    }
+}
+
+struct PendingCompletion {
+    id: usize,
+    _task: Task<()>,
+}
+
+pub struct SweepAiEditPredictionProvider {
+    workspace: WeakEntity<Workspace>,
+    sweep_ai: Entity<SweepAi>,
+    pending_completions: ArrayVec<PendingCompletion, 2>,
+    next_pending_completion_id: usize,
+    current_completion: Option<CurrentEditPrediction>,
+    last_request_timestamp: Instant,
+    project: Entity<Project>,
+}
+
+impl SweepAiEditPredictionProvider {
+    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+
+    pub fn new(
+        sweep_ai: Entity<SweepAi>,
+        workspace: WeakEntity<Workspace>,
+        project: Entity<Project>,
+    ) -> Self {
+        Self {
+            sweep_ai,
+            pending_completions: ArrayVec::new(),
+            next_pending_completion_id: 0,
+            current_completion: None,
+            last_request_timestamp: Instant::now(),
+            project,
+            workspace,
+        }
+    }
+}
+
+impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
+    fn name() -> &'static str {
+        "zed-predict"
+    }
+
+    fn display_name() -> &'static str {
+        "Zed's Edit Predictions"
+    }
+
+    fn show_completions_in_menu() -> bool {
+        true
+    }
+
+    fn show_tab_accept_marker() -> bool {
+        true
+    }
+
+    fn is_enabled(
+        &self,
+        _buffer: &Entity<Buffer>,
+        _cursor_position: language::Anchor,
+        _cx: &App,
+    ) -> bool {
+        true
+    }
+
+    fn is_refreshing(&self) -> bool {
+        !self.pending_completions.is_empty()
+    }
+
+    fn refresh(
+        &mut self,
+        buffer: Entity<Buffer>,
+        position: language::Anchor,
+        _debounce: bool,
+        cx: &mut Context<Self>,
+    ) {
+        if let Some(current_completion) = self.current_completion.as_ref() {
+            let snapshot = buffer.read(cx).snapshot();
+            if current_completion
+                .completion
+                .interpolate(&snapshot)
+                .is_some()
+            {
+                return;
+            }
+        }
+
+        let pending_completion_id = self.next_pending_completion_id;
+        self.next_pending_completion_id += 1;
+        let last_request_timestamp = self.last_request_timestamp;
+
+        let project = self.project.clone();
+        let workspace = self.workspace.clone();
+        let task = cx.spawn(async move |this, cx| {
+            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
+                .checked_duration_since(Instant::now())
+            {
+                cx.background_executor().timer(timeout).await;
+            }
+
+            let completion_request = this.update(cx, |this, cx| {
+                this.last_request_timestamp = Instant::now();
+                this.sweep_ai.update(cx, |sweep_ai, cx| {
+                    sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
+                })
+            });
+
+            let completion = match completion_request {
+                Ok(completion_request) => {
+                    let completion_request = completion_request.await;
+                    completion_request.map(|c| {
+                        c.map(|completion| CurrentEditPrediction {
+                            buffer_id: buffer.entity_id(),
+                            completion,
+                        })
+                    })
+                }
+                Err(error) => Err(error),
+            };
+
+            let Some(new_completion) = completion
+                .context("edit prediction failed")
+                .log_err()
+                .flatten()
+            else {
+                this.update(cx, |this, cx| {
+                    if this.pending_completions[0].id == pending_completion_id {
+                        this.pending_completions.remove(0);
+                    } else {
+                        this.pending_completions.clear();
+                    }
+
+                    cx.notify();
+                })
+                .ok();
+                return;
+            };
+
+            this.update(cx, |this, cx| {
+                if this.pending_completions[0].id == pending_completion_id {
+                    this.pending_completions.remove(0);
+                } else {
+                    this.pending_completions.clear();
+                }
+
+                if let Some(old_completion) = this.current_completion.as_ref() {
+                    let snapshot = buffer.read(cx).snapshot();
+                    if new_completion.should_replace_completion(old_completion, &snapshot) {
+                        this.current_completion = Some(new_completion);
+                    }
+                } else {
+                    this.current_completion = Some(new_completion);
+                }
+
+                cx.notify();
+            })
+            .ok();
+        });
+
+        // We always maintain at most two pending completions. When we already
+        // have two, we replace the newest one.
+        if self.pending_completions.len() <= 1 {
+            self.pending_completions.push(PendingCompletion {
+                id: pending_completion_id,
+                _task: task,
+            });
+        } else if self.pending_completions.len() == 2 {
+            self.pending_completions.pop();
+            self.pending_completions.push(PendingCompletion {
+                id: pending_completion_id,
+                _task: task,
+            });
+        }
+    }
+
+    fn cycle(
+        &mut self,
+        _buffer: Entity<Buffer>,
+        _cursor_position: language::Anchor,
+        _direction: edit_prediction::Direction,
+        _cx: &mut Context<Self>,
+    ) {
+        // Right now we don't support cycling.
+    }
+
+    fn accept(&mut self, _cx: &mut Context<Self>) {
+        self.pending_completions.clear();
+    }
+
+    fn discard(&mut self, _cx: &mut Context<Self>) {
+        self.pending_completions.clear();
+        self.current_completion.take();
+    }
+
+    fn suggest(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        cursor_position: language::Anchor,
+        cx: &mut Context<Self>,
+    ) -> Option<edit_prediction::EditPrediction> {
+        let CurrentEditPrediction {
+            buffer_id,
+            completion,
+            ..
+        } = self.current_completion.as_mut()?;
+
+        // Invalidate previous completion if it was generated for a different buffer.
+        if *buffer_id != buffer.entity_id() {
+            self.current_completion.take();
+            return None;
+        }
+
+        let buffer = buffer.read(cx);
+        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
+            self.current_completion.take();
+            return None;
+        };
+
+        let cursor_row = cursor_position.to_point(buffer).row;
+        let (closest_edit_ix, (closest_edit_range, _)) =
+            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
+                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+                cmp::min(distance_from_start, distance_from_end)
+            })?;
+
+        let mut edit_start_ix = closest_edit_ix;
+        for (range, _) in edits[..edit_start_ix].iter().rev() {
+            let distance_from_closest_edit =
+                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+            if distance_from_closest_edit <= 1 {
+                edit_start_ix -= 1;
+            } else {
+                break;
+            }
+        }
+
+        let mut edit_end_ix = closest_edit_ix + 1;
+        for (range, _) in &edits[edit_end_ix..] {
+            let distance_from_closest_edit =
+                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+            if distance_from_closest_edit <= 1 {
+                edit_end_ix += 1;
+            } else {
+                break;
+            }
+        }
+
+        Some(edit_prediction::EditPrediction::Local {
+            id: Some(completion.id.to_string().into()),
+            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+            edit_preview: Some(completion.edit_preview.clone()),
+        })
+    }
+}

crates/zed/Cargo.toml 🔗

@@ -133,6 +133,7 @@ snippet_provider.workspace = true
 snippets_ui.workspace = true
 supermaven.workspace = true
 svg_preview.workspace = true
+sweep_ai.workspace = true
 sysinfo.workspace = true
 tab_switcher.workspace = true
 task.workspace = true

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -7,9 +7,10 @@ use feature_flags::FeatureFlagAppExt;
 use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
 use language::language_settings::{EditPredictionProvider, all_language_settings};
 use language_models::MistralLanguageModelProvider;
-use settings::SettingsStore;
+use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
 use std::{cell::RefCell, rc::Rc, sync::Arc};
 use supermaven::{Supermaven, SupermavenCompletionProvider};
+use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
 use ui::Window;
 use zeta::ZetaEditPredictionProvider;
 use zeta2::Zeta2FeatureFlag;
@@ -202,6 +203,38 @@ fn assign_edit_prediction_provider(
             let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
             editor.set_edit_prediction_provider(Some(provider), window, cx);
         }
+        EditPredictionProvider::Experimental(name) => {
+            if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+                && cx.has_flag::<SweepFeatureFlag>()
+            {
+                if let Some(project) = editor.project()
+                    && let Some(workspace) = editor.workspace()
+                {
+                    let sweep_ai = sweep_ai::SweepAi::register(cx);
+
+                    if let Some(buffer) = &singleton_buffer
+                        && buffer.read(cx).file().is_some()
+                    {
+                        sweep_ai.update(cx, |sweep_ai, cx| {
+                            sweep_ai.register_buffer(buffer, project, cx);
+                        });
+                    }
+
+                    let provider = cx.new(|_| {
+                        sweep_ai::SweepAiEditPredictionProvider::new(
+                            sweep_ai,
+                            workspace.downgrade(),
+                            project.clone(),
+                        )
+                    });
+                    editor.set_edit_prediction_provider(Some(provider), window, cx);
+                }
+            } else {
+                editor.set_edit_prediction_provider::<SweepAiEditPredictionProvider>(
+                    None, window, cx,
+                );
+            }
+        }
         EditPredictionProvider::Zed => {
             if user_store.read(cx).current_user().is_some() {
                 let mut worktree = None;