diff --git a/Cargo.lock b/Cargo.lock index 4f9a3f26e9a20df498bd3b735cfec54aa77c77cd..873fcdbb63fcabee0f722ae27beac486d0ce8670 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5314,13 +5314,13 @@ dependencies = [ "serde_json", "settings", "supermaven", - "sweep_ai", "telemetry", "theme", "ui", "workspace", "zed_actions", "zeta", + "zeta2", ] [[package]] @@ -16590,33 +16590,6 @@ dependencies = [ "zeno", ] -[[package]] -name = "sweep_ai" -version = "0.1.0" -dependencies = [ - "anyhow", - "arrayvec", - "brotli", - "client", - "collections", - "edit_prediction", - "feature_flags", - "futures 0.3.31", - "gpui", - "http_client", - "indoc", - "language", - "project", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "tree-sitter-rust", - "util", - "workspace", - "zlog", -] - [[package]] name = "symphonia" version = "0.5.5" @@ -21343,7 +21316,6 @@ dependencies = [ "snippets_ui", "supermaven", "svg_preview", - "sweep_ai", "sysinfo 0.37.2", "system_specs", "tab_switcher", @@ -21754,6 +21726,7 @@ version = "0.1.0" dependencies = [ "anyhow", "arrayvec", + "brotli", "chrono", "client", "clock", @@ -21864,7 +21837,6 @@ dependencies = [ "shellexpand 2.1.2", "smol", "soa-rs", - "sweep_ai", "terminal_view", "toml 0.8.23", "util", diff --git a/Cargo.toml b/Cargo.toml index 03a86c9e25bd8f5a1bb8498b3cb0169055672ad4..a4c9caccd9539ffde7d57d36dcfaf4cf162c7e92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -165,7 +165,6 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", - "crates/sweep_ai", "crates/codestral", "crates/svg_preview", "crates/system_specs", @@ -399,7 +398,6 @@ streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree" } supermaven = { path = "crates/supermaven" } supermaven_api = { path = "crates/supermaven_api" } -sweep_ai = { path = "crates/sweep_ai" } codestral = { path = "crates/codestral" } system_specs = { path = "crates/system_specs" } tab_switcher = { path = "crates/tab_switcher" } diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index 3ed3d9411510ad2d978b221d8cb3412465a66879..9877b70161b3fdd16a0f667d85085520c9fe4f86 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -30,12 +30,12 @@ project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true -sweep_ai.workspace = true telemetry.workspace = true ui.workspace = true workspace.workspace = true zed_actions.workspace = true zeta.workspace = true +zeta2.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 4f5f60d5a2328e5e56d65e87add7338b7e572346..ba00e95c488dc8e8704274638087c8334f96e1a3 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -28,7 +28,6 @@ use std::{ time::Duration, }; use supermaven::{AccountStatus, Supermaven}; -use sweep_ai::SweepFeatureFlag; use ui::{ Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton, IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*, @@ -39,6 +38,7 @@ use workspace::{ }; use zed_actions::OpenBrowser; use zeta::RateCompletions; +use zeta2::SweepFeatureFlag; actions!( edit_prediction, diff --git a/crates/sweep_ai/Cargo.toml b/crates/sweep_ai/Cargo.toml deleted file mode 100644 index 4edf7ea1bb6af9a6657ccfe310c0253b118ec2e7..0000000000000000000000000000000000000000 --- a/crates/sweep_ai/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "sweep_ai" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" -exclude = ["fixtures"] - -[lints] -workspace = true - -[lib] -path = "src/sweep_ai.rs" -doctest = false - -[dependencies] -anyhow.workspace = true -arrayvec.workspace = true -brotli.workspace = true -client.workspace = true -collections.workspace = true -edit_prediction.workspace = true -feature_flags.workspace = true -futures.workspace = true -gpui.workspace = true -http_client.workspace = true -language.workspace = true -project.workspace = true -release_channel.workspace = true -serde.workspace = true -serde_json.workspace = true -util.workspace = true -workspace.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } -http_client = { workspace = true, features = ["test-support"] } -indoc.workspace = true -language = { workspace = true, features = ["test-support"] } -reqwest_client = { workspace = true, features = ["test-support"] } -tree-sitter-rust.workspace = true -workspace = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/sweep_ai/LICENSE-GPL b/crates/sweep_ai/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/sweep_ai/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/sweep_ai/src/sweep_ai.rs b/crates/sweep_ai/src/sweep_ai.rs deleted file mode 100644 index 1b4c92120d866a218987f36161e9520a0f3f703a..0000000000000000000000000000000000000000 --- a/crates/sweep_ai/src/sweep_ai.rs +++ /dev/null @@ -1,784 +0,0 @@ -mod api; - -use anyhow::{Context as _, Result}; -use arrayvec::ArrayVec; -use client::telemetry; -use collections::HashMap; -use feature_flags::FeatureFlag; -use futures::AsyncReadExt as _; -use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity}; -use http_client::{AsyncBody, Method}; -use language::{ - Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff, -}; -use project::{Project, ProjectPath}; -use release_channel::{AppCommitSha, AppVersion}; -use std::collections::{VecDeque, hash_map}; -use std::fmt::{self, Display}; -use std::mem; -use std::{ - cmp, - fmt::Write, - ops::Range, - path::Path, - sync::Arc, - time::{Duration, Instant}, -}; -use util::ResultExt; -use util::rel_path::RelPath; -use workspace::Workspace; - -use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk}; - -const CHANGE_GROUPING_LINE_SPAN: u32 = 8; -const MAX_EVENT_COUNT: usize = 6; - -const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; - -pub struct SweepFeatureFlag; - -impl FeatureFlag for SweepFeatureFlag { - const NAME: &str = "sweep-ai"; -} - -#[derive(Clone)] -struct SweepAiGlobal(Entity); - -impl Global for SweepAiGlobal {} - -#[derive(Clone)] -pub struct EditPrediction { - pub id: EditPredictionId, - pub path: Arc, - pub edits: Arc<[(Range, Arc)]>, - pub snapshot: BufferSnapshot, - pub edit_preview: EditPreview, -} - -impl EditPrediction { - fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { - edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) - } -} - -impl fmt::Debug for EditPrediction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("EditPrediction") - .field("path", &self.path) - .field("edits", &self.edits) - .finish_non_exhaustive() - } -} - -#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(String); - -impl Display for EditPredictionId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -pub struct SweepAi { - projects: HashMap, - debug_info: Arc, - api_token: Option, -} - -struct SweepAiProject { - events: VecDeque, - registered_buffers: HashMap, -} - -impl SweepAi { - pub fn global(cx: &mut App) -> Option> { - cx.try_global::() - .map(|global| global.0.clone()) - } - - pub fn register(cx: &mut App) -> Entity { - Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(cx)); - cx.set_global(SweepAiGlobal(entity.clone())); - entity - }) - } - - pub fn clear_history(&mut self) { - for sweep_ai_project in self.projects.values_mut() { - sweep_ai_project.events.clear(); - } - } - - pub fn new(cx: &mut Context) -> Self { - Self { - api_token: std::env::var("SWEEP_AI_TOKEN").ok(), - projects: HashMap::default(), - debug_info: format!( - "Zed v{version} ({sha}) - OS: {os} - Zed v{version}", - version = AppVersion::global(cx), - sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()), - os = telemetry::os_name(), - ) - .into(), - } - } - - fn get_or_init_sweep_ai_project( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> &mut SweepAiProject { - let project_id = project.entity_id(); - match self.projects.entry(project_id) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - cx.observe_release(project, move |this, _, _cx| { - this.projects.remove(&project_id); - }) - .detach(); - entry.insert(SweepAiProject { - events: VecDeque::with_capacity(MAX_EVENT_COUNT), - registered_buffers: HashMap::default(), - }) - } - } - } - - pub fn register_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx); - Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); - } - - fn register_buffer_impl<'a>( - sweep_ai_project: &'a mut SweepAiProject, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) -> &'a mut RegisteredBuffer { - let buffer_id = buffer.entity_id(); - match sweep_ai_project.registered_buffers.entry(buffer_id) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - let snapshot = buffer.read(cx).snapshot(); - let project_entity_id = project.entity_id(); - entry.insert(RegisteredBuffer { - snapshot, - _subscriptions: [ - cx.subscribe(buffer, { - let project = project.downgrade(); - move |this, buffer, event, cx| { - if let language::BufferEvent::Edited = event - && let Some(project) = project.upgrade() - { - this.report_changes_for_buffer(&buffer, &project, cx); - } - } - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id) - else { - return; - }; - sweep_ai_project.registered_buffers.remove(&buffer_id); - }), - ], - }) - } - } - } - - pub fn request_completion( - &mut self, - project: &Entity, - recent_buffers: impl Iterator, - active_buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task>> { - let snapshot = active_buffer.read(cx).snapshot(); - let debug_info = self.debug_info.clone(); - let Some(api_token) = self.api_token.clone() else { - return Task::ready(Ok(None)); - }; - let full_path: Arc = snapshot - .file() - .map(|file| file.full_path(cx)) - .unwrap_or_else(|| "untitled".into()) - .into(); - - let project_file = project::File::from_dyn(snapshot.file()); - let repo_name = project_file - .map(|file| file.worktree.read(cx).root_name_str()) - .unwrap_or("untitled") - .into(); - let offset = position.to_offset(&snapshot); - - let project_state = self.get_or_init_sweep_ai_project(project, cx); - let events = project_state.events.clone(); - let http_client = cx.http_client(); - - let recent_buffer_snapshots = recent_buffers - .filter_map(|project_path| { - let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; - if active_buffer == &buffer { - None - } else { - Some(buffer.read(cx).snapshot()) - } - }) - .take(3) - .collect::>(); - - let result = cx.background_spawn({ - let full_path = full_path.clone(); - async move { - let text = snapshot.text(); - - let mut recent_changes = String::new(); - - for event in events { - writeln!(&mut recent_changes, "{event}")?; - } - - let file_chunks = recent_buffer_snapshots - .into_iter() - .map(|snapshot| { - let end_point = language::Point::new(30, 0).min(snapshot.max_point()); - FileChunk { - content: snapshot - .text_for_range(language::Point::zero()..end_point) - .collect(), - file_path: snapshot - .file() - .map(|f| f.path().as_unix_str()) - .unwrap_or("untitled") - .to_string(), - start_line: 0, - end_line: end_point.row as usize, - timestamp: snapshot.file().and_then(|file| { - Some( - file.disk_state() - .mtime()? - .to_seconds_and_nanos_for_persistence()? - .0, - ) - }), - } - }) - .collect(); - - eprintln!("{recent_changes}"); - - let request_body = AutocompleteRequest { - debug_info, - repo_name, - file_path: full_path.clone(), - file_contents: text.clone(), - original_file_contents: text, - cursor_position: offset, - recent_changes: recent_changes.clone(), - changes_above_cursor: true, - multiple_suggestions: false, - branch: None, - file_chunks, - retrieval_chunks: vec![], - recent_user_actions: vec![], - // TODO - privacy_mode_enabled: false, - }; - - let mut buf: Vec = Vec::new(); - let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); - serde_json::to_writer(writer, &request_body)?; - let body: AsyncBody = buf.into(); - - let request = http_client::Request::builder() - .uri(SWEEP_API_URL) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_token)) - .header("Connection", "keep-alive") - .header("Content-Encoding", "br") - .method(Method::POST) - .body(body)?; - - let mut response = http_client.send(request).await?; - - let mut body: Vec = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - - if !response.status().is_success() { - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - String::from_utf8_lossy(&body), - ); - }; - - let response: AutocompleteResponse = serde_json::from_slice(&body)?; - - let old_text = snapshot - .text_for_range(response.start_index..response.end_index) - .collect::(); - let edits = text_diff(&old_text, &response.completion) - .into_iter() - .map(|(range, text)| { - ( - snapshot.anchor_after(response.start_index + range.start) - ..snapshot.anchor_before(response.start_index + range.end), - text, - ) - }) - .collect::>(); - - anyhow::Ok((response.autocomplete_id, edits, snapshot)) - } - }); - - let buffer = active_buffer.clone(); - - cx.spawn(async move |_, cx| { - let (id, edits, old_snapshot) = result.await?; - - if edits.is_empty() { - return anyhow::Ok(None); - } - - let Some((edits, new_snapshot, preview_task)) = - buffer.read_with(cx, |buffer, cx| { - let new_snapshot = buffer.snapshot(); - - let edits: Arc<[(Range, Arc)]> = - edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)? - .into(); - let preview_task = buffer.preview_edits(edits.clone(), cx); - - Some((edits, new_snapshot, preview_task)) - })? - else { - return anyhow::Ok(None); - }; - - let prediction = EditPrediction { - id: EditPredictionId(id), - path: full_path, - edits, - snapshot: new_snapshot, - edit_preview: preview_task.await, - }; - - anyhow::Ok(Some(prediction)) - }) - } - - fn report_changes_for_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx); - let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); - - let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version == registered_buffer.snapshot.version { - return; - } - - let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - let end_edit_anchor = new_snapshot - .anchored_edits_since::(&old_snapshot.version) - .last() - .map(|(_, range)| range.end); - let events = &mut sweep_ai_project.events; - - if let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - end_edit_anchor: last_end_edit_anchor, - .. - }) = events.back_mut() - { - let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() - == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version; - - let should_coalesce = is_next_snapshot_of_same_buffer - && end_edit_anchor - .as_ref() - .zip(last_end_edit_anchor.as_ref()) - .is_some_and(|(a, b)| { - let a = a.to_point(&new_snapshot); - let b = b.to_point(&new_snapshot); - a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN - }); - - if should_coalesce { - *last_end_edit_anchor = end_edit_anchor; - *last_new_snapshot = new_snapshot; - return; - } - } - - if events.len() >= MAX_EVENT_COUNT { - events.pop_front(); - } - - events.push_back(Event::BufferChange { - old_snapshot, - new_snapshot, - end_edit_anchor, - }); - } -} - -struct RegisteredBuffer { - snapshot: BufferSnapshot, - _subscriptions: [gpui::Subscription; 2], -} - -#[derive(Clone)] -pub enum Event { - BufferChange { - old_snapshot: BufferSnapshot, - new_snapshot: BufferSnapshot, - end_edit_anchor: Option, - }, -} - -impl Display for Event { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Event::BufferChange { - old_snapshot, - new_snapshot, - .. - } => { - let old_path = old_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - let new_path = new_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - if old_path != new_path { - // TODO confirm how to do this for sweep - // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?; - } - - let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); - if !diff.is_empty() { - write!( - f, - "File: {}:\n{}\n", - new_path.display(util::paths::PathStyle::Posix), - diff - )? - } - - fmt::Result::Ok(()) - } - } - } -} - -#[derive(Debug, Clone)] -struct CurrentEditPrediction { - buffer_id: EntityId, - completion: EditPrediction, -} - -impl CurrentEditPrediction { - fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { - if self.buffer_id != old_completion.buffer_id { - return true; - } - - let Some(old_edits) = old_completion.completion.interpolate(snapshot) else { - return true; - }; - let Some(new_edits) = self.completion.interpolate(snapshot) else { - return false; - }; - - if old_edits.len() == 1 && new_edits.len() == 1 { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text.as_ref()) - } else { - true - } - } -} - -struct PendingCompletion { - id: usize, - _task: Task<()>, -} - -pub struct SweepAiEditPredictionProvider { - workspace: WeakEntity, - sweep_ai: Entity, - pending_completions: ArrayVec, - next_pending_completion_id: usize, - current_completion: Option, - last_request_timestamp: Instant, - project: Entity, -} - -impl SweepAiEditPredictionProvider { - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - - pub fn new( - sweep_ai: Entity, - workspace: WeakEntity, - project: Entity, - ) -> Self { - Self { - sweep_ai, - pending_completions: ArrayVec::new(), - next_pending_completion_id: 0, - current_completion: None, - last_request_timestamp: Instant::now(), - project, - workspace, - } - } -} - -impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider { - fn name() -> &'static str { - "zed-predict" - } - - fn display_name() -> &'static str { - "Zed's Edit Predictions" - } - - fn show_completions_in_menu() -> bool { - true - } - - fn show_tab_accept_marker() -> bool { - true - } - - fn is_enabled( - &self, - _buffer: &Entity, - _cursor_position: language::Anchor, - cx: &App, - ) -> bool { - self.sweep_ai.read(cx).api_token.is_some() - } - - fn is_refreshing(&self) -> bool { - !self.pending_completions.is_empty() - } - - fn refresh( - &mut self, - buffer: Entity, - position: language::Anchor, - _debounce: bool, - cx: &mut Context, - ) { - if let Some(current_completion) = self.current_completion.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if current_completion - .completion - .interpolate(&snapshot) - .is_some() - { - return; - } - } - - let pending_completion_id = self.next_pending_completion_id; - self.next_pending_completion_id += 1; - let last_request_timestamp = self.last_request_timestamp; - - let project = self.project.clone(); - let workspace = self.workspace.clone(); - let task = cx.spawn(async move |this, cx| { - if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) - .checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - let completion_request = this.update(cx, |this, cx| { - this.last_request_timestamp = Instant::now(); - - this.sweep_ai.update(cx, |sweep_ai, cx| { - let Some(recent_buffers) = workspace - .read_with(cx, |workspace, cx| { - workspace.recent_navigation_history_iter(cx) - }) - .log_err() - else { - return Task::ready(Ok(None)); - }; - sweep_ai.request_completion( - &project, - recent_buffers.map(move |(project_path, _)| project_path), - &buffer, - position, - cx, - ) - }) - }); - - let completion = match completion_request { - Ok(completion_request) => { - let completion_request = completion_request.await; - completion_request.map(|c| { - c.map(|completion| CurrentEditPrediction { - buffer_id: buffer.entity_id(), - completion, - }) - }) - } - Err(error) => Err(error), - }; - - let Some(new_completion) = completion - .context("edit prediction failed") - .log_err() - .flatten() - else { - this.update(cx, |this, cx| { - if this.pending_completions[0].id == pending_completion_id { - this.pending_completions.remove(0); - } else { - this.pending_completions.clear(); - } - - cx.notify(); - }) - .ok(); - return; - }; - - this.update(cx, |this, cx| { - if this.pending_completions[0].id == pending_completion_id { - this.pending_completions.remove(0); - } else { - this.pending_completions.clear(); - } - - if let Some(old_completion) = this.current_completion.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if new_completion.should_replace_completion(old_completion, &snapshot) { - this.current_completion = Some(new_completion); - } - } else { - this.current_completion = Some(new_completion); - } - - cx.notify(); - }) - .ok(); - }); - - // We always maintain at most two pending completions. When we already - // have two, we replace the newest one. - if self.pending_completions.len() <= 1 { - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - _task: task, - }); - } else if self.pending_completions.len() == 2 { - self.pending_completions.pop(); - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - _task: task, - }); - } - } - - fn cycle( - &mut self, - _buffer: Entity, - _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, - _cx: &mut Context, - ) { - // Right now we don't support cycling. - } - - fn accept(&mut self, _cx: &mut Context) { - self.pending_completions.clear(); - } - - fn discard(&mut self, _cx: &mut Context) { - self.pending_completions.clear(); - self.current_completion.take(); - } - - fn suggest( - &mut self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) -> Option { - let CurrentEditPrediction { - buffer_id, - completion, - .. - } = self.current_completion.as_mut()?; - - // Invalidate previous completion if it was generated for a different buffer. - if *buffer_id != buffer.entity_id() { - self.current_completion.take(); - return None; - } - - let buffer = buffer.read(cx); - let Some(edits) = completion.interpolate(&buffer.snapshot()) else { - self.current_completion.take(); - return None; - }; - - let cursor_row = cursor_position.to_point(buffer).row; - let (closest_edit_ix, (closest_edit_range, _)) = - edits.iter().enumerate().min_by_key(|(_, (range, _))| { - let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); - let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); - cmp::min(distance_from_start, distance_from_end) - })?; - - let mut edit_start_ix = closest_edit_ix; - for (range, _) in edits[..edit_start_ix].iter().rev() { - let distance_from_closest_edit = - closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_start_ix -= 1; - } else { - break; - } - } - - let mut edit_end_ix = closest_edit_ix + 1; - for (range, _) in &edits[edit_end_ix..] { - let distance_from_closest_edit = - range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_end_ix += 1; - } else { - break; - } - } - - Some(edit_prediction::EditPrediction::Local { - id: Some(completion.id.to_string().into()), - edits: edits[edit_start_ix..edit_end_ix].to_vec(), - edit_preview: Some(completion.edit_preview.clone()), - }) - } -} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 78d650793ebb98cdc9a52e50adc9fa57c7c24b4f..b0a4f6344c9a710af5cf6a391d7b2c0f03efe7b1 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -133,7 +133,6 @@ snippet_provider.workspace = true snippets_ui.workspace = true supermaven.workspace = true svg_preview.workspace = true -sweep_ai.workspace = true sysinfo.workspace = true tab_switcher.workspace = true task.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 1723ca91f143c8529e14e24e0bdd85dc7b1c14d4..250a2b5a0e585d5acad7658a25f89bce12f766d2 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -10,9 +10,9 @@ use language_models::MistralLanguageModelProvider; use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore}; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; -use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag}; use ui::Window; use zeta::ZetaEditPredictionProvider; +use zeta2::SweepFeatureFlag; use zeta2::Zeta2FeatureFlag; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { @@ -203,55 +203,41 @@ fn assign_edit_prediction_provider( let provider = cx.new(|_| CodestralCompletionProvider::new(http_client)); editor.set_edit_prediction_provider(Some(provider), window, cx); } - EditPredictionProvider::Experimental(name) => { - if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - && cx.has_flag::() - { - if let Some(project) = editor.project() - && let Some(workspace) = editor.workspace() + value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { + if let Some(project) = editor.project() { + let mut worktree = None; + if let Some(buffer) = &singleton_buffer + && let Some(file) = buffer.read(cx).file() { - let sweep_ai = sweep_ai::SweepAi::register(cx); - - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - { - sweep_ai.update(cx, |sweep_ai, cx| { - sweep_ai.register_buffer(buffer, project, cx); - }); - } + let id = file.worktree_id(cx); + worktree = project.read(cx).worktree_for_id(id, cx); + } - let provider = cx.new(|_| { - sweep_ai::SweepAiEditPredictionProvider::new( - sweep_ai, - workspace.downgrade(), + if let EditPredictionProvider::Experimental(name) = value + && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + let zeta2 = zeta2::Zeta::global(client, &user_store, cx); + let provider = cx.new(|cx| { + zeta2::ZetaEditPredictionProvider::new( project.clone(), + &client, + &user_store, + cx, ) }); - editor.set_edit_prediction_provider(Some(provider), window, cx); - } - } else { - editor.set_edit_prediction_provider::( - None, window, cx, - ); - } - } - EditPredictionProvider::Zed => { - if user_store.read(cx).current_user().is_some() { - let mut worktree = None; - if let Some(buffer) = &singleton_buffer - && let Some(file) = buffer.read(cx).file() - { - let id = file.worktree_id(cx); - if let Some(inner_worktree) = editor - .project() - .and_then(|project| project.read(cx).worktree_for_id(id, cx)) + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() { - worktree = Some(inner_worktree); + zeta2.update(cx, |zeta, cx| { + zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep); + zeta.register_buffer(buffer, project, cx); + }); } - } - if let Some(project) = editor.project() { + editor.set_edit_prediction_provider(Some(provider), window, cx); + } else if user_store.read(cx).current_user().is_some() { if cx.has_flag::() { let zeta = zeta2::Zeta::global(client, &user_store, cx); let provider = cx.new(|cx| { @@ -268,6 +254,9 @@ fn assign_edit_prediction_provider( && buffer.read(cx).file().is_some() { zeta.update(cx, |zeta, cx| { + zeta.set_edit_prediction_model( + zeta2::ZetaEditPredictionModel::ZedCloud, + ); zeta.register_buffer(buffer, project, cx); }); } diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 1eef507e6def3d80560ff1515623d0c42687d74a..0f156f68fac881d65d76f178315f40df1dba9d7f 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -17,6 +17,7 @@ eval-support = [] [dependencies] anyhow.workspace = true arrayvec.workspace = true +brotli.workspace = true chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index a19e7f9a1da5e1808c48e3ce0469d8b390698760..1b82826f663b092b5763935d9a7a2d4bb9607ebf 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -12,7 +12,7 @@ use language::ToPoint as _; use project::Project; use util::ResultExt as _; -use crate::{BufferEditPrediction, Zeta}; +use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel}; pub struct ZetaEditPredictionProvider { zeta: Entity, @@ -85,9 +85,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { &self, _buffer: &Entity, _cursor_position: language::Anchor, - _cx: &App, + cx: &App, ) -> bool { - true + let zeta = self.zeta.read(cx); + if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep { + zeta.sweep_api_token.is_some() + } else { + true + } } fn is_refreshing(&self) -> bool { diff --git a/crates/sweep_ai/src/api.rs b/crates/zeta2/src/sweep_ai.rs similarity index 59% rename from crates/sweep_ai/src/api.rs rename to crates/zeta2/src/sweep_ai.rs index edb392885e476e3924d285613af1f0a4e8be8599..c56d7409fa212734c5f5a73a6b24319c27c7494f 100644 --- a/crates/sweep_ai/src/api.rs +++ b/crates/zeta2/src/sweep_ai.rs @@ -1,6 +1,8 @@ +use std::fmt; use std::{path::Path, sync::Arc}; use serde::{Deserialize, Serialize}; +use util::rel_path::RelPath; #[derive(Debug, Clone, Serialize)] pub struct AutocompleteRequest { @@ -88,3 +90,49 @@ pub struct AdditionalCompletion { pub logprobs: Option, pub finish_reason: Option, } + +pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result { + match event { + crate::Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + let old_path = old_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(RelPath::unix("untitled").unwrap()); + let new_path = new_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(RelPath::unix("untitled").unwrap()); + if old_path != new_path { + // TODO confirm how to do this for sweep + // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?; + } + + let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); + if !diff.is_empty() { + write!( + f, + "File: {}:\n{}\n", + new_path.display(util::paths::PathStyle::Posix), + diff + )? + } + + fmt::Result::Ok(()) + } + } +} + +pub(crate) fn debug_info(cx: &gpui::App) -> Arc { + format!( + "Zed v{version} ({sha}) - OS: {os} - Zed v{version}", + version = release_channel::AppVersion::global(cx), + sha = release_channel::AppCommitSha::try_global(cx) + .map_or("unknown".to_string(), |sha| sha.full()), + os = client::telemetry::os_name(), + ) + .into() +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 099cd95134ec3d1fd59bbc33306bc439c0a8ee1a..6eacc5190f403594ad20f7365512b011d2226719 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -22,30 +22,31 @@ use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, }; -use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; +use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use open_ai::FunctionDefinition; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; -use std::env; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; +use std::{env, mem}; use thiserror::Error; use util::rel_path::RelPathBuf; -use util::{LogErrorFuture, TryFutureExt}; +use util::{LogErrorFuture, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod assemble_excerpts; mod prediction; mod provider; pub mod retrieval_search; +mod sweep_ai; pub mod udiff; mod xml_edits; @@ -55,8 +56,15 @@ pub use crate::prediction::EditPredictionId; pub use provider::ZetaEditPredictionProvider; /// Maximum number of events to track. -const MAX_EVENT_COUNT: usize = 16; +const EVENT_COUNT_MAX_SWEEP: usize = 6; +const EVENT_COUNT_MAX_ZETA: usize = 16; +const CHANGE_GROUPING_LINE_SPAN: u32 = 8; +pub struct SweepFeatureFlag; + +impl FeatureFlag for SweepFeatureFlag { + const NAME: &str = "sweep-ai"; +} pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { max_bytes: 512, min_bytes: 128, @@ -143,6 +151,15 @@ pub struct Zeta { debug_tx: Option>, #[cfg(feature = "eval-support")] eval_cache: Option>, + edit_prediction_model: ZetaEditPredictionModel, + sweep_api_token: Option, + sweep_ai_debug_info: Arc, +} + +#[derive(PartialEq, Eq)] +pub enum ZetaEditPredictionModel { + ZedCloud, + Sweep, } #[derive(Debug, Clone, PartialEq)] @@ -219,12 +236,14 @@ pub type RequestDebugInfo = predict_edits_v3::DebugInfo; struct ZetaProject { syntax_index: Option>, events: VecDeque, + recent_paths: VecDeque, registered_buffers: HashMap, current_prediction: Option, context: Option, Vec>>>, refresh_context_task: Option>>>, refresh_context_debounce_task: Option>>, refresh_context_timestamp: Option, + _subscription: gpui::Subscription, } #[derive(Debug, Clone)] @@ -287,6 +306,7 @@ pub enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, + end_edit_anchor: Option, timestamp: Instant, }, } @@ -381,7 +401,19 @@ impl Zeta { debug_tx: None, #[cfg(feature = "eval-support")] eval_cache: None, + edit_prediction_model: ZetaEditPredictionModel::ZedCloud, + sweep_api_token: None, + sweep_ai_debug_info: sweep_ai::debug_info(cx), + } + } + + pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { + if model == ZetaEditPredictionModel::Sweep { + self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN") + .context("No SWEEP_AI_TOKEN environment variable set") + .log_err(); } + self.edit_prediction_model = model; } #[cfg(feature = "eval-support")] @@ -443,7 +475,7 @@ impl Zeta { self.user_store.read(cx).edit_prediction_usage() } - pub fn register_project(&mut self, project: &Entity, cx: &mut App) { + pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { self.get_or_init_zeta_project(project, cx); } @@ -460,7 +492,7 @@ impl Zeta { fn get_or_init_zeta_project( &mut self, project: &Entity, - cx: &mut App, + cx: &mut Context, ) -> &mut ZetaProject { self.projects .entry(project.entity_id()) @@ -473,12 +505,31 @@ impl Zeta { None }, events: VecDeque::new(), + recent_paths: VecDeque::new(), registered_buffers: HashMap::default(), current_prediction: None, context: None, refresh_context_task: None, refresh_context_debounce_task: None, refresh_context_timestamp: None, + _subscription: cx.subscribe(&project, |this, project, event, cx| { + // TODO [zeta2] init with recent paths + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event { + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = zeta_project + .recent_paths + .iter() + .position(|probe| probe == &path) + { + zeta_project.recent_paths.remove(ix); + } + zeta_project.recent_paths.push_front(path); + } + } + } + }), }) } @@ -525,66 +576,64 @@ impl Zeta { buffer: &Entity, project: &Entity, cx: &mut Context, - ) -> BufferSnapshot { - let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval; - let zeta_project = self.get_or_init_zeta_project(project, cx); - let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); + ) { + let event_count_max = match self.edit_prediction_model { + ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA, + ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP, + }; + + let sweep_ai_project = self.get_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version != registered_buffer.snapshot.version { - let old_snapshot = - std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - Self::push_event( - zeta_project, - buffer_change_grouping_interval, - Event::BufferChange { - old_snapshot, - new_snapshot: new_snapshot.clone(), - timestamp: Instant::now(), - }, - ); + if new_snapshot.version == registered_buffer.snapshot.version { + return; } - new_snapshot - } - - fn push_event( - zeta_project: &mut ZetaProject, - buffer_change_grouping_interval: Duration, - event: Event, - ) { - let events = &mut zeta_project.events; + let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + let end_edit_anchor = new_snapshot + .anchored_edits_since::(&old_snapshot.version) + .last() + .map(|(_, range)| range.end); + let events = &mut sweep_ai_project.events; - if buffer_change_grouping_interval > Duration::ZERO - && let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - timestamp: last_timestamp, - .. - }) = events.back_mut() + if let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + end_edit_anchor: last_end_edit_anchor, + .. + }) = events.back_mut() { - // Coalesce edits for the same buffer when they happen one after the other. - let Event::BufferChange { - old_snapshot, - new_snapshot, - timestamp, - } = &event; - - if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval - && old_snapshot.remote_id() == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version - { - *last_new_snapshot = new_snapshot.clone(); - *last_timestamp = *timestamp; + let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() + == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version; + + let should_coalesce = is_next_snapshot_of_same_buffer + && end_edit_anchor + .as_ref() + .zip(last_end_edit_anchor.as_ref()) + .is_some_and(|(a, b)| { + let a = a.to_point(&new_snapshot); + let b = b.to_point(&new_snapshot); + a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN + }); + + if should_coalesce { + *last_end_edit_anchor = end_edit_anchor; + *last_new_snapshot = new_snapshot; return; } } - if events.len() >= MAX_EVENT_COUNT { - // These are halved instead of popping to improve prompt caching. - events.drain(..MAX_EVENT_COUNT / 2); + if events.len() >= event_count_max { + events.pop_front(); } - events.push_back(event); + events.push_back(Event::BufferChange { + old_snapshot, + new_snapshot, + end_edit_anchor, + timestamp: Instant::now(), + }); } fn current_prediction_for_buffer( @@ -706,6 +755,203 @@ impl Zeta { active_buffer: &Entity, position: language::Anchor, cx: &mut Context, + ) -> Task>> { + match self.edit_prediction_model { + ZetaEditPredictionModel::ZedCloud => { + self.request_prediction_with_zed_cloud(project, active_buffer, position, cx) + } + ZetaEditPredictionModel::Sweep => { + self.request_prediction_with_sweep(project, active_buffer, position, cx) + } + } + } + + fn request_prediction_with_sweep( + &mut self, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task>> { + let snapshot = active_buffer.read(cx).snapshot(); + let debug_info = self.sweep_ai_debug_info.clone(); + let Some(api_token) = self.sweep_api_token.clone() else { + return Task::ready(Ok(None)); + }; + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let project_file = project::File::from_dyn(snapshot.file()); + let repo_name = project_file + .map(|file| file.worktree.read(cx).root_name_str()) + .unwrap_or("untitled") + .into(); + let offset = position.to_offset(&snapshot); + + let project_state = self.get_or_init_zeta_project(project, cx); + let events = project_state.events.clone(); + let recent_buffers = project_state.recent_paths.iter().cloned(); + let http_client = cx.http_client(); + + let recent_buffer_snapshots = recent_buffers + .filter_map(|project_path| { + let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; + if active_buffer == &buffer { + None + } else { + Some(buffer.read(cx).snapshot()) + } + }) + .take(3) + .collect::>(); + + let result = cx.background_spawn(async move { + let text = snapshot.text(); + + let mut recent_changes = String::new(); + for event in events { + sweep_ai::write_event(event, &mut recent_changes).unwrap(); + } + + let file_chunks = recent_buffer_snapshots + .into_iter() + .map(|snapshot| { + let end_point = language::Point::new(30, 0).min(snapshot.max_point()); + sweep_ai::FileChunk { + content: snapshot + .text_for_range(language::Point::zero()..end_point) + .collect(), + file_path: snapshot + .file() + .map(|f| f.path().as_unix_str()) + .unwrap_or("untitled") + .to_string(), + start_line: 0, + end_line: end_point.row as usize, + timestamp: snapshot.file().and_then(|file| { + Some( + file.disk_state() + .mtime()? + .to_seconds_and_nanos_for_persistence()? + .0, + ) + }), + } + }) + .collect(); + + let request_body = sweep_ai::AutocompleteRequest { + debug_info, + repo_name, + file_path: full_path.clone(), + file_contents: text.clone(), + original_file_contents: text, + cursor_position: offset, + recent_changes: recent_changes.clone(), + changes_above_cursor: true, + multiple_suggestions: false, + branch: None, + file_chunks, + retrieval_chunks: vec![], + recent_user_actions: vec![], + // TODO + privacy_mode_enabled: false, + }; + + let mut buf: Vec = Vec::new(); + let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); + serde_json::to_writer(writer, &request_body)?; + let body: AsyncBody = buf.into(); + + const SWEEP_API_URL: &str = + "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; + + let request = http_client::Request::builder() + .uri(SWEEP_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .header("Content-Encoding", "br") + .method(Method::POST) + .body(body)?; + + let mut response = http_client.send(request).await?; + + let mut body: Vec = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; + + let old_text = snapshot + .text_for_range(response.start_index..response.end_index) + .collect::(); + let edits = language::text_diff(&old_text, &response.completion) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(response.start_index + range.start) + ..snapshot.anchor_before(response.start_index + range.end), + text, + ) + }) + .collect::>(); + + anyhow::Ok((response.autocomplete_id, edits, snapshot)) + }); + + let buffer = active_buffer.clone(); + + cx.spawn(async move |_, cx| { + let (id, edits, old_snapshot) = result.await?; + + if edits.is_empty() { + return anyhow::Ok(None); + } + + let Some((edits, new_snapshot, preview_task)) = + buffer.read_with(cx, |buffer, cx| { + let new_snapshot = buffer.snapshot(); + + let edits: Arc<[(Range, Arc)]> = + edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)? + .into(); + let preview_task = buffer.preview_edits(edits.clone(), cx); + + Some((edits, new_snapshot, preview_task)) + })? + else { + return anyhow::Ok(None); + }; + + let prediction = EditPrediction { + id: EditPredictionId(id.into()), + edits, + snapshot: new_snapshot, + edit_preview: preview_task.await, + buffer, + }; + + anyhow::Ok(Some(prediction)) + }) + } + + fn request_prediction_with_zed_cloud( + &mut self, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + cx: &mut Context, ) -> Task>> { let project_state = self.projects.get(&project.entity_id()); @@ -1653,7 +1899,7 @@ impl Zeta { pub fn wait_for_initial_indexing( &mut self, project: &Entity, - cx: &mut App, + cx: &mut Context, ) -> Task> { let zeta_project = self.get_or_init_zeta_project(project, cx); if let Some(syntax_index) = &zeta_project.syntax_index { diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 35fbcb1c61097156d2f0e172d700ed12d3d3894e..e18cf54787ca98e2be60db4977dd2de18e9c09e2 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -49,7 +49,6 @@ settings.workspace = true shellexpand.workspace = true smol.workspace = true soa-rs = "0.8.1" -sweep_ai.workspace = true terminal_view.workspace = true toml.workspace = true util.workspace = true diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 09fbbb29dd6cf58910a2b6e6ff7fb4a31fc4a10a..a9d7acaee2287450eac828bd2d770b88a8150940 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -8,16 +8,15 @@ use anyhow::Result; use collections::HashSet; use gpui::{AsyncApp, Entity}; use project::Project; -use sweep_ai::SweepAi; use util::ResultExt as _; use zeta2::{Zeta, udiff::DiffLine}; use crate::{ - EvaluateArguments, PredictionOptions, PredictionProvider, + EvaluateArguments, PredictionOptions, example::{Example, NamedExample}, headless::ZetaCliAppState, paths::print_run_data_dir, - predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta}, + predict::{PredictionDetails, perform_predict, setup_zeta}, }; #[derive(Debug)] @@ -46,46 +45,35 @@ pub async fn run_evaluate( let project = example.setup_project(&app_state, cx).await.unwrap(); let providers = (0..args.repetitions) - .map(|_| { - ( - setup_zeta(&project, &app_state, cx).unwrap(), - if matches!(args.options.provider, PredictionProvider::Sweep) { - Some(setup_sweep(&project, cx).unwrap()) - } else { - None - }, - ) - }) + .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap()) .collect::>(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - let tasks = - providers - .into_iter() - .enumerate() - .map(move |(repetition_ix, (zeta, sweep))| { - let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); - let example = example.clone(); - let project = project.clone(); - let options = options.clone(); - - cx.spawn(async move |cx| { - let name = example.name.clone(); - run_evaluate_one( - example, - repetition_ix, - project, - zeta, - sweep, - options, - !args.skip_prediction, - cx, - ) - .await - .map_err(|err| (err, name, repetition_ix)) - }) - }); + let tasks = providers + .into_iter() + .enumerate() + .map(move |(repetition_ix, zeta)| { + let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); + let example = example.clone(); + let project = project.clone(); + let options = options.clone(); + + cx.spawn(async move |cx| { + let name = example.name.clone(); + run_evaluate_one( + example, + repetition_ix, + project, + zeta, + options, + !args.skip_prediction, + cx, + ) + .await + .map_err(|err| (err, name, repetition_ix)) + }) + }); futures::future::join_all(tasks).await }) }); @@ -177,7 +165,6 @@ pub async fn run_evaluate_one( repetition_ix: Option, project: Entity, zeta: Entity, - sweep: Option>, prediction_options: PredictionOptions, predict: bool, cx: &mut AsyncApp, @@ -186,7 +173,6 @@ pub async fn run_evaluate_one( example.clone(), project, zeta, - sweep, repetition_ix, prediction_options, cx, diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 803e02b10cfb7533341a3009e0325a7bcf13df1e..53f231599b7d0449b1f2a9cdef8227a7c3e6bbd5 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -191,7 +191,7 @@ pub struct EvaluateArguments { skip_prediction: bool, } -#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] +#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)] enum PredictionProvider { #[default] Zeta2, diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 4505035eaf992751e85216a314b731a12ffbd342..c792b318cec6de42e518793ed5400df0010ae5ea 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -21,7 +21,6 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; -use sweep_ai::SweepAi; use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; pub async fn run_predict( @@ -31,14 +30,9 @@ pub async fn run_predict( ) { let example = NamedExample::load(args.example_path).unwrap(); let project = example.setup_project(app_state, cx).await.unwrap(); - let zeta = setup_zeta(&project, app_state, cx).unwrap(); - let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) { - Some(setup_sweep(&project, cx).unwrap()) - } else { - None - }; + let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - let result = perform_predict(example, project, zeta, sweep, None, args.options, cx) + let result = perform_predict(example, project, zeta, None, args.options, cx) .await .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); @@ -47,6 +41,7 @@ pub async fn run_predict( } pub fn setup_zeta( + provider: PredictionProvider, project: &Entity, app_state: &Arc, cx: &mut AsyncApp, @@ -54,6 +49,14 @@ pub fn setup_zeta( let zeta = cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; + zeta.update(cx, |zeta, _cx| { + let model = match provider { + PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud, + PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep, + }; + zeta.set_edit_prediction_model(model); + })?; + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; cx.subscribe(&buffer_store, { @@ -71,31 +74,10 @@ pub fn setup_zeta( anyhow::Ok(zeta) } -pub fn setup_sweep(project: &Entity, cx: &mut AsyncApp) -> Result> { - let sweep = cx.new(|cx| SweepAi::new(cx))?; - - let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; - - cx.subscribe(&buffer_store, { - let project = project.clone(); - let sweep = sweep.clone(); - move |_, event, cx| match event { - BufferStoreEvent::BufferAdded(buffer) => { - sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx)); - } - _ => {} - } - })? - .detach(); - - anyhow::Ok(sweep) -} - pub async fn perform_predict( example: NamedExample, project: Entity, zeta: Entity, - sweep: Option>, repetition_ix: Option, options: PredictionOptions, cx: &mut AsyncApp, @@ -147,194 +129,152 @@ pub async fn perform_predict( zeta.set_options(options); })?; - let prediction = match options.provider { - crate::PredictionProvider::Zeta2 => { - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; - - let debug_task = cx.background_spawn({ - let result = result.clone(); - async move { - let mut start_time = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; - while let Some(event) = debug_rx.next().await { - match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - start_time = Some(info.timestamp); - fs::write( - example_run_dir.join("search_prompt.md"), - &info.search_prompt, - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - example_run_dir.join("search_queries.json"), - serde_json::to_string_pretty(&info.search_queries).unwrap(), - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} - zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { - let prediction_started_at = Instant::now(); - start_time.get_or_insert(prediction_started_at); - let prompt = request.local_prompt.unwrap_or_default(); - fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?; - - { - let mut result = result.lock().unwrap(); - result.prompt_len = prompt.chars().count(); - - for included_file in request.request.included_files { - let insertions = - vec![(request.request.cursor_point, CURSOR_MARKER)]; - result.excerpts.extend(included_file.excerpts.iter().map( - |excerpt| { - ActualExcerpt { - path: included_file - .path - .components() - .skip(1) - .collect(), - text: String::from(excerpt.text.as_ref()), - } - }, - )); - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut result.excerpts_text, - ); - } - } - - let response = - request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = - zeta2::text_from_response(response).unwrap_or_default(); - let prediction_finished_at = Instant::now(); - fs::write( - example_run_dir.join("prediction_response.md"), - &response, - )?; - + let mut debug_task = gpui::Task::ready(Ok(())); + + if options.provider == crate::PredictionProvider::Zeta2 { + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + + debug_task = cx.background_spawn({ + let result = result.clone(); + async move { + let mut start_time = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + start_time = Some(info.timestamp); + fs::write( + example_run_dir.join("search_prompt.md"), + &info.search_prompt, + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + fs::write( + example_run_dir.join("search_queries.json"), + serde_json::to_string_pretty(&info.search_queries).unwrap(), + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} + zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { + let prediction_started_at = Instant::now(); + start_time.get_or_insert(prediction_started_at); + let prompt = request.local_prompt.unwrap_or_default(); + fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?; + + { let mut result = result.lock().unwrap(); - result.generated_len = response.chars().count(); - - if !options.use_expected_context { - result.planning_search_time = Some( - search_queries_generated_at.unwrap() - start_time.unwrap(), - ); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), + result.prompt_len = prompt.chars().count(); + + for included_file in request.request.included_files { + let insertions = + vec![(request.request.cursor_point, CURSOR_MARKER)]; + result.excerpts.extend(included_file.excerpts.iter().map( + |excerpt| ActualExcerpt { + path: included_file.path.components().skip(1).collect(), + text: String::from(excerpt.text.as_ref()), + }, + )); + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut result.excerpts_text, ); } - result.prediction_time = - prediction_finished_at - prediction_started_at; - result.total_time = prediction_finished_at - start_time.unwrap(); + } - break; + let response = + request.response_rx.await?.0.map_err(|err| anyhow!(err))?; + let response = zeta2::text_from_response(response).unwrap_or_default(); + let prediction_finished_at = Instant::now(); + fs::write(example_run_dir.join("prediction_response.md"), &response)?; + + let mut result = result.lock().unwrap(); + result.generated_len = response.chars().count(); + + if !options.use_expected_context { + result.planning_search_time = Some( + search_queries_generated_at.unwrap() - start_time.unwrap(), + ); + result.running_search_time = Some( + search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(), + ); } + result.prediction_time = prediction_finished_at - prediction_started_at; + result.total_time = prediction_finished_at - start_time.unwrap(); + + break; } } - anyhow::Ok(()) } - }); - - if options.use_expected_context { - let context_excerpts_tasks = example - .example - .expected_context - .iter() - .flat_map(|section| { - section.alternatives[0].excerpts.iter().map(|excerpt| { - resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) - }) + anyhow::Ok(()) + } + }); + + if options.use_expected_context { + let context_excerpts_tasks = example + .example + .expected_context + .iter() + .flat_map(|section| { + section.alternatives[0].excerpts.iter().map(|excerpt| { + resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) }) - .collect::>(); - let context_excerpts_vec = - futures::future::try_join_all(context_excerpts_tasks).await?; - - let mut context_excerpts = HashMap::default(); - for (buffer, mut excerpts) in context_excerpts_vec { - context_excerpts - .entry(buffer) - .or_insert(Vec::new()) - .append(&mut excerpts); - } - - zeta.update(cx, |zeta, _cx| { - zeta.set_context(project.clone(), context_excerpts) - })?; - } else { - zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })? - .await?; + }) + .collect::>(); + let context_excerpts_vec = + futures::future::try_join_all(context_excerpts_tasks).await?; + + let mut context_excerpts = HashMap::default(); + for (buffer, mut excerpts) in context_excerpts_vec { + context_excerpts + .entry(buffer) + .or_insert(Vec::new()) + .append(&mut excerpts); } - let prediction = zeta - .update(cx, |zeta, cx| { - zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) - })? - .await? - .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits)); - - debug_task.await?; - - prediction + zeta.update(cx, |zeta, _cx| { + zeta.set_context(project.clone(), context_excerpts) + })?; + } else { + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; } - crate::PredictionProvider::Sweep => sweep - .unwrap() - .update(cx, |sweep, cx| { - let mut recent_paths = Vec::new(); - for path in zeta - .read(cx) - .history_for_project(&project) - .rev() - .filter_map(|event| event.project_path(cx)) - { - if !recent_paths.contains(&path) { - recent_paths.push(path); - } - } + } - sweep.request_completion( - &project, - recent_paths.into_iter(), - &cursor_buffer, - cursor_anchor, - cx, - ) - })? - .await? - .map( - |sweep_ai::EditPrediction { - edits, snapshot, .. - }| { (cursor_buffer.clone(), snapshot, edits) }, - ), - }; + let prediction = zeta + .update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })? + .await?; + + debug_task.await?; let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); result.diff = prediction - .map(|(buffer, snapshot, edits)| { - let old_text = snapshot.text(); - let new_text = buffer + .map(|prediction| { + let old_text = prediction.snapshot.text(); + let new_text = prediction + .buffer .update(cx, |buffer, cx| { let branch = buffer.branch(cx); branch.update(cx, |branch, cx| { - branch.edit(edits.iter().cloned(), None, cx); + branch.edit(prediction.edits.iter().cloned(), None, cx); branch.text() }) })