Separate experimental edit prediction jumps feature from the Sweep AI prediction provider (#43481)

Max Brunsfeld and Ben Kunkle created

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/zeta/src/provider.rs           |   2 
crates/zeta/src/sweep_ai.rs           | 283 ++++++++++++++++++++++++-
crates/zeta/src/zeta.rs               | 320 +++++-----------------------
crates/zeta/src/zeta1.rs              |   6 
crates/zeta2_tools/src/zeta2_tools.rs |  79 +++---
5 files changed, 381 insertions(+), 309 deletions(-)

Detailed changes

crates/zeta/src/provider.rs 🔗

@@ -77,7 +77,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
     ) -> bool {
         let zeta = self.zeta.read(cx);
         if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
-            zeta.sweep_api_token.is_some()
+            zeta.sweep_ai.api_token.is_some()
         } else {
             true
         }

crates/zeta/src/sweep_ai.rs 🔗

@@ -1,10 +1,269 @@
-use std::fmt;
-use std::{path::Path, sync::Arc};
-
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::Event;
+use futures::AsyncReadExt as _;
+use gpui::{
+    App, AppContext as _, Entity, Task,
+    http_client::{self, AsyncBody, Method},
+};
+use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use lsp::DiagnosticSeverity;
+use project::{Project, ProjectPath};
 use serde::{Deserialize, Serialize};
+use std::{
+    collections::VecDeque,
+    fmt::{self, Write as _},
+    ops::Range,
+    path::Path,
+    sync::Arc,
+    time::Instant,
+};
+use util::ResultExt as _;
+
+use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
+
+const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+pub struct SweepAi {
+    pub api_token: Option<String>,
+    pub debug_info: Arc<str>,
+}
+
+impl SweepAi {
+    pub fn new(cx: &App) -> Self {
+        SweepAi {
+            api_token: std::env::var("SWEEP_AI_TOKEN")
+                .context("No SWEEP_AI_TOKEN environment variable set")
+                .log_err(),
+            debug_info: debug_info(cx),
+        }
+    }
+
+    pub fn request_prediction_with_sweep(
+        &self,
+        project: &Entity<Project>,
+        active_buffer: &Entity<Buffer>,
+        snapshot: BufferSnapshot,
+        position: language::Anchor,
+        events: Vec<Arc<Event>>,
+        recent_paths: &VecDeque<ProjectPath>,
+        diagnostic_search_range: Range<Point>,
+        cx: &mut App,
+    ) -> Task<Result<Option<EditPrediction>>> {
+        let debug_info = self.debug_info.clone();
+        let Some(api_token) = self.api_token.clone() else {
+            return Task::ready(Ok(None));
+        };
+        let full_path: Arc<Path> = snapshot
+            .file()
+            .map(|file| file.full_path(cx))
+            .unwrap_or_else(|| "untitled".into())
+            .into();
+
+        let project_file = project::File::from_dyn(snapshot.file());
+        let repo_name = project_file
+            .map(|file| file.worktree.read(cx).root_name_str())
+            .unwrap_or("untitled")
+            .into();
+        let offset = position.to_offset(&snapshot);
+
+        let recent_buffers = recent_paths.iter().cloned();
+        let http_client = cx.http_client();
+
+        let recent_buffer_snapshots = recent_buffers
+            .filter_map(|project_path| {
+                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+                if active_buffer == &buffer {
+                    None
+                } else {
+                    Some(buffer.read(cx).snapshot())
+                }
+            })
+            .take(3)
+            .collect::<Vec<_>>();
+
+        let cursor_point = position.to_point(&snapshot);
+        let buffer_snapshotted_at = Instant::now();
+
+        let result = cx.background_spawn(async move {
+            let text = snapshot.text();
+
+            let mut recent_changes = String::new();
+            for event in &events {
+                write_event(event.as_ref(), &mut recent_changes).unwrap();
+            }
+
+            let mut file_chunks = recent_buffer_snapshots
+                .into_iter()
+                .map(|snapshot| {
+                    let end_point = Point::new(30, 0).min(snapshot.max_point());
+                    FileChunk {
+                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
+                        file_path: snapshot
+                            .file()
+                            .map(|f| f.path().as_unix_str())
+                            .unwrap_or("untitled")
+                            .to_string(),
+                        start_line: 0,
+                        end_line: end_point.row as usize,
+                        timestamp: snapshot.file().and_then(|file| {
+                            Some(
+                                file.disk_state()
+                                    .mtime()?
+                                    .to_seconds_and_nanos_for_persistence()?
+                                    .0,
+                            )
+                        }),
+                    }
+                })
+                .collect::<Vec<_>>();
+
+            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+            let mut diagnostic_content = String::new();
+            let mut diagnostic_count = 0;
+
+            for entry in diagnostic_entries {
+                let start_point: Point = entry.range.start;
+
+                let severity = match entry.diagnostic.severity {
+                    DiagnosticSeverity::ERROR => "error",
+                    DiagnosticSeverity::WARNING => "warning",
+                    DiagnosticSeverity::INFORMATION => "info",
+                    DiagnosticSeverity::HINT => "hint",
+                    _ => continue,
+                };
+
+                diagnostic_count += 1;
+
+                writeln!(
+                    &mut diagnostic_content,
+                    "{} at line {}: {}",
+                    severity,
+                    start_point.row + 1,
+                    entry.diagnostic.message
+                )?;
+            }
+
+            if !diagnostic_content.is_empty() {
+                file_chunks.push(FileChunk {
+                    file_path: format!("Diagnostics for {}", full_path.display()),
+                    start_line: 0,
+                    end_line: diagnostic_count,
+                    content: diagnostic_content,
+                    timestamp: None,
+                });
+            }
+
+            let request_body = AutocompleteRequest {
+                debug_info,
+                repo_name,
+                file_path: full_path.clone(),
+                file_contents: text.clone(),
+                original_file_contents: text,
+                cursor_position: offset,
+                recent_changes: recent_changes.clone(),
+                changes_above_cursor: true,
+                multiple_suggestions: false,
+                branch: None,
+                file_chunks,
+                retrieval_chunks: vec![],
+                recent_user_actions: vec![],
+                // TODO
+                privacy_mode_enabled: false,
+            };
+
+            let mut buf: Vec<u8> = Vec::new();
+            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+            serde_json::to_writer(writer, &request_body)?;
+            let body: AsyncBody = buf.into();
+
+            let inputs = EditPredictionInputs {
+                events,
+                included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+                    path: full_path.clone(),
+                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
+                        text: request_body.file_contents.into(),
+                    }],
+                }],
+                cursor_point: cloud_llm_client::predict_edits_v3::Point {
+                    column: cursor_point.column,
+                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+                },
+                cursor_path: full_path.clone(),
+            };
+
+            let request = http_client::Request::builder()
+                .uri(SWEEP_API_URL)
+                .header("Content-Type", "application/json")
+                .header("Authorization", format!("Bearer {}", api_token))
+                .header("Connection", "keep-alive")
+                .header("Content-Encoding", "br")
+                .method(Method::POST)
+                .body(body)?;
+
+            let mut response = http_client.send(request).await?;
+
+            let mut body: Vec<u8> = Vec::new();
+            response.body_mut().read_to_end(&mut body).await?;
+
+            let response_received_at = Instant::now();
+            if !response.status().is_success() {
+                anyhow::bail!(
+                    "Request failed with status: {:?}\nBody: {}",
+                    response.status(),
+                    String::from_utf8_lossy(&body),
+                );
+            };
+
+            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+
+            let old_text = snapshot
+                .text_for_range(response.start_index..response.end_index)
+                .collect::<String>();
+            let edits = language::text_diff(&old_text, &response.completion)
+                .into_iter()
+                .map(|(range, text)| {
+                    (
+                        snapshot.anchor_after(response.start_index + range.start)
+                            ..snapshot.anchor_before(response.start_index + range.end),
+                        text,
+                    )
+                })
+                .collect::<Vec<_>>();
+
+            anyhow::Ok((
+                response.autocomplete_id,
+                edits,
+                snapshot,
+                response_received_at,
+                inputs,
+            ))
+        });
+
+        let buffer = active_buffer.clone();
+
+        cx.spawn(async move |cx| {
+            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+            anyhow::Ok(
+                EditPrediction::new(
+                    EditPredictionId(id.into()),
+                    &buffer,
+                    &old_snapshot,
+                    edits.into(),
+                    buffer_snapshotted_at,
+                    response_received_at,
+                    inputs,
+                    cx,
+                )
+                .await,
+            )
+        })
+    }
+}
 
 #[derive(Debug, Clone, Serialize)]
-pub struct AutocompleteRequest {
+struct AutocompleteRequest {
     pub debug_info: Arc<str>,
     pub repo_name: String,
     pub branch: Option<String>,
@@ -22,7 +281,7 @@ pub struct AutocompleteRequest {
 }
 
 #[derive(Debug, Clone, Serialize)]
-pub struct FileChunk {
+struct FileChunk {
     pub file_path: String,
     pub start_line: usize,
     pub end_line: usize,
@@ -31,7 +290,7 @@ pub struct FileChunk {
 }
 
 #[derive(Debug, Clone, Serialize)]
-pub struct RetrievalChunk {
+struct RetrievalChunk {
     pub file_path: String,
     pub start_line: usize,
     pub end_line: usize,
@@ -40,7 +299,7 @@ pub struct RetrievalChunk {
 }
 
 #[derive(Debug, Clone, Serialize)]
-pub struct UserAction {
+struct UserAction {
     pub action_type: ActionType,
     pub line_number: usize,
     pub offset: usize,
@@ -51,7 +310,7 @@ pub struct UserAction {
 #[allow(dead_code)]
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum ActionType {
+enum ActionType {
     CursorMovement,
     InsertChar,
     DeleteChar,
@@ -60,7 +319,7 @@ pub enum ActionType {
 }
 
 #[derive(Debug, Clone, Deserialize)]
-pub struct AutocompleteResponse {
+struct AutocompleteResponse {
     pub autocomplete_id: String,
     pub start_index: usize,
     pub end_index: usize,
@@ -80,7 +339,7 @@ pub struct AutocompleteResponse {
 
 #[allow(dead_code)]
 #[derive(Debug, Clone, Deserialize)]
-pub struct AdditionalCompletion {
+struct AdditionalCompletion {
     pub start_index: usize,
     pub end_index: usize,
     pub completion: String,
@@ -90,7 +349,7 @@ pub struct AdditionalCompletion {
     pub finish_reason: Option<String>,
 }
 
-pub(crate) fn write_event(
+fn write_event(
     event: &cloud_llm_client::predict_edits_v3::Event,
     f: &mut impl fmt::Write,
 ) -> fmt::Result {
@@ -115,7 +374,7 @@ pub(crate) fn write_event(
     }
 }
 
-pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
+fn debug_info(cx: &gpui::App) -> Arc<str> {
     format!(
         "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
         version = release_channel::AppVersion::global(cx),

crates/zeta/src/zeta.rs 🔗

@@ -30,7 +30,6 @@ use language::{
 };
 use language::{BufferSnapshot, OffsetRangeExt};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use lsp::DiagnosticSeverity;
 use open_ai::FunctionDefinition;
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
@@ -42,7 +41,6 @@ use std::collections::{VecDeque, hash_map};
 use telemetry_events::EditPredictionRating;
 use workspace::Workspace;
 
-use std::fmt::Write as _;
 use std::ops::Range;
 use std::path::Path;
 use std::rc::Rc;
@@ -80,6 +78,7 @@ use crate::rate_prediction_modal::{
     NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
     ThumbsUpActivePrediction,
 };
+use crate::sweep_ai::SweepAi;
 use crate::zeta1::request_prediction_with_zeta1;
 pub use provider::ZetaEditPredictionProvider;
 
@@ -171,7 +170,7 @@ impl FeatureFlag for Zeta2FeatureFlag {
     const NAME: &'static str = "zeta2";
 
     fn enabled_for_staff() -> bool {
-        false
+        true
     }
 }
 
@@ -192,8 +191,7 @@ pub struct Zeta {
     #[cfg(feature = "eval-support")]
     eval_cache: Option<Arc<dyn EvalCache>>,
     edit_prediction_model: ZetaEditPredictionModel,
-    sweep_api_token: Option<String>,
-    sweep_ai_debug_info: Arc<str>,
+    sweep_ai: SweepAi,
     data_collection_choice: DataCollectionChoice,
     rejected_predictions: Vec<EditPredictionRejection>,
     reject_predictions_tx: mpsc::UnboundedSender<()>,
@@ -202,7 +200,7 @@ pub struct Zeta {
     rated_predictions: HashSet<EditPredictionId>,
 }
 
-#[derive(Default, PartialEq, Eq)]
+#[derive(Copy, Clone, Default, PartialEq, Eq)]
 pub enum ZetaEditPredictionModel {
     #[default]
     Zeta1,
@@ -499,11 +497,8 @@ impl Zeta {
             #[cfg(feature = "eval-support")]
             eval_cache: None,
             edit_prediction_model: ZetaEditPredictionModel::Zeta2,
-            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
-                .context("No SWEEP_AI_TOKEN environment variable set")
-                .log_err(),
+            sweep_ai: SweepAi::new(cx),
             data_collection_choice,
-            sweep_ai_debug_info: sweep_ai::debug_info(cx),
             rejected_predictions: Vec::new(),
             reject_predictions_debounce_task: None,
             reject_predictions_tx: reject_tx,
@@ -517,7 +512,7 @@ impl Zeta {
     }
 
     pub fn has_sweep_api_token(&self) -> bool {
-        self.sweep_api_token.is_some()
+        self.sweep_ai.api_token.is_some()
     }
 
     #[cfg(feature = "eval-support")]
@@ -643,7 +638,9 @@ impl Zeta {
                 }
             }
             project::Event::DiagnosticsUpdated { .. } => {
-                self.refresh_prediction_from_diagnostics(project, cx);
+                if cx.has_flag::<Zeta2FeatureFlag>() {
+                    self.refresh_prediction_from_diagnostics(project, cx);
+                }
             }
             _ => (),
         }
@@ -1183,249 +1180,77 @@ impl Zeta {
         position: language::Anchor,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
-        match self.edit_prediction_model {
-            ZetaEditPredictionModel::Zeta1 => {
-                request_prediction_with_zeta1(self, project, active_buffer, position, cx)
-            }
-            ZetaEditPredictionModel::Zeta2 => {
-                self.request_prediction_with_zeta2(project, active_buffer, position, cx)
-            }
-            ZetaEditPredictionModel::Sweep => {
-                self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
-            }
-        }
+        self.request_prediction_internal(
+            project.clone(),
+            active_buffer.clone(),
+            position,
+            cx.has_flag::<Zeta2FeatureFlag>(),
+            cx,
+        )
     }
 
-    fn request_prediction_with_sweep(
+    fn request_prediction_internal(
         &mut self,
-        project: &Entity<Project>,
-        active_buffer: &Entity<Buffer>,
+        project: Entity<Project>,
+        active_buffer: Entity<Buffer>,
         position: language::Anchor,
         allow_jump: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
-        let snapshot = active_buffer.read(cx).snapshot();
-        let debug_info = self.sweep_ai_debug_info.clone();
-        let Some(api_token) = self.sweep_api_token.clone() else {
-            return Task::ready(Ok(None));
-        };
-        let full_path: Arc<Path> = snapshot
-            .file()
-            .map(|file| file.full_path(cx))
-            .unwrap_or_else(|| "untitled".into())
-            .into();
-
-        let project_file = project::File::from_dyn(snapshot.file());
-        let repo_name = project_file
-            .map(|file| file.worktree.read(cx).root_name_str())
-            .unwrap_or("untitled")
-            .into();
-        let offset = position.to_offset(&snapshot);
+        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 
-        let project_state = self.get_or_init_zeta_project(project, cx);
-        let events = project_state.events(cx);
+        self.get_or_init_zeta_project(&project, cx);
+        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
+        let events = zeta_project.events(cx);
         let has_events = !events.is_empty();
-        let recent_buffers = project_state.recent_paths.iter().cloned();
-        let http_client = cx.http_client();
-
-        let recent_buffer_snapshots = recent_buffers
-            .filter_map(|project_path| {
-                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
-                if active_buffer == &buffer {
-                    None
-                } else {
-                    Some(buffer.read(cx).snapshot())
-                }
-            })
-            .take(3)
-            .collect::<Vec<_>>();
-
-        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 
+        let snapshot = active_buffer.read(cx).snapshot();
         let cursor_point = position.to_point(&snapshot);
         let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
         let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
         let diagnostic_search_range =
             Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
-        let buffer_snapshotted_at = Instant::now();
-
-        let result = cx.background_spawn({
-            let snapshot = snapshot.clone();
-            let diagnostic_search_range = diagnostic_search_range.clone();
-            async move {
-                let text = snapshot.text();
-
-                let mut recent_changes = String::new();
-                for event in &events {
-                    sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
-                }
-
-                let mut file_chunks = recent_buffer_snapshots
-                    .into_iter()
-                    .map(|snapshot| {
-                        let end_point = Point::new(30, 0).min(snapshot.max_point());
-                        sweep_ai::FileChunk {
-                            content: snapshot.text_for_range(Point::zero()..end_point).collect(),
-                            file_path: snapshot
-                                .file()
-                                .map(|f| f.path().as_unix_str())
-                                .unwrap_or("untitled")
-                                .to_string(),
-                            start_line: 0,
-                            end_line: end_point.row as usize,
-                            timestamp: snapshot.file().and_then(|file| {
-                                Some(
-                                    file.disk_state()
-                                        .mtime()?
-                                        .to_seconds_and_nanos_for_persistence()?
-                                        .0,
-                                )
-                            }),
-                        }
-                    })
-                    .collect::<Vec<_>>();
-
-                let diagnostic_entries =
-                    snapshot.diagnostics_in_range(diagnostic_search_range, false);
-                let mut diagnostic_content = String::new();
-                let mut diagnostic_count = 0;
-
-                for entry in diagnostic_entries {
-                    let start_point: Point = entry.range.start;
-
-                    let severity = match entry.diagnostic.severity {
-                        DiagnosticSeverity::ERROR => "error",
-                        DiagnosticSeverity::WARNING => "warning",
-                        DiagnosticSeverity::INFORMATION => "info",
-                        DiagnosticSeverity::HINT => "hint",
-                        _ => continue,
-                    };
-
-                    diagnostic_count += 1;
-
-                    writeln!(
-                        &mut diagnostic_content,
-                        "{} at line {}: {}",
-                        severity,
-                        start_point.row + 1,
-                        entry.diagnostic.message
-                    )?;
-                }
-
-                if !diagnostic_content.is_empty() {
-                    file_chunks.push(sweep_ai::FileChunk {
-                        file_path: format!("Diagnostics for {}", full_path.display()),
-                        start_line: 0,
-                        end_line: diagnostic_count,
-                        content: diagnostic_content,
-                        timestamp: None,
-                    });
-                }
-
-                let request_body = sweep_ai::AutocompleteRequest {
-                    debug_info,
-                    repo_name,
-                    file_path: full_path.clone(),
-                    file_contents: text.clone(),
-                    original_file_contents: text,
-                    cursor_position: offset,
-                    recent_changes: recent_changes.clone(),
-                    changes_above_cursor: true,
-                    multiple_suggestions: false,
-                    branch: None,
-                    file_chunks,
-                    retrieval_chunks: vec![],
-                    recent_user_actions: vec![],
-                    // TODO
-                    privacy_mode_enabled: false,
-                };
 
-                let mut buf: Vec<u8> = Vec::new();
-                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
-                serde_json::to_writer(writer, &request_body)?;
-                let body: AsyncBody = buf.into();
-
-                let inputs = EditPredictionInputs {
-                    events,
-                    included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
-                        path: full_path.clone(),
-                        max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
-                        excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
-                            start_line: cloud_llm_client::predict_edits_v3::Line(0),
-                            text: request_body.file_contents.into(),
-                        }],
-                    }],
-                    cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                        column: cursor_point.column,
-                        line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
-                    },
-                    cursor_path: full_path.clone(),
-                };
-
-                const SWEEP_API_URL: &str =
-                    "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
-
-                let request = http_client::Request::builder()
-                    .uri(SWEEP_API_URL)
-                    .header("Content-Type", "application/json")
-                    .header("Authorization", format!("Bearer {}", api_token))
-                    .header("Connection", "keep-alive")
-                    .header("Content-Encoding", "br")
-                    .method(Method::POST)
-                    .body(body)?;
-
-                let mut response = http_client.send(request).await?;
-
-                let mut body: Vec<u8> = Vec::new();
-                response.body_mut().read_to_end(&mut body).await?;
-
-                let response_received_at = Instant::now();
-                if !response.status().is_success() {
-                    anyhow::bail!(
-                        "Request failed with status: {:?}\nBody: {}",
-                        response.status(),
-                        String::from_utf8_lossy(&body),
-                    );
-                };
-
-                let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
-
-                let old_text = snapshot
-                    .text_for_range(response.start_index..response.end_index)
-                    .collect::<String>();
-                let edits = language::text_diff(&old_text, &response.completion)
-                    .into_iter()
-                    .map(|(range, text)| {
-                        (
-                            snapshot.anchor_after(response.start_index + range.start)
-                                ..snapshot.anchor_before(response.start_index + range.end),
-                            text,
-                        )
-                    })
-                    .collect::<Vec<_>>();
-
-                anyhow::Ok((
-                    response.autocomplete_id,
-                    edits,
-                    snapshot,
-                    response_received_at,
-                    inputs,
-                ))
-            }
-        });
-
-        let buffer = active_buffer.clone();
-        let project = project.clone();
-        let active_buffer = active_buffer.clone();
+        let task = match self.edit_prediction_model {
+            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
+                self,
+                &project,
+                &active_buffer,
+                snapshot.clone(),
+                position,
+                events,
+                cx,
+            ),
+            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
+                &project,
+                &active_buffer,
+                snapshot.clone(),
+                position,
+                events,
+                cx,
+            ),
+            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
+                &project,
+                &active_buffer,
+                snapshot.clone(),
+                position,
+                events,
+                &zeta_project.recent_paths,
+                diagnostic_search_range.clone(),
+                cx,
+            ),
+        };
 
         cx.spawn(async move |this, cx| {
-            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+            let prediction = task
+                .await?
+                .filter(|prediction| !prediction.edits.is_empty());
 
-            if edits.is_empty() {
+            if prediction.is_none() && allow_jump {
+                let cursor_point = position.to_point(&snapshot);
                 if has_events
-                    && allow_jump
                     && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
-                        active_buffer,
+                        active_buffer.clone(),
                         &snapshot,
                         diagnostic_search_range,
                         cursor_point,
@@ -1436,9 +1261,9 @@ impl Zeta {
                 {
                     return this
                         .update(cx, |this, cx| {
-                            this.request_prediction_with_sweep(
-                                &project,
-                                &jump_buffer,
+                            this.request_prediction_internal(
+                                project,
+                                jump_buffer,
                                 jump_position,
                                 false,
                                 cx,
@@ -1450,19 +1275,7 @@ impl Zeta {
                 return anyhow::Ok(None);
             }
 
-            anyhow::Ok(
-                EditPrediction::new(
-                    EditPredictionId(id.into()),
-                    &buffer,
-                    &old_snapshot,
-                    edits.into(),
-                    buffer_snapshotted_at,
-                    response_received_at,
-                    inputs,
-                    cx,
-                )
-                .await,
-            )
+            Ok(prediction)
         })
     }
 
@@ -1549,7 +1362,9 @@ impl Zeta {
         &mut self,
         project: &Entity<Project>,
         active_buffer: &Entity<Buffer>,
+        active_snapshot: BufferSnapshot,
         position: language::Anchor,
+        events: Vec<Arc<Event>>,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
         let project_state = self.projects.get(&project.entity_id());
@@ -1561,7 +1376,6 @@ impl Zeta {
                 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
         });
         let options = self.options.clone();
-        let active_snapshot = active_buffer.read(cx).snapshot();
         let buffer_snapshotted_at = Instant::now();
         let Some(excerpt_path) = active_snapshot
             .file()
@@ -1579,10 +1393,6 @@ impl Zeta {
             .collect::<Vec<_>>();
         let debug_tx = self.debug_tx.clone();
 
-        let events = project_state
-            .map(|state| state.events(cx))
-            .unwrap_or_default();
-
         let diagnostics = active_snapshot.diagnostic_sets().clone();
 
         let file = active_buffer.read(cx).file();

crates/zeta/src/zeta1.rs 🔗

@@ -32,19 +32,17 @@ pub(crate) fn request_prediction_with_zeta1(
     zeta: &mut Zeta,
     project: &Entity<Project>,
     buffer: &Entity<Buffer>,
+    snapshot: BufferSnapshot,
     position: language::Anchor,
+    events: Vec<Arc<Event>>,
     cx: &mut Context<Zeta>,
 ) -> Task<Result<Option<EditPrediction>>> {
     let buffer = buffer.clone();
     let buffer_snapshotted_at = Instant::now();
-    let snapshot = buffer.read(cx).snapshot();
     let client = zeta.client.clone();
     let llm_token = zeta.llm_token.clone();
     let app_version = AppVersion::global(cx);
 
-    let zeta_project = zeta.get_or_init_zeta_project(project, cx);
-    let events = Arc::new(zeta_project.events(cx));
-
     let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
         let can_collect_file = zeta.can_collect_file(project, file, cx);
         let git_info = if can_collect_file {

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -42,43 +42,48 @@ actions!(
 
 pub fn init(cx: &mut App) {
     cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
-        workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
-            let project = workspace.project();
-            workspace.split_item(
-                SplitDirection::Right,
-                Box::new(cx.new(|cx| {
-                    Zeta2Inspector::new(
-                        &project,
-                        workspace.client(),
-                        workspace.user_store(),
-                        window,
-                        cx,
-                    )
-                })),
-                window,
-                cx,
-            );
-        });
-    })
-    .detach();
-
-    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
-        workspace.register_action(move |workspace, _: &OpenZeta2ContextView, window, cx| {
-            let project = workspace.project();
-            workspace.split_item(
-                SplitDirection::Right,
-                Box::new(cx.new(|cx| {
-                    Zeta2ContextView::new(
-                        project.clone(),
-                        workspace.client(),
-                        workspace.user_store(),
-                        window,
-                        cx,
-                    )
-                })),
-                window,
-                cx,
-            );
+        workspace.register_action_renderer(|div, _, _, cx| {
+            let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
+            div.when(has_flag, |div| {
+                div.on_action(
+                    cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
+                        let project = workspace.project();
+                        workspace.split_item(
+                            SplitDirection::Right,
+                            Box::new(cx.new(|cx| {
+                                Zeta2Inspector::new(
+                                    &project,
+                                    workspace.client(),
+                                    workspace.user_store(),
+                                    window,
+                                    cx,
+                                )
+                            })),
+                            window,
+                            cx,
+                        )
+                    }),
+                )
+                .on_action(cx.listener(
+                    move |workspace, _: &OpenZeta2ContextView, window, cx| {
+                        let project = workspace.project();
+                        workspace.split_item(
+                            SplitDirection::Right,
+                            Box::new(cx.new(|cx| {
+                                Zeta2ContextView::new(
+                                    project.clone(),
+                                    workspace.client(),
+                                    workspace.user_store(),
+                                    window,
+                                    cx,
+                                )
+                            })),
+                            window,
+                            cx,
+                        );
+                    },
+                ))
+            })
         });
     })
     .detach();