zeta2: Merge Sweep and Zeta2 Providers (#43097)

Ben Kunkle and Max Brunsfeld created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

Cargo.lock                                                  |  32 
Cargo.toml                                                  |   2 
crates/edit_prediction_button/Cargo.toml                    |   2 
crates/edit_prediction_button/src/edit_prediction_button.rs |   2 
crates/sweep_ai/Cargo.toml                                  |  43 
crates/sweep_ai/LICENSE-GPL                                 |   1 
crates/sweep_ai/src/sweep_ai.rs                             | 784 -------
crates/zed/Cargo.toml                                       |   1 
crates/zed/src/zed/edit_prediction_registry.rs              |  71 
crates/zeta2/Cargo.toml                                     |   1 
crates/zeta2/src/provider.rs                                |  11 
crates/zeta2/src/sweep_ai.rs                                |  48 
crates/zeta2/src/zeta2.rs                                   | 358 ++
crates/zeta_cli/Cargo.toml                                  |   1 
crates/zeta_cli/src/evaluate.rs                             |  68 
crates/zeta_cli/src/main.rs                                 |   2 
crates/zeta_cli/src/predict.rs                              | 334 +-
17 files changed, 558 insertions(+), 1,203 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5314,13 +5314,13 @@ dependencies = [
  "serde_json",
  "settings",
  "supermaven",
- "sweep_ai",
  "telemetry",
  "theme",
  "ui",
  "workspace",
  "zed_actions",
  "zeta",
+ "zeta2",
 ]
 
 [[package]]
@@ -16590,33 +16590,6 @@ dependencies = [
  "zeno",
 ]
 
-[[package]]
-name = "sweep_ai"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "arrayvec",
- "brotli",
- "client",
- "collections",
- "edit_prediction",
- "feature_flags",
- "futures 0.3.31",
- "gpui",
- "http_client",
- "indoc",
- "language",
- "project",
- "release_channel",
- "reqwest_client",
- "serde",
- "serde_json",
- "tree-sitter-rust",
- "util",
- "workspace",
- "zlog",
-]
-
 [[package]]
 name = "symphonia"
 version = "0.5.5"
@@ -21343,7 +21316,6 @@ dependencies = [
  "snippets_ui",
  "supermaven",
  "svg_preview",
- "sweep_ai",
  "sysinfo 0.37.2",
  "system_specs",
  "tab_switcher",
@@ -21754,6 +21726,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "arrayvec",
+ "brotli",
  "chrono",
  "client",
  "clock",
@@ -21864,7 +21837,6 @@ dependencies = [
  "shellexpand 2.1.2",
  "smol",
  "soa-rs",
- "sweep_ai",
  "terminal_view",
  "toml 0.8.23",
  "util",

Cargo.toml 🔗

@@ -165,7 +165,6 @@ members = [
     "crates/sum_tree",
     "crates/supermaven",
     "crates/supermaven_api",
-    "crates/sweep_ai",
     "crates/codestral",
     "crates/svg_preview",
     "crates/system_specs",
@@ -399,7 +398,6 @@ streaming_diff = { path = "crates/streaming_diff" }
 sum_tree = { path = "crates/sum_tree" }
 supermaven = { path = "crates/supermaven" }
 supermaven_api = { path = "crates/supermaven_api" }
-sweep_ai = { path = "crates/sweep_ai" }
 codestral = { path = "crates/codestral" }
 system_specs = { path = "crates/system_specs" }
 tab_switcher = { path = "crates/tab_switcher" }

crates/edit_prediction_button/Cargo.toml 🔗

@@ -30,12 +30,12 @@ project.workspace = true
 regex.workspace = true
 settings.workspace = true
 supermaven.workspace = true
-sweep_ai.workspace = true
 telemetry.workspace = true
 ui.workspace = true
 workspace.workspace = true
 zed_actions.workspace = true
 zeta.workspace = true
+zeta2.workspace = true
 
 [dev-dependencies]
 copilot = { workspace = true, features = ["test-support"] }

crates/edit_prediction_button/src/edit_prediction_button.rs 🔗

@@ -28,7 +28,6 @@ use std::{
     time::Duration,
 };
 use supermaven::{AccountStatus, Supermaven};
-use sweep_ai::SweepFeatureFlag;
 use ui::{
     Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
     IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
@@ -39,6 +38,7 @@ use workspace::{
 };
 use zed_actions::OpenBrowser;
 use zeta::RateCompletions;
+use zeta2::SweepFeatureFlag;
 
 actions!(
     edit_prediction,

crates/sweep_ai/Cargo.toml 🔗

@@ -1,43 +0,0 @@
-[package]
-name = "sweep_ai"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-exclude = ["fixtures"]
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/sweep_ai.rs"
-doctest = false
-
-[dependencies]
-anyhow.workspace = true
-arrayvec.workspace = true
-brotli.workspace = true
-client.workspace = true
-collections.workspace = true
-edit_prediction.workspace = true
-feature_flags.workspace = true
-futures.workspace = true
-gpui.workspace = true
-http_client.workspace = true
-language.workspace = true
-project.workspace = true
-release_channel.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-util.workspace = true
-workspace.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
-http_client = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-reqwest_client = { workspace = true, features = ["test-support"] }
-tree-sitter-rust.workspace = true
-workspace = { workspace = true, features = ["test-support"] }
-zlog.workspace = true

crates/sweep_ai/src/sweep_ai.rs 🔗

@@ -1,784 +0,0 @@
-mod api;
-
-use anyhow::{Context as _, Result};
-use arrayvec::ArrayVec;
-use client::telemetry;
-use collections::HashMap;
-use feature_flags::FeatureFlag;
-use futures::AsyncReadExt as _;
-use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
-use http_client::{AsyncBody, Method};
-use language::{
-    Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
-};
-use project::{Project, ProjectPath};
-use release_channel::{AppCommitSha, AppVersion};
-use std::collections::{VecDeque, hash_map};
-use std::fmt::{self, Display};
-use std::mem;
-use std::{
-    cmp,
-    fmt::Write,
-    ops::Range,
-    path::Path,
-    sync::Arc,
-    time::{Duration, Instant},
-};
-use util::ResultExt;
-use util::rel_path::RelPath;
-use workspace::Workspace;
-
-use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
-
-const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
-const MAX_EVENT_COUNT: usize = 6;
-
-const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
-
-pub struct SweepFeatureFlag;
-
-impl FeatureFlag for SweepFeatureFlag {
-    const NAME: &str = "sweep-ai";
-}
-
-#[derive(Clone)]
-struct SweepAiGlobal(Entity<SweepAi>);
-
-impl Global for SweepAiGlobal {}
-
-#[derive(Clone)]
-pub struct EditPrediction {
-    pub id: EditPredictionId,
-    pub path: Arc<Path>,
-    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
-    pub snapshot: BufferSnapshot,
-    pub edit_preview: EditPreview,
-}
-
-impl EditPrediction {
-    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
-        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
-    }
-}
-
-impl fmt::Debug for EditPrediction {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.debug_struct("EditPrediction")
-            .field("path", &self.path)
-            .field("edits", &self.edits)
-            .finish_non_exhaustive()
-    }
-}
-
-#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(String);
-
-impl Display for EditPredictionId {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "{}", self.0)
-    }
-}
-
-pub struct SweepAi {
-    projects: HashMap<EntityId, SweepAiProject>,
-    debug_info: Arc<str>,
-    api_token: Option<String>,
-}
-
-struct SweepAiProject {
-    events: VecDeque<Event>,
-    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
-}
-
-impl SweepAi {
-    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
-        cx.try_global::<SweepAiGlobal>()
-            .map(|global| global.0.clone())
-    }
-
-    pub fn register(cx: &mut App) -> Entity<Self> {
-        Self::global(cx).unwrap_or_else(|| {
-            let entity = cx.new(|cx| Self::new(cx));
-            cx.set_global(SweepAiGlobal(entity.clone()));
-            entity
-        })
-    }
-
-    pub fn clear_history(&mut self) {
-        for sweep_ai_project in self.projects.values_mut() {
-            sweep_ai_project.events.clear();
-        }
-    }
-
-    pub fn new(cx: &mut Context<Self>) -> Self {
-        Self {
-            api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
-            projects: HashMap::default(),
-            debug_info: format!(
-                "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
-                version = AppVersion::global(cx),
-                sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
-                os = telemetry::os_name(),
-            )
-            .into(),
-        }
-    }
-
-    fn get_or_init_sweep_ai_project(
-        &mut self,
-        project: &Entity<Project>,
-        cx: &mut Context<Self>,
-    ) -> &mut SweepAiProject {
-        let project_id = project.entity_id();
-        match self.projects.entry(project_id) {
-            hash_map::Entry::Occupied(entry) => entry.into_mut(),
-            hash_map::Entry::Vacant(entry) => {
-                cx.observe_release(project, move |this, _, _cx| {
-                    this.projects.remove(&project_id);
-                })
-                .detach();
-                entry.insert(SweepAiProject {
-                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
-                    registered_buffers: HashMap::default(),
-                })
-            }
-        }
-    }
-
-    pub fn register_buffer(
-        &mut self,
-        buffer: &Entity<Buffer>,
-        project: &Entity<Project>,
-        cx: &mut Context<Self>,
-    ) {
-        let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
-        Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
-    }
-
-    fn register_buffer_impl<'a>(
-        sweep_ai_project: &'a mut SweepAiProject,
-        buffer: &Entity<Buffer>,
-        project: &Entity<Project>,
-        cx: &mut Context<Self>,
-    ) -> &'a mut RegisteredBuffer {
-        let buffer_id = buffer.entity_id();
-        match sweep_ai_project.registered_buffers.entry(buffer_id) {
-            hash_map::Entry::Occupied(entry) => entry.into_mut(),
-            hash_map::Entry::Vacant(entry) => {
-                let snapshot = buffer.read(cx).snapshot();
-                let project_entity_id = project.entity_id();
-                entry.insert(RegisteredBuffer {
-                    snapshot,
-                    _subscriptions: [
-                        cx.subscribe(buffer, {
-                            let project = project.downgrade();
-                            move |this, buffer, event, cx| {
-                                if let language::BufferEvent::Edited = event
-                                    && let Some(project) = project.upgrade()
-                                {
-                                    this.report_changes_for_buffer(&buffer, &project, cx);
-                                }
-                            }
-                        }),
-                        cx.observe_release(buffer, move |this, _buffer, _cx| {
-                            let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
-                            else {
-                                return;
-                            };
-                            sweep_ai_project.registered_buffers.remove(&buffer_id);
-                        }),
-                    ],
-                })
-            }
-        }
-    }
-
-    pub fn request_completion(
-        &mut self,
-        project: &Entity<Project>,
-        recent_buffers: impl Iterator<Item = ProjectPath>,
-        active_buffer: &Entity<Buffer>,
-        position: language::Anchor,
-        cx: &mut Context<Self>,
-    ) -> Task<Result<Option<EditPrediction>>> {
-        let snapshot = active_buffer.read(cx).snapshot();
-        let debug_info = self.debug_info.clone();
-        let Some(api_token) = self.api_token.clone() else {
-            return Task::ready(Ok(None));
-        };
-        let full_path: Arc<Path> = snapshot
-            .file()
-            .map(|file| file.full_path(cx))
-            .unwrap_or_else(|| "untitled".into())
-            .into();
-
-        let project_file = project::File::from_dyn(snapshot.file());
-        let repo_name = project_file
-            .map(|file| file.worktree.read(cx).root_name_str())
-            .unwrap_or("untitled")
-            .into();
-        let offset = position.to_offset(&snapshot);
-
-        let project_state = self.get_or_init_sweep_ai_project(project, cx);
-        let events = project_state.events.clone();
-        let http_client = cx.http_client();
-
-        let recent_buffer_snapshots = recent_buffers
-            .filter_map(|project_path| {
-                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
-                if active_buffer == &buffer {
-                    None
-                } else {
-                    Some(buffer.read(cx).snapshot())
-                }
-            })
-            .take(3)
-            .collect::<Vec<_>>();
-
-        let result = cx.background_spawn({
-            let full_path = full_path.clone();
-            async move {
-                let text = snapshot.text();
-
-                let mut recent_changes = String::new();
-
-                for event in events {
-                    writeln!(&mut recent_changes, "{event}")?;
-                }
-
-                let file_chunks = recent_buffer_snapshots
-                    .into_iter()
-                    .map(|snapshot| {
-                        let end_point = language::Point::new(30, 0).min(snapshot.max_point());
-                        FileChunk {
-                            content: snapshot
-                                .text_for_range(language::Point::zero()..end_point)
-                                .collect(),
-                            file_path: snapshot
-                                .file()
-                                .map(|f| f.path().as_unix_str())
-                                .unwrap_or("untitled")
-                                .to_string(),
-                            start_line: 0,
-                            end_line: end_point.row as usize,
-                            timestamp: snapshot.file().and_then(|file| {
-                                Some(
-                                    file.disk_state()
-                                        .mtime()?
-                                        .to_seconds_and_nanos_for_persistence()?
-                                        .0,
-                                )
-                            }),
-                        }
-                    })
-                    .collect();
-
-                eprintln!("{recent_changes}");
-
-                let request_body = AutocompleteRequest {
-                    debug_info,
-                    repo_name,
-                    file_path: full_path.clone(),
-                    file_contents: text.clone(),
-                    original_file_contents: text,
-                    cursor_position: offset,
-                    recent_changes: recent_changes.clone(),
-                    changes_above_cursor: true,
-                    multiple_suggestions: false,
-                    branch: None,
-                    file_chunks,
-                    retrieval_chunks: vec![],
-                    recent_user_actions: vec![],
-                    // TODO
-                    privacy_mode_enabled: false,
-                };
-
-                let mut buf: Vec<u8> = Vec::new();
-                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
-                serde_json::to_writer(writer, &request_body)?;
-                let body: AsyncBody = buf.into();
-
-                let request = http_client::Request::builder()
-                    .uri(SWEEP_API_URL)
-                    .header("Content-Type", "application/json")
-                    .header("Authorization", format!("Bearer {}", api_token))
-                    .header("Connection", "keep-alive")
-                    .header("Content-Encoding", "br")
-                    .method(Method::POST)
-                    .body(body)?;
-
-                let mut response = http_client.send(request).await?;
-
-                let mut body: Vec<u8> = Vec::new();
-                response.body_mut().read_to_end(&mut body).await?;
-
-                if !response.status().is_success() {
-                    anyhow::bail!(
-                        "Request failed with status: {:?}\nBody: {}",
-                        response.status(),
-                        String::from_utf8_lossy(&body),
-                    );
-                };
-
-                let response: AutocompleteResponse = serde_json::from_slice(&body)?;
-
-                let old_text = snapshot
-                    .text_for_range(response.start_index..response.end_index)
-                    .collect::<String>();
-                let edits = text_diff(&old_text, &response.completion)
-                    .into_iter()
-                    .map(|(range, text)| {
-                        (
-                            snapshot.anchor_after(response.start_index + range.start)
-                                ..snapshot.anchor_before(response.start_index + range.end),
-                            text,
-                        )
-                    })
-                    .collect::<Vec<_>>();
-
-                anyhow::Ok((response.autocomplete_id, edits, snapshot))
-            }
-        });
-
-        let buffer = active_buffer.clone();
-
-        cx.spawn(async move |_, cx| {
-            let (id, edits, old_snapshot) = result.await?;
-
-            if edits.is_empty() {
-                return anyhow::Ok(None);
-            }
-
-            let Some((edits, new_snapshot, preview_task)) =
-                buffer.read_with(cx, |buffer, cx| {
-                    let new_snapshot = buffer.snapshot();
-
-                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
-                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
-                            .into();
-                    let preview_task = buffer.preview_edits(edits.clone(), cx);
-
-                    Some((edits, new_snapshot, preview_task))
-                })?
-            else {
-                return anyhow::Ok(None);
-            };
-
-            let prediction = EditPrediction {
-                id: EditPredictionId(id),
-                path: full_path,
-                edits,
-                snapshot: new_snapshot,
-                edit_preview: preview_task.await,
-            };
-
-            anyhow::Ok(Some(prediction))
-        })
-    }
-
-    fn report_changes_for_buffer(
-        &mut self,
-        buffer: &Entity<Buffer>,
-        project: &Entity<Project>,
-        cx: &mut Context<Self>,
-    ) {
-        let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
-        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
-
-        let new_snapshot = buffer.read(cx).snapshot();
-        if new_snapshot.version == registered_buffer.snapshot.version {
-            return;
-        }
-
-        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
-        let end_edit_anchor = new_snapshot
-            .anchored_edits_since::<Point>(&old_snapshot.version)
-            .last()
-            .map(|(_, range)| range.end);
-        let events = &mut sweep_ai_project.events;
-
-        if let Some(Event::BufferChange {
-            new_snapshot: last_new_snapshot,
-            end_edit_anchor: last_end_edit_anchor,
-            ..
-        }) = events.back_mut()
-        {
-            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
-                == last_new_snapshot.remote_id()
-                && old_snapshot.version == last_new_snapshot.version;
-
-            let should_coalesce = is_next_snapshot_of_same_buffer
-                && end_edit_anchor
-                    .as_ref()
-                    .zip(last_end_edit_anchor.as_ref())
-                    .is_some_and(|(a, b)| {
-                        let a = a.to_point(&new_snapshot);
-                        let b = b.to_point(&new_snapshot);
-                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
-                    });
-
-            if should_coalesce {
-                *last_end_edit_anchor = end_edit_anchor;
-                *last_new_snapshot = new_snapshot;
-                return;
-            }
-        }
-
-        if events.len() >= MAX_EVENT_COUNT {
-            events.pop_front();
-        }
-
-        events.push_back(Event::BufferChange {
-            old_snapshot,
-            new_snapshot,
-            end_edit_anchor,
-        });
-    }
-}
-
-struct RegisteredBuffer {
-    snapshot: BufferSnapshot,
-    _subscriptions: [gpui::Subscription; 2],
-}
-
-#[derive(Clone)]
-pub enum Event {
-    BufferChange {
-        old_snapshot: BufferSnapshot,
-        new_snapshot: BufferSnapshot,
-        end_edit_anchor: Option<Anchor>,
-    },
-}
-
-impl Display for Event {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match self {
-            Event::BufferChange {
-                old_snapshot,
-                new_snapshot,
-                ..
-            } => {
-                let old_path = old_snapshot
-                    .file()
-                    .map(|f| f.path().as_ref())
-                    .unwrap_or(RelPath::unix("untitled").unwrap());
-                let new_path = new_snapshot
-                    .file()
-                    .map(|f| f.path().as_ref())
-                    .unwrap_or(RelPath::unix("untitled").unwrap());
-                if old_path != new_path {
-                    // TODO confirm how to do this for sweep
-                    // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
-                }
-
-                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
-                if !diff.is_empty() {
-                    write!(
-                        f,
-                        "File: {}:\n{}\n",
-                        new_path.display(util::paths::PathStyle::Posix),
-                        diff
-                    )?
-                }
-
-                fmt::Result::Ok(())
-            }
-        }
-    }
-}
-
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
-    buffer_id: EntityId,
-    completion: EditPrediction,
-}
-
-impl CurrentEditPrediction {
-    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
-        if self.buffer_id != old_completion.buffer_id {
-            return true;
-        }
-
-        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
-            return true;
-        };
-        let Some(new_edits) = self.completion.interpolate(snapshot) else {
-            return false;
-        };
-
-        if old_edits.len() == 1 && new_edits.len() == 1 {
-            let (old_range, old_text) = &old_edits[0];
-            let (new_range, new_text) = &new_edits[0];
-            new_range == old_range && new_text.starts_with(old_text.as_ref())
-        } else {
-            true
-        }
-    }
-}
-
-struct PendingCompletion {
-    id: usize,
-    _task: Task<()>,
-}
-
-pub struct SweepAiEditPredictionProvider {
-    workspace: WeakEntity<Workspace>,
-    sweep_ai: Entity<SweepAi>,
-    pending_completions: ArrayVec<PendingCompletion, 2>,
-    next_pending_completion_id: usize,
-    current_completion: Option<CurrentEditPrediction>,
-    last_request_timestamp: Instant,
-    project: Entity<Project>,
-}
-
-impl SweepAiEditPredictionProvider {
-    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-
-    pub fn new(
-        sweep_ai: Entity<SweepAi>,
-        workspace: WeakEntity<Workspace>,
-        project: Entity<Project>,
-    ) -> Self {
-        Self {
-            sweep_ai,
-            pending_completions: ArrayVec::new(),
-            next_pending_completion_id: 0,
-            current_completion: None,
-            last_request_timestamp: Instant::now(),
-            project,
-            workspace,
-        }
-    }
-}
-
-impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
-    fn name() -> &'static str {
-        "zed-predict"
-    }
-
-    fn display_name() -> &'static str {
-        "Zed's Edit Predictions"
-    }
-
-    fn show_completions_in_menu() -> bool {
-        true
-    }
-
-    fn show_tab_accept_marker() -> bool {
-        true
-    }
-
-    fn is_enabled(
-        &self,
-        _buffer: &Entity<Buffer>,
-        _cursor_position: language::Anchor,
-        cx: &App,
-    ) -> bool {
-        self.sweep_ai.read(cx).api_token.is_some()
-    }
-
-    fn is_refreshing(&self) -> bool {
-        !self.pending_completions.is_empty()
-    }
-
-    fn refresh(
-        &mut self,
-        buffer: Entity<Buffer>,
-        position: language::Anchor,
-        _debounce: bool,
-        cx: &mut Context<Self>,
-    ) {
-        if let Some(current_completion) = self.current_completion.as_ref() {
-            let snapshot = buffer.read(cx).snapshot();
-            if current_completion
-                .completion
-                .interpolate(&snapshot)
-                .is_some()
-            {
-                return;
-            }
-        }
-
-        let pending_completion_id = self.next_pending_completion_id;
-        self.next_pending_completion_id += 1;
-        let last_request_timestamp = self.last_request_timestamp;
-
-        let project = self.project.clone();
-        let workspace = self.workspace.clone();
-        let task = cx.spawn(async move |this, cx| {
-            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
-                .checked_duration_since(Instant::now())
-            {
-                cx.background_executor().timer(timeout).await;
-            }
-
-            let completion_request = this.update(cx, |this, cx| {
-                this.last_request_timestamp = Instant::now();
-
-                this.sweep_ai.update(cx, |sweep_ai, cx| {
-                    let Some(recent_buffers) = workspace
-                        .read_with(cx, |workspace, cx| {
-                            workspace.recent_navigation_history_iter(cx)
-                        })
-                        .log_err()
-                    else {
-                        return Task::ready(Ok(None));
-                    };
-                    sweep_ai.request_completion(
-                        &project,
-                        recent_buffers.map(move |(project_path, _)| project_path),
-                        &buffer,
-                        position,
-                        cx,
-                    )
-                })
-            });
-
-            let completion = match completion_request {
-                Ok(completion_request) => {
-                    let completion_request = completion_request.await;
-                    completion_request.map(|c| {
-                        c.map(|completion| CurrentEditPrediction {
-                            buffer_id: buffer.entity_id(),
-                            completion,
-                        })
-                    })
-                }
-                Err(error) => Err(error),
-            };
-
-            let Some(new_completion) = completion
-                .context("edit prediction failed")
-                .log_err()
-                .flatten()
-            else {
-                this.update(cx, |this, cx| {
-                    if this.pending_completions[0].id == pending_completion_id {
-                        this.pending_completions.remove(0);
-                    } else {
-                        this.pending_completions.clear();
-                    }
-
-                    cx.notify();
-                })
-                .ok();
-                return;
-            };
-
-            this.update(cx, |this, cx| {
-                if this.pending_completions[0].id == pending_completion_id {
-                    this.pending_completions.remove(0);
-                } else {
-                    this.pending_completions.clear();
-                }
-
-                if let Some(old_completion) = this.current_completion.as_ref() {
-                    let snapshot = buffer.read(cx).snapshot();
-                    if new_completion.should_replace_completion(old_completion, &snapshot) {
-                        this.current_completion = Some(new_completion);
-                    }
-                } else {
-                    this.current_completion = Some(new_completion);
-                }
-
-                cx.notify();
-            })
-            .ok();
-        });
-
-        // We always maintain at most two pending completions. When we already
-        // have two, we replace the newest one.
-        if self.pending_completions.len() <= 1 {
-            self.pending_completions.push(PendingCompletion {
-                id: pending_completion_id,
-                _task: task,
-            });
-        } else if self.pending_completions.len() == 2 {
-            self.pending_completions.pop();
-            self.pending_completions.push(PendingCompletion {
-                id: pending_completion_id,
-                _task: task,
-            });
-        }
-    }
-
-    fn cycle(
-        &mut self,
-        _buffer: Entity<Buffer>,
-        _cursor_position: language::Anchor,
-        _direction: edit_prediction::Direction,
-        _cx: &mut Context<Self>,
-    ) {
-        // Right now we don't support cycling.
-    }
-
-    fn accept(&mut self, _cx: &mut Context<Self>) {
-        self.pending_completions.clear();
-    }
-
-    fn discard(&mut self, _cx: &mut Context<Self>) {
-        self.pending_completions.clear();
-        self.current_completion.take();
-    }
-
-    fn suggest(
-        &mut self,
-        buffer: &Entity<Buffer>,
-        cursor_position: language::Anchor,
-        cx: &mut Context<Self>,
-    ) -> Option<edit_prediction::EditPrediction> {
-        let CurrentEditPrediction {
-            buffer_id,
-            completion,
-            ..
-        } = self.current_completion.as_mut()?;
-
-        // Invalidate previous completion if it was generated for a different buffer.
-        if *buffer_id != buffer.entity_id() {
-            self.current_completion.take();
-            return None;
-        }
-
-        let buffer = buffer.read(cx);
-        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
-            self.current_completion.take();
-            return None;
-        };
-
-        let cursor_row = cursor_position.to_point(buffer).row;
-        let (closest_edit_ix, (closest_edit_range, _)) =
-            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
-                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
-                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
-                cmp::min(distance_from_start, distance_from_end)
-            })?;
-
-        let mut edit_start_ix = closest_edit_ix;
-        for (range, _) in edits[..edit_start_ix].iter().rev() {
-            let distance_from_closest_edit =
-                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
-            if distance_from_closest_edit <= 1 {
-                edit_start_ix -= 1;
-            } else {
-                break;
-            }
-        }
-
-        let mut edit_end_ix = closest_edit_ix + 1;
-        for (range, _) in &edits[edit_end_ix..] {
-            let distance_from_closest_edit =
-                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
-            if distance_from_closest_edit <= 1 {
-                edit_end_ix += 1;
-            } else {
-                break;
-            }
-        }
-
-        Some(edit_prediction::EditPrediction::Local {
-            id: Some(completion.id.to_string().into()),
-            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
-            edit_preview: Some(completion.edit_preview.clone()),
-        })
-    }
-}

crates/zed/Cargo.toml 🔗

@@ -133,7 +133,6 @@ snippet_provider.workspace = true
 snippets_ui.workspace = true
 supermaven.workspace = true
 svg_preview.workspace = true
-sweep_ai.workspace = true
 sysinfo.workspace = true
 tab_switcher.workspace = true
 task.workspace = true

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -10,9 +10,9 @@ use language_models::MistralLanguageModelProvider;
 use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
 use std::{cell::RefCell, rc::Rc, sync::Arc};
 use supermaven::{Supermaven, SupermavenCompletionProvider};
-use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
 use ui::Window;
 use zeta::ZetaEditPredictionProvider;
+use zeta2::SweepFeatureFlag;
 use zeta2::Zeta2FeatureFlag;
 
 pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@@ -203,55 +203,41 @@ fn assign_edit_prediction_provider(
             let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
             editor.set_edit_prediction_provider(Some(provider), window, cx);
         }
-        EditPredictionProvider::Experimental(name) => {
-            if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
-                && cx.has_flag::<SweepFeatureFlag>()
-            {
-                if let Some(project) = editor.project()
-                    && let Some(workspace) = editor.workspace()
+        value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
+            if let Some(project) = editor.project() {
+                let mut worktree = None;
+                if let Some(buffer) = &singleton_buffer
+                    && let Some(file) = buffer.read(cx).file()
                 {
-                    let sweep_ai = sweep_ai::SweepAi::register(cx);
-
-                    if let Some(buffer) = &singleton_buffer
-                        && buffer.read(cx).file().is_some()
-                    {
-                        sweep_ai.update(cx, |sweep_ai, cx| {
-                            sweep_ai.register_buffer(buffer, project, cx);
-                        });
-                    }
+                    let id = file.worktree_id(cx);
+                    worktree = project.read(cx).worktree_for_id(id, cx);
+                }
 
-                    let provider = cx.new(|_| {
-                        sweep_ai::SweepAiEditPredictionProvider::new(
-                            sweep_ai,
-                            workspace.downgrade(),
+                if let EditPredictionProvider::Experimental(name) = value
+                    && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+                    && cx.has_flag::<SweepFeatureFlag>()
+                {
+                    let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
+                    let provider = cx.new(|cx| {
+                        zeta2::ZetaEditPredictionProvider::new(
                             project.clone(),
+                            &client,
+                            &user_store,
+                            cx,
                         )
                     });
-                    editor.set_edit_prediction_provider(Some(provider), window, cx);
-                }
-            } else {
-                editor.set_edit_prediction_provider::<SweepAiEditPredictionProvider>(
-                    None, window, cx,
-                );
-            }
-        }
-        EditPredictionProvider::Zed => {
-            if user_store.read(cx).current_user().is_some() {
-                let mut worktree = None;
 
-                if let Some(buffer) = &singleton_buffer
-                    && let Some(file) = buffer.read(cx).file()
-                {
-                    let id = file.worktree_id(cx);
-                    if let Some(inner_worktree) = editor
-                        .project()
-                        .and_then(|project| project.read(cx).worktree_for_id(id, cx))
+                    if let Some(buffer) = &singleton_buffer
+                        && buffer.read(cx).file().is_some()
                     {
-                        worktree = Some(inner_worktree);
+                        zeta2.update(cx, |zeta, cx| {
+                            zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
+                            zeta.register_buffer(buffer, project, cx);
+                        });
                     }
-                }
 
-                if let Some(project) = editor.project() {
+                    editor.set_edit_prediction_provider(Some(provider), window, cx);
+                } else if user_store.read(cx).current_user().is_some() {
                     if cx.has_flag::<Zeta2FeatureFlag>() {
                         let zeta = zeta2::Zeta::global(client, &user_store, cx);
                         let provider = cx.new(|cx| {
@@ -268,6 +254,9 @@ fn assign_edit_prediction_provider(
                             && buffer.read(cx).file().is_some()
                         {
                             zeta.update(cx, |zeta, cx| {
+                                zeta.set_edit_prediction_model(
+                                    zeta2::ZetaEditPredictionModel::ZedCloud,
+                                );
                                 zeta.register_buffer(buffer, project, cx);
                             });
                         }

crates/zeta2/Cargo.toml 🔗

@@ -17,6 +17,7 @@ eval-support = []
 [dependencies]
 anyhow.workspace = true
 arrayvec.workspace = true
+brotli.workspace = true
 chrono.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true

crates/zeta2/src/provider.rs 🔗

@@ -12,7 +12,7 @@ use language::ToPoint as _;
 use project::Project;
 use util::ResultExt as _;
 
-use crate::{BufferEditPrediction, Zeta};
+use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
 
 pub struct ZetaEditPredictionProvider {
     zeta: Entity<Zeta>,
@@ -85,9 +85,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         &self,
         _buffer: &Entity<language::Buffer>,
         _cursor_position: language::Anchor,
-        _cx: &App,
+        cx: &App,
     ) -> bool {
-        true
+        let zeta = self.zeta.read(cx);
+        if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
+            zeta.sweep_api_token.is_some()
+        } else {
+            true
+        }
     }
 
     fn is_refreshing(&self) -> bool {

crates/sweep_ai/src/api.rs → crates/zeta2/src/sweep_ai.rs 🔗

@@ -1,6 +1,8 @@
+use std::fmt;
 use std::{path::Path, sync::Arc};
 
 use serde::{Deserialize, Serialize};
+use util::rel_path::RelPath;
 
 #[derive(Debug, Clone, Serialize)]
 pub struct AutocompleteRequest {
@@ -88,3 +90,49 @@ pub struct AdditionalCompletion {
     pub logprobs: Option<serde_json::Value>,
     pub finish_reason: Option<String>,
 }
+
+pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
+    match event {
+        crate::Event::BufferChange {
+            old_snapshot,
+            new_snapshot,
+            ..
+        } => {
+            let old_path = old_snapshot
+                .file()
+                .map(|f| f.path().as_ref())
+                .unwrap_or(RelPath::unix("untitled").unwrap());
+            let new_path = new_snapshot
+                .file()
+                .map(|f| f.path().as_ref())
+                .unwrap_or(RelPath::unix("untitled").unwrap());
+            if old_path != new_path {
+                // TODO confirm how to do this for sweep
+                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
+            }
+
+            let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+            if !diff.is_empty() {
+                write!(
+                    f,
+                    "File: {}:\n{}\n",
+                    new_path.display(util::paths::PathStyle::Posix),
+                    diff
+                )?
+            }
+
+            fmt::Result::Ok(())
+        }
+    }
+}
+
+pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
+    format!(
+        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
+        version = release_channel::AppVersion::global(cx),
+        sha = release_channel::AppCommitSha::try_global(cx)
+            .map_or("unknown".to_string(), |sha| sha.full()),
+        os = client::telemetry::os_name(),
+    )
+    .into()
+}

crates/zeta2/src/zeta2.rs 🔗

@@ -22,30 +22,31 @@ use gpui::{
     App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
     http_client, prelude::*,
 };
-use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use open_ai::FunctionDefinition;
-use project::Project;
+use project::{Project, ProjectPath};
 use release_channel::AppVersion;
 use serde::de::DeserializeOwned;
 use std::collections::{VecDeque, hash_map};
 
-use std::env;
 use std::ops::Range;
 use std::path::Path;
 use std::str::FromStr as _;
 use std::sync::{Arc, LazyLock};
 use std::time::{Duration, Instant};
+use std::{env, mem};
 use thiserror::Error;
 use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, TryFutureExt};
+use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
 pub mod assemble_excerpts;
 mod prediction;
 mod provider;
 pub mod retrieval_search;
+mod sweep_ai;
 pub mod udiff;
 mod xml_edits;
 
@@ -55,8 +56,15 @@ pub use crate::prediction::EditPredictionId;
 pub use provider::ZetaEditPredictionProvider;
 
 /// Maximum number of events to track.
-const MAX_EVENT_COUNT: usize = 16;
+const EVENT_COUNT_MAX_SWEEP: usize = 6;
+const EVENT_COUNT_MAX_ZETA: usize = 16;
+const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 
+pub struct SweepFeatureFlag;
+
+impl FeatureFlag for SweepFeatureFlag {
+    const NAME: &str = "sweep-ai";
+}
 pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
     max_bytes: 512,
     min_bytes: 128,
@@ -143,6 +151,15 @@ pub struct Zeta {
     debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
     #[cfg(feature = "eval-support")]
     eval_cache: Option<Arc<dyn EvalCache>>,
+    edit_prediction_model: ZetaEditPredictionModel,
+    sweep_api_token: Option<String>,
+    sweep_ai_debug_info: Arc<str>,
+}
+
+#[derive(PartialEq, Eq)]
+pub enum ZetaEditPredictionModel {
+    ZedCloud,
+    Sweep,
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -219,12 +236,14 @@ pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 struct ZetaProject {
     syntax_index: Option<Entity<SyntaxIndex>>,
     events: VecDeque<Event>,
+    recent_paths: VecDeque<ProjectPath>,
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     current_prediction: Option<CurrentEditPrediction>,
     context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
     refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
     refresh_context_debounce_task: Option<Task<Option<()>>>,
     refresh_context_timestamp: Option<Instant>,
+    _subscription: gpui::Subscription,
 }
 
 #[derive(Debug, Clone)]
@@ -287,6 +306,7 @@ pub enum Event {
     BufferChange {
         old_snapshot: BufferSnapshot,
         new_snapshot: BufferSnapshot,
+        end_edit_anchor: Option<Anchor>,
         timestamp: Instant,
     },
 }
@@ -381,7 +401,19 @@ impl Zeta {
             debug_tx: None,
             #[cfg(feature = "eval-support")]
             eval_cache: None,
+            edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
+            sweep_api_token: None,
+            sweep_ai_debug_info: sweep_ai::debug_info(cx),
+        }
+    }
+
+    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
+        if model == ZetaEditPredictionModel::Sweep {
+            self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
+                .context("No SWEEP_AI_TOKEN environment variable set")
+                .log_err();
         }
+        self.edit_prediction_model = model;
     }
 
     #[cfg(feature = "eval-support")]
@@ -443,7 +475,7 @@ impl Zeta {
         self.user_store.read(cx).edit_prediction_usage()
     }
 
-    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
+    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
         self.get_or_init_zeta_project(project, cx);
     }
 
@@ -460,7 +492,7 @@ impl Zeta {
     fn get_or_init_zeta_project(
         &mut self,
         project: &Entity<Project>,
-        cx: &mut App,
+        cx: &mut Context<Self>,
     ) -> &mut ZetaProject {
         self.projects
             .entry(project.entity_id())
@@ -473,12 +505,31 @@ impl Zeta {
                     None
                 },
                 events: VecDeque::new(),
+                recent_paths: VecDeque::new(),
                 registered_buffers: HashMap::default(),
                 current_prediction: None,
                 context: None,
                 refresh_context_task: None,
                 refresh_context_debounce_task: None,
                 refresh_context_timestamp: None,
+                _subscription: cx.subscribe(&project, |this, project, event, cx| {
+                    // TODO [zeta2] init with recent paths
+                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+                        if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
+                            let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+                            if let Some(path) = path {
+                                if let Some(ix) = zeta_project
+                                    .recent_paths
+                                    .iter()
+                                    .position(|probe| probe == &path)
+                                {
+                                    zeta_project.recent_paths.remove(ix);
+                                }
+                                zeta_project.recent_paths.push_front(path);
+                            }
+                        }
+                    }
+                }),
             })
     }
 
@@ -525,66 +576,64 @@ impl Zeta {
         buffer: &Entity<Buffer>,
         project: &Entity<Project>,
         cx: &mut Context<Self>,
-    ) -> BufferSnapshot {
-        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
-        let zeta_project = self.get_or_init_zeta_project(project, cx);
-        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+    ) {
+        let event_count_max = match self.edit_prediction_model {
+            ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
+            ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
+        };
+
+        let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
+        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
 
         let new_snapshot = buffer.read(cx).snapshot();
-        if new_snapshot.version != registered_buffer.snapshot.version {
-            let old_snapshot =
-                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
-            Self::push_event(
-                zeta_project,
-                buffer_change_grouping_interval,
-                Event::BufferChange {
-                    old_snapshot,
-                    new_snapshot: new_snapshot.clone(),
-                    timestamp: Instant::now(),
-                },
-            );
+        if new_snapshot.version == registered_buffer.snapshot.version {
+            return;
         }
 
-        new_snapshot
-    }
-
-    fn push_event(
-        zeta_project: &mut ZetaProject,
-        buffer_change_grouping_interval: Duration,
-        event: Event,
-    ) {
-        let events = &mut zeta_project.events;
+        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+        let end_edit_anchor = new_snapshot
+            .anchored_edits_since::<Point>(&old_snapshot.version)
+            .last()
+            .map(|(_, range)| range.end);
+        let events = &mut sweep_ai_project.events;
 
-        if buffer_change_grouping_interval > Duration::ZERO
-            && let Some(Event::BufferChange {
-                new_snapshot: last_new_snapshot,
-                timestamp: last_timestamp,
-                ..
-            }) = events.back_mut()
+        if let Some(Event::BufferChange {
+            new_snapshot: last_new_snapshot,
+            end_edit_anchor: last_end_edit_anchor,
+            ..
+        }) = events.back_mut()
         {
-            // Coalesce edits for the same buffer when they happen one after the other.
-            let Event::BufferChange {
-                old_snapshot,
-                new_snapshot,
-                timestamp,
-            } = &event;
-
-            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
-                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
-                && old_snapshot.version == last_new_snapshot.version
-            {
-                *last_new_snapshot = new_snapshot.clone();
-                *last_timestamp = *timestamp;
+            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
+                == last_new_snapshot.remote_id()
+                && old_snapshot.version == last_new_snapshot.version;
+
+            let should_coalesce = is_next_snapshot_of_same_buffer
+                && end_edit_anchor
+                    .as_ref()
+                    .zip(last_end_edit_anchor.as_ref())
+                    .is_some_and(|(a, b)| {
+                        let a = a.to_point(&new_snapshot);
+                        let b = b.to_point(&new_snapshot);
+                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
+                    });
+
+            if should_coalesce {
+                *last_end_edit_anchor = end_edit_anchor;
+                *last_new_snapshot = new_snapshot;
                 return;
             }
         }
 
-        if events.len() >= MAX_EVENT_COUNT {
-            // These are halved instead of popping to improve prompt caching.
-            events.drain(..MAX_EVENT_COUNT / 2);
+        if events.len() >= event_count_max {
+            events.pop_front();
         }
 
-        events.push_back(event);
+        events.push_back(Event::BufferChange {
+            old_snapshot,
+            new_snapshot,
+            end_edit_anchor,
+            timestamp: Instant::now(),
+        });
     }
 
     fn current_prediction_for_buffer(
@@ -706,6 +755,203 @@ impl Zeta {
         active_buffer: &Entity<Buffer>,
         position: language::Anchor,
         cx: &mut Context<Self>,
+    ) -> Task<Result<Option<EditPrediction>>> {
+        match self.edit_prediction_model {
+            ZetaEditPredictionModel::ZedCloud => {
+                self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
+            }
+            ZetaEditPredictionModel::Sweep => {
+                self.request_prediction_with_sweep(project, active_buffer, position, cx)
+            }
+        }
+    }
+
+    fn request_prediction_with_sweep(
+        &mut self,
+        project: &Entity<Project>,
+        active_buffer: &Entity<Buffer>,
+        position: language::Anchor,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<Option<EditPrediction>>> {
+        let snapshot = active_buffer.read(cx).snapshot();
+        let debug_info = self.sweep_ai_debug_info.clone();
+        let Some(api_token) = self.sweep_api_token.clone() else {
+            return Task::ready(Ok(None));
+        };
+        let full_path: Arc<Path> = snapshot
+            .file()
+            .map(|file| file.full_path(cx))
+            .unwrap_or_else(|| "untitled".into())
+            .into();
+
+        let project_file = project::File::from_dyn(snapshot.file());
+        let repo_name = project_file
+            .map(|file| file.worktree.read(cx).root_name_str())
+            .unwrap_or("untitled")
+            .into();
+        let offset = position.to_offset(&snapshot);
+
+        let project_state = self.get_or_init_zeta_project(project, cx);
+        let events = project_state.events.clone();
+        let recent_buffers = project_state.recent_paths.iter().cloned();
+        let http_client = cx.http_client();
+
+        let recent_buffer_snapshots = recent_buffers
+            .filter_map(|project_path| {
+                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+                if active_buffer == &buffer {
+                    None
+                } else {
+                    Some(buffer.read(cx).snapshot())
+                }
+            })
+            .take(3)
+            .collect::<Vec<_>>();
+
+        let result = cx.background_spawn(async move {
+            let text = snapshot.text();
+
+            let mut recent_changes = String::new();
+            for event in events {
+                sweep_ai::write_event(event, &mut recent_changes).unwrap();
+            }
+
+            let file_chunks = recent_buffer_snapshots
+                .into_iter()
+                .map(|snapshot| {
+                    let end_point = language::Point::new(30, 0).min(snapshot.max_point());
+                    sweep_ai::FileChunk {
+                        content: snapshot
+                            .text_for_range(language::Point::zero()..end_point)
+                            .collect(),
+                        file_path: snapshot
+                            .file()
+                            .map(|f| f.path().as_unix_str())
+                            .unwrap_or("untitled")
+                            .to_string(),
+                        start_line: 0,
+                        end_line: end_point.row as usize,
+                        timestamp: snapshot.file().and_then(|file| {
+                            Some(
+                                file.disk_state()
+                                    .mtime()?
+                                    .to_seconds_and_nanos_for_persistence()?
+                                    .0,
+                            )
+                        }),
+                    }
+                })
+                .collect();
+
+            let request_body = sweep_ai::AutocompleteRequest {
+                debug_info,
+                repo_name,
+                file_path: full_path.clone(),
+                file_contents: text.clone(),
+                original_file_contents: text,
+                cursor_position: offset,
+                recent_changes: recent_changes.clone(),
+                changes_above_cursor: true,
+                multiple_suggestions: false,
+                branch: None,
+                file_chunks,
+                retrieval_chunks: vec![],
+                recent_user_actions: vec![],
+                // TODO
+                privacy_mode_enabled: false,
+            };
+
+            let mut buf: Vec<u8> = Vec::new();
+            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+            serde_json::to_writer(writer, &request_body)?;
+            let body: AsyncBody = buf.into();
+
+            const SWEEP_API_URL: &str =
+                "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+            let request = http_client::Request::builder()
+                .uri(SWEEP_API_URL)
+                .header("Content-Type", "application/json")
+                .header("Authorization", format!("Bearer {}", api_token))
+                .header("Connection", "keep-alive")
+                .header("Content-Encoding", "br")
+                .method(Method::POST)
+                .body(body)?;
+
+            let mut response = http_client.send(request).await?;
+
+            let mut body: Vec<u8> = Vec::new();
+            response.body_mut().read_to_end(&mut body).await?;
+
+            if !response.status().is_success() {
+                anyhow::bail!(
+                    "Request failed with status: {:?}\nBody: {}",
+                    response.status(),
+                    String::from_utf8_lossy(&body),
+                );
+            };
+
+            let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
+
+            let old_text = snapshot
+                .text_for_range(response.start_index..response.end_index)
+                .collect::<String>();
+            let edits = language::text_diff(&old_text, &response.completion)
+                .into_iter()
+                .map(|(range, text)| {
+                    (
+                        snapshot.anchor_after(response.start_index + range.start)
+                            ..snapshot.anchor_before(response.start_index + range.end),
+                        text,
+                    )
+                })
+                .collect::<Vec<_>>();
+
+            anyhow::Ok((response.autocomplete_id, edits, snapshot))
+        });
+
+        let buffer = active_buffer.clone();
+
+        cx.spawn(async move |_, cx| {
+            let (id, edits, old_snapshot) = result.await?;
+
+            if edits.is_empty() {
+                return anyhow::Ok(None);
+            }
+
+            let Some((edits, new_snapshot, preview_task)) =
+                buffer.read_with(cx, |buffer, cx| {
+                    let new_snapshot = buffer.snapshot();
+
+                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
+                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
+                            .into();
+                    let preview_task = buffer.preview_edits(edits.clone(), cx);
+
+                    Some((edits, new_snapshot, preview_task))
+                })?
+            else {
+                return anyhow::Ok(None);
+            };
+
+            let prediction = EditPrediction {
+                id: EditPredictionId(id.into()),
+                edits,
+                snapshot: new_snapshot,
+                edit_preview: preview_task.await,
+                buffer,
+            };
+
+            anyhow::Ok(Some(prediction))
+        })
+    }
+
+    fn request_prediction_with_zed_cloud(
+        &mut self,
+        project: &Entity<Project>,
+        active_buffer: &Entity<Buffer>,
+        position: language::Anchor,
+        cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
         let project_state = self.projects.get(&project.entity_id());
 
@@ -1653,7 +1899,7 @@ impl Zeta {
     pub fn wait_for_initial_indexing(
         &mut self,
         project: &Entity<Project>,
-        cx: &mut App,
+        cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
         let zeta_project = self.get_or_init_zeta_project(project, cx);
         if let Some(syntax_index) = &zeta_project.syntax_index {

crates/zeta_cli/Cargo.toml 🔗

@@ -49,7 +49,6 @@ settings.workspace = true
 shellexpand.workspace = true
 smol.workspace = true
 soa-rs = "0.8.1"
-sweep_ai.workspace = true
 terminal_view.workspace = true
 toml.workspace = true
 util.workspace = true

crates/zeta_cli/src/evaluate.rs 🔗

@@ -8,16 +8,15 @@ use anyhow::Result;
 use collections::HashSet;
 use gpui::{AsyncApp, Entity};
 use project::Project;
-use sweep_ai::SweepAi;
 use util::ResultExt as _;
 use zeta2::{Zeta, udiff::DiffLine};
 
 use crate::{
-    EvaluateArguments, PredictionOptions, PredictionProvider,
+    EvaluateArguments, PredictionOptions,
     example::{Example, NamedExample},
     headless::ZetaCliAppState,
     paths::print_run_data_dir,
-    predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
+    predict::{PredictionDetails, perform_predict, setup_zeta},
 };
 
 #[derive(Debug)]
@@ -46,46 +45,35 @@ pub async fn run_evaluate(
             let project = example.setup_project(&app_state, cx).await.unwrap();
 
             let providers = (0..args.repetitions)
-                .map(|_| {
-                    (
-                        setup_zeta(&project, &app_state, cx).unwrap(),
-                        if matches!(args.options.provider, PredictionProvider::Sweep) {
-                            Some(setup_sweep(&project, cx).unwrap())
-                        } else {
-                            None
-                        },
-                    )
-                })
+                .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
                 .collect::<Vec<_>>();
 
             let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 
-            let tasks =
-                providers
-                    .into_iter()
-                    .enumerate()
-                    .map(move |(repetition_ix, (zeta, sweep))| {
-                        let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
-                        let example = example.clone();
-                        let project = project.clone();
-                        let options = options.clone();
-
-                        cx.spawn(async move |cx| {
-                            let name = example.name.clone();
-                            run_evaluate_one(
-                                example,
-                                repetition_ix,
-                                project,
-                                zeta,
-                                sweep,
-                                options,
-                                !args.skip_prediction,
-                                cx,
-                            )
-                            .await
-                            .map_err(|err| (err, name, repetition_ix))
-                        })
-                    });
+            let tasks = providers
+                .into_iter()
+                .enumerate()
+                .map(move |(repetition_ix, zeta)| {
+                    let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
+                    let example = example.clone();
+                    let project = project.clone();
+                    let options = options.clone();
+
+                    cx.spawn(async move |cx| {
+                        let name = example.name.clone();
+                        run_evaluate_one(
+                            example,
+                            repetition_ix,
+                            project,
+                            zeta,
+                            options,
+                            !args.skip_prediction,
+                            cx,
+                        )
+                        .await
+                        .map_err(|err| (err, name, repetition_ix))
+                    })
+                });
             futures::future::join_all(tasks).await
         })
     });
@@ -177,7 +165,6 @@ pub async fn run_evaluate_one(
     repetition_ix: Option<u16>,
     project: Entity<Project>,
     zeta: Entity<Zeta>,
-    sweep: Option<Entity<SweepAi>>,
     prediction_options: PredictionOptions,
     predict: bool,
     cx: &mut AsyncApp,
@@ -186,7 +173,6 @@ pub async fn run_evaluate_one(
         example.clone(),
         project,
         zeta,
-        sweep,
         repetition_ix,
         prediction_options,
         cx,

crates/zeta_cli/src/main.rs 🔗

@@ -191,7 +191,7 @@ pub struct EvaluateArguments {
     skip_prediction: bool,
 }
 
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
 enum PredictionProvider {
     #[default]
     Zeta2,

crates/zeta_cli/src/predict.rs 🔗

@@ -21,7 +21,6 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::sync::Mutex;
 use std::time::{Duration, Instant};
-use sweep_ai::SweepAi;
 use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 
 pub async fn run_predict(
@@ -31,14 +30,9 @@ pub async fn run_predict(
 ) {
     let example = NamedExample::load(args.example_path).unwrap();
     let project = example.setup_project(app_state, cx).await.unwrap();
-    let zeta = setup_zeta(&project, app_state, cx).unwrap();
-    let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
-        Some(setup_sweep(&project, cx).unwrap())
-    } else {
-        None
-    };
+    let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
     let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
-    let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
+    let result = perform_predict(example, project, zeta, None, args.options, cx)
         .await
         .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
@@ -47,6 +41,7 @@ pub async fn run_predict(
 }
 
 pub fn setup_zeta(
+    provider: PredictionProvider,
     project: &Entity<Project>,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
@@ -54,6 +49,14 @@ pub fn setup_zeta(
     let zeta =
         cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
 
+    zeta.update(cx, |zeta, _cx| {
+        let model = match provider {
+            PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
+            PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
+        };
+        zeta.set_edit_prediction_model(model);
+    })?;
+
     let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 
     cx.subscribe(&buffer_store, {
@@ -71,31 +74,10 @@ pub fn setup_zeta(
     anyhow::Ok(zeta)
 }
 
-pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
-    let sweep = cx.new(|cx| SweepAi::new(cx))?;
-
-    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
-
-    cx.subscribe(&buffer_store, {
-        let project = project.clone();
-        let sweep = sweep.clone();
-        move |_, event, cx| match event {
-            BufferStoreEvent::BufferAdded(buffer) => {
-                sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
-            }
-            _ => {}
-        }
-    })?
-    .detach();
-
-    anyhow::Ok(sweep)
-}
-
 pub async fn perform_predict(
     example: NamedExample,
     project: Entity<Project>,
     zeta: Entity<Zeta>,
-    sweep: Option<Entity<SweepAi>>,
     repetition_ix: Option<u16>,
     options: PredictionOptions,
     cx: &mut AsyncApp,
@@ -147,194 +129,152 @@ pub async fn perform_predict(
         zeta.set_options(options);
     })?;
 
-    let prediction = match options.provider {
-        crate::PredictionProvider::Zeta2 => {
-            let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
-
-            let debug_task = cx.background_spawn({
-                let result = result.clone();
-                async move {
-                    let mut start_time = None;
-                    let mut search_queries_generated_at = None;
-                    let mut search_queries_executed_at = None;
-                    while let Some(event) = debug_rx.next().await {
-                        match event {
-                            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
-                                start_time = Some(info.timestamp);
-                                fs::write(
-                                    example_run_dir.join("search_prompt.md"),
-                                    &info.search_prompt,
-                                )?;
-                            }
-                            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
-                                search_queries_generated_at = Some(info.timestamp);
-                                fs::write(
-                                    example_run_dir.join("search_queries.json"),
-                                    serde_json::to_string_pretty(&info.search_queries).unwrap(),
-                                )?;
-                            }
-                            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
-                                search_queries_executed_at = Some(info.timestamp);
-                            }
-                            zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
-                            zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
-                                let prediction_started_at = Instant::now();
-                                start_time.get_or_insert(prediction_started_at);
-                                let prompt = request.local_prompt.unwrap_or_default();
-                                fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
-                                {
-                                    let mut result = result.lock().unwrap();
-                                    result.prompt_len = prompt.chars().count();
-
-                                    for included_file in request.request.included_files {
-                                        let insertions =
-                                            vec![(request.request.cursor_point, CURSOR_MARKER)];
-                                        result.excerpts.extend(included_file.excerpts.iter().map(
-                                            |excerpt| {
-                                                ActualExcerpt {
-                                                    path: included_file
-                                                        .path
-                                                        .components()
-                                                        .skip(1)
-                                                        .collect(),
-                                                    text: String::from(excerpt.text.as_ref()),
-                                                }
-                                            },
-                                        ));
-                                        write_codeblock(
-                                            &included_file.path,
-                                            included_file.excerpts.iter(),
-                                            if included_file.path == request.request.excerpt_path {
-                                                &insertions
-                                            } else {
-                                                &[]
-                                            },
-                                            included_file.max_row,
-                                            false,
-                                            &mut result.excerpts_text,
-                                        );
-                                    }
-                                }
-
-                                let response =
-                                    request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
-                                let response =
-                                    zeta2::text_from_response(response).unwrap_or_default();
-                                let prediction_finished_at = Instant::now();
-                                fs::write(
-                                    example_run_dir.join("prediction_response.md"),
-                                    &response,
-                                )?;
-
+    let mut debug_task = gpui::Task::ready(Ok(()));
+
+    if options.provider == crate::PredictionProvider::Zeta2 {
+        let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+
+        debug_task = cx.background_spawn({
+            let result = result.clone();
+            async move {
+                let mut start_time = None;
+                let mut search_queries_generated_at = None;
+                let mut search_queries_executed_at = None;
+                while let Some(event) = debug_rx.next().await {
+                    match event {
+                        zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                            start_time = Some(info.timestamp);
+                            fs::write(
+                                example_run_dir.join("search_prompt.md"),
+                                &info.search_prompt,
+                            )?;
+                        }
+                        zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                            search_queries_generated_at = Some(info.timestamp);
+                            fs::write(
+                                example_run_dir.join("search_queries.json"),
+                                serde_json::to_string_pretty(&info.search_queries).unwrap(),
+                            )?;
+                        }
+                        zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                            search_queries_executed_at = Some(info.timestamp);
+                        }
+                        zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
+                        zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+                            let prediction_started_at = Instant::now();
+                            start_time.get_or_insert(prediction_started_at);
+                            let prompt = request.local_prompt.unwrap_or_default();
+                            fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
+
+                            {
                                 let mut result = result.lock().unwrap();
-                                result.generated_len = response.chars().count();
-
-                                if !options.use_expected_context {
-                                    result.planning_search_time = Some(
-                                        search_queries_generated_at.unwrap() - start_time.unwrap(),
-                                    );
-                                    result.running_search_time = Some(
-                                        search_queries_executed_at.unwrap()
-                                            - search_queries_generated_at.unwrap(),
+                                result.prompt_len = prompt.chars().count();
+
+                                for included_file in request.request.included_files {
+                                    let insertions =
+                                        vec![(request.request.cursor_point, CURSOR_MARKER)];
+                                    result.excerpts.extend(included_file.excerpts.iter().map(
+                                        |excerpt| ActualExcerpt {
+                                            path: included_file.path.components().skip(1).collect(),
+                                            text: String::from(excerpt.text.as_ref()),
+                                        },
+                                    ));
+                                    write_codeblock(
+                                        &included_file.path,
+                                        included_file.excerpts.iter(),
+                                        if included_file.path == request.request.excerpt_path {
+                                            &insertions
+                                        } else {
+                                            &[]
+                                        },
+                                        included_file.max_row,
+                                        false,
+                                        &mut result.excerpts_text,
                                     );
                                 }
-                                result.prediction_time =
-                                    prediction_finished_at - prediction_started_at;
-                                result.total_time = prediction_finished_at - start_time.unwrap();
+                            }
 
-                                break;
+                            let response =
+                                request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+                            let response = zeta2::text_from_response(response).unwrap_or_default();
+                            let prediction_finished_at = Instant::now();
+                            fs::write(example_run_dir.join("prediction_response.md"), &response)?;
+
+                            let mut result = result.lock().unwrap();
+                            result.generated_len = response.chars().count();
+
+                            if !options.use_expected_context {
+                                result.planning_search_time = Some(
+                                    search_queries_generated_at.unwrap() - start_time.unwrap(),
+                                );
+                                result.running_search_time = Some(
+                                    search_queries_executed_at.unwrap()
+                                        - search_queries_generated_at.unwrap(),
+                                );
                             }
+                            result.prediction_time = prediction_finished_at - prediction_started_at;
+                            result.total_time = prediction_finished_at - start_time.unwrap();
+
+                            break;
                         }
                     }
-                    anyhow::Ok(())
                 }
-            });
-
-            if options.use_expected_context {
-                let context_excerpts_tasks = example
-                    .example
-                    .expected_context
-                    .iter()
-                    .flat_map(|section| {
-                        section.alternatives[0].excerpts.iter().map(|excerpt| {
-                            resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
-                        })
+                anyhow::Ok(())
+            }
+        });
+
+        if options.use_expected_context {
+            let context_excerpts_tasks = example
+                .example
+                .expected_context
+                .iter()
+                .flat_map(|section| {
+                    section.alternatives[0].excerpts.iter().map(|excerpt| {
+                        resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
                     })
-                    .collect::<Vec<_>>();
-                let context_excerpts_vec =
-                    futures::future::try_join_all(context_excerpts_tasks).await?;
-
-                let mut context_excerpts = HashMap::default();
-                for (buffer, mut excerpts) in context_excerpts_vec {
-                    context_excerpts
-                        .entry(buffer)
-                        .or_insert(Vec::new())
-                        .append(&mut excerpts);
-                }
-
-                zeta.update(cx, |zeta, _cx| {
-                    zeta.set_context(project.clone(), context_excerpts)
-                })?;
-            } else {
-                zeta.update(cx, |zeta, cx| {
-                    zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
-                })?
-                .await?;
+                })
+                .collect::<Vec<_>>();
+            let context_excerpts_vec =
+                futures::future::try_join_all(context_excerpts_tasks).await?;
+
+            let mut context_excerpts = HashMap::default();
+            for (buffer, mut excerpts) in context_excerpts_vec {
+                context_excerpts
+                    .entry(buffer)
+                    .or_insert(Vec::new())
+                    .append(&mut excerpts);
             }
 
-            let prediction = zeta
-                .update(cx, |zeta, cx| {
-                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
-                })?
-                .await?
-                .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
-
-            debug_task.await?;
-
-            prediction
+            zeta.update(cx, |zeta, _cx| {
+                zeta.set_context(project.clone(), context_excerpts)
+            })?;
+        } else {
+            zeta.update(cx, |zeta, cx| {
+                zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+            })?
+            .await?;
         }
-        crate::PredictionProvider::Sweep => sweep
-            .unwrap()
-            .update(cx, |sweep, cx| {
-                let mut recent_paths = Vec::new();
-                for path in zeta
-                    .read(cx)
-                    .history_for_project(&project)
-                    .rev()
-                    .filter_map(|event| event.project_path(cx))
-                {
-                    if !recent_paths.contains(&path) {
-                        recent_paths.push(path);
-                    }
-                }
+    }
 
-                sweep.request_completion(
-                    &project,
-                    recent_paths.into_iter(),
-                    &cursor_buffer,
-                    cursor_anchor,
-                    cx,
-                )
-            })?
-            .await?
-            .map(
-                |sweep_ai::EditPrediction {
-                     edits, snapshot, ..
-                 }| { (cursor_buffer.clone(), snapshot, edits) },
-            ),
-    };
+    let prediction = zeta
+        .update(cx, |zeta, cx| {
+            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+        })?
+        .await?;
+
+    debug_task.await?;
 
     let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
 
     result.diff = prediction
-        .map(|(buffer, snapshot, edits)| {
-            let old_text = snapshot.text();
-            let new_text = buffer
+        .map(|prediction| {
+            let old_text = prediction.snapshot.text();
+            let new_text = prediction
+                .buffer
                 .update(cx, |buffer, cx| {
                     let branch = buffer.branch(cx);
                     branch.update(cx, |branch, cx| {
-                        branch.edit(edits.iter().cloned(), None, cx);
+                        branch.edit(prediction.edits.iter().cloned(), None, cx);
                         branch.text()
                     })
                 })