Detailed changes
@@ -5314,13 +5314,13 @@ dependencies = [
"serde_json",
"settings",
"supermaven",
- "sweep_ai",
"telemetry",
"theme",
"ui",
"workspace",
"zed_actions",
"zeta",
+ "zeta2",
]
[[package]]
@@ -16590,33 +16590,6 @@ dependencies = [
"zeno",
]
-[[package]]
-name = "sweep_ai"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "arrayvec",
- "brotli",
- "client",
- "collections",
- "edit_prediction",
- "feature_flags",
- "futures 0.3.31",
- "gpui",
- "http_client",
- "indoc",
- "language",
- "project",
- "release_channel",
- "reqwest_client",
- "serde",
- "serde_json",
- "tree-sitter-rust",
- "util",
- "workspace",
- "zlog",
-]
-
[[package]]
name = "symphonia"
version = "0.5.5"
@@ -21343,7 +21316,6 @@ dependencies = [
"snippets_ui",
"supermaven",
"svg_preview",
- "sweep_ai",
"sysinfo 0.37.2",
"system_specs",
"tab_switcher",
@@ -21754,6 +21726,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
+ "brotli",
"chrono",
"client",
"clock",
@@ -21864,7 +21837,6 @@ dependencies = [
"shellexpand 2.1.2",
"smol",
"soa-rs",
- "sweep_ai",
"terminal_view",
"toml 0.8.23",
"util",
@@ -165,7 +165,6 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
- "crates/sweep_ai",
"crates/codestral",
"crates/svg_preview",
"crates/system_specs",
@@ -399,7 +398,6 @@ streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree" }
supermaven = { path = "crates/supermaven" }
supermaven_api = { path = "crates/supermaven_api" }
-sweep_ai = { path = "crates/sweep_ai" }
codestral = { path = "crates/codestral" }
system_specs = { path = "crates/system_specs" }
tab_switcher = { path = "crates/tab_switcher" }
@@ -30,12 +30,12 @@ project.workspace = true
regex.workspace = true
settings.workspace = true
supermaven.workspace = true
-sweep_ai.workspace = true
telemetry.workspace = true
ui.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
+zeta2.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -28,7 +28,6 @@ use std::{
time::Duration,
};
use supermaven::{AccountStatus, Supermaven};
-use sweep_ai::SweepFeatureFlag;
use ui::{
Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
@@ -39,6 +38,7 @@ use workspace::{
};
use zed_actions::OpenBrowser;
use zeta::RateCompletions;
+use zeta2::SweepFeatureFlag;
actions!(
edit_prediction,
@@ -1,43 +0,0 @@
-[package]
-name = "sweep_ai"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-exclude = ["fixtures"]
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/sweep_ai.rs"
-doctest = false
-
-[dependencies]
-anyhow.workspace = true
-arrayvec.workspace = true
-brotli.workspace = true
-client.workspace = true
-collections.workspace = true
-edit_prediction.workspace = true
-feature_flags.workspace = true
-futures.workspace = true
-gpui.workspace = true
-http_client.workspace = true
-language.workspace = true
-project.workspace = true
-release_channel.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-util.workspace = true
-workspace.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
-http_client = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-reqwest_client = { workspace = true, features = ["test-support"] }
-tree-sitter-rust.workspace = true
-workspace = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,784 +0,0 @@
-mod api;
-
-use anyhow::{Context as _, Result};
-use arrayvec::ArrayVec;
-use client::telemetry;
-use collections::HashMap;
-use feature_flags::FeatureFlag;
-use futures::AsyncReadExt as _;
-use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
-use http_client::{AsyncBody, Method};
-use language::{
- Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
-};
-use project::{Project, ProjectPath};
-use release_channel::{AppCommitSha, AppVersion};
-use std::collections::{VecDeque, hash_map};
-use std::fmt::{self, Display};
-use std::mem;
-use std::{
- cmp,
- fmt::Write,
- ops::Range,
- path::Path,
- sync::Arc,
- time::{Duration, Instant},
-};
-use util::ResultExt;
-use util::rel_path::RelPath;
-use workspace::Workspace;
-
-use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
-
-const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
-const MAX_EVENT_COUNT: usize = 6;
-
-const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
-
-pub struct SweepFeatureFlag;
-
-impl FeatureFlag for SweepFeatureFlag {
- const NAME: &str = "sweep-ai";
-}
-
-#[derive(Clone)]
-struct SweepAiGlobal(Entity<SweepAi>);
-
-impl Global for SweepAiGlobal {}
-
-#[derive(Clone)]
-pub struct EditPrediction {
- pub id: EditPredictionId,
- pub path: Arc<Path>,
- pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
- pub snapshot: BufferSnapshot,
- pub edit_preview: EditPreview,
-}
-
-impl EditPrediction {
- fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
- }
-}
-
-impl fmt::Debug for EditPrediction {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("EditPrediction")
- .field("path", &self.path)
- .field("edits", &self.edits)
- .finish_non_exhaustive()
- }
-}
-
-#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(String);
-
-impl Display for EditPredictionId {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-pub struct SweepAi {
- projects: HashMap<EntityId, SweepAiProject>,
- debug_info: Arc<str>,
- api_token: Option<String>,
-}
-
-struct SweepAiProject {
- events: VecDeque<Event>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
-}
-
-impl SweepAi {
- pub fn global(cx: &mut App) -> Option<Entity<Self>> {
- cx.try_global::<SweepAiGlobal>()
- .map(|global| global.0.clone())
- }
-
- pub fn register(cx: &mut App) -> Entity<Self> {
- Self::global(cx).unwrap_or_else(|| {
- let entity = cx.new(|cx| Self::new(cx));
- cx.set_global(SweepAiGlobal(entity.clone()));
- entity
- })
- }
-
- pub fn clear_history(&mut self) {
- for sweep_ai_project in self.projects.values_mut() {
- sweep_ai_project.events.clear();
- }
- }
-
- pub fn new(cx: &mut Context<Self>) -> Self {
- Self {
- api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
- projects: HashMap::default(),
- debug_info: format!(
- "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
- version = AppVersion::global(cx),
- sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
- os = telemetry::os_name(),
- )
- .into(),
- }
- }
-
- fn get_or_init_sweep_ai_project(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &mut SweepAiProject {
- let project_id = project.entity_id();
- match self.projects.entry(project_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- cx.observe_release(project, move |this, _, _cx| {
- this.projects.remove(&project_id);
- })
- .detach();
- entry.insert(SweepAiProject {
- events: VecDeque::with_capacity(MAX_EVENT_COUNT),
- registered_buffers: HashMap::default(),
- })
- }
- }
- }
-
- pub fn register_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
- Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
- }
-
- fn register_buffer_impl<'a>(
- sweep_ai_project: &'a mut SweepAiProject,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &'a mut RegisteredBuffer {
- let buffer_id = buffer.entity_id();
- match sweep_ai_project.registered_buffers.entry(buffer_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- let snapshot = buffer.read(cx).snapshot();
- let project_entity_id = project.entity_id();
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, {
- let project = project.downgrade();
- move |this, buffer, event, cx| {
- if let language::BufferEvent::Edited = event
- && let Some(project) = project.upgrade()
- {
- this.report_changes_for_buffer(&buffer, &project, cx);
- }
- }
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- sweep_ai_project.registered_buffers.remove(&buffer_id);
- }),
- ],
- })
- }
- }
- }
-
- pub fn request_completion(
- &mut self,
- project: &Entity<Project>,
- recent_buffers: impl Iterator<Item = ProjectPath>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
- let snapshot = active_buffer.read(cx).snapshot();
- let debug_info = self.debug_info.clone();
- let Some(api_token) = self.api_token.clone() else {
- return Task::ready(Ok(None));
- };
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|file| file.full_path(cx))
- .unwrap_or_else(|| "untitled".into())
- .into();
-
- let project_file = project::File::from_dyn(snapshot.file());
- let repo_name = project_file
- .map(|file| file.worktree.read(cx).root_name_str())
- .unwrap_or("untitled")
- .into();
- let offset = position.to_offset(&snapshot);
-
- let project_state = self.get_or_init_sweep_ai_project(project, cx);
- let events = project_state.events.clone();
- let http_client = cx.http_client();
-
- let recent_buffer_snapshots = recent_buffers
- .filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
- None
- } else {
- Some(buffer.read(cx).snapshot())
- }
- })
- .take(3)
- .collect::<Vec<_>>();
-
- let result = cx.background_spawn({
- let full_path = full_path.clone();
- async move {
- let text = snapshot.text();
-
- let mut recent_changes = String::new();
-
- for event in events {
- writeln!(&mut recent_changes, "{event}")?;
- }
-
- let file_chunks = recent_buffer_snapshots
- .into_iter()
- .map(|snapshot| {
- let end_point = language::Point::new(30, 0).min(snapshot.max_point());
- FileChunk {
- content: snapshot
- .text_for_range(language::Point::zero()..end_point)
- .collect(),
- file_path: snapshot
- .file()
- .map(|f| f.path().as_unix_str())
- .unwrap_or("untitled")
- .to_string(),
- start_line: 0,
- end_line: end_point.row as usize,
- timestamp: snapshot.file().and_then(|file| {
- Some(
- file.disk_state()
- .mtime()?
- .to_seconds_and_nanos_for_persistence()?
- .0,
- )
- }),
- }
- })
- .collect();
-
- eprintln!("{recent_changes}");
-
- let request_body = AutocompleteRequest {
- debug_info,
- repo_name,
- file_path: full_path.clone(),
- file_contents: text.clone(),
- original_file_contents: text,
- cursor_position: offset,
- recent_changes: recent_changes.clone(),
- changes_above_cursor: true,
- multiple_suggestions: false,
- branch: None,
- file_chunks,
- retrieval_chunks: vec![],
- recent_user_actions: vec![],
- // TODO
- privacy_mode_enabled: false,
- };
-
- let mut buf: Vec<u8> = Vec::new();
- let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
- serde_json::to_writer(writer, &request_body)?;
- let body: AsyncBody = buf.into();
-
- let request = http_client::Request::builder()
- .uri(SWEEP_API_URL)
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_token))
- .header("Connection", "keep-alive")
- .header("Content-Encoding", "br")
- .method(Method::POST)
- .body(body)?;
-
- let mut response = http_client.send(request).await?;
-
- let mut body: Vec<u8> = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
-
- if !response.status().is_success() {
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- String::from_utf8_lossy(&body),
- );
- };
-
- let response: AutocompleteResponse = serde_json::from_slice(&body)?;
-
- let old_text = snapshot
- .text_for_range(response.start_index..response.end_index)
- .collect::<String>();
- let edits = text_diff(&old_text, &response.completion)
- .into_iter()
- .map(|(range, text)| {
- (
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
- text,
- )
- })
- .collect::<Vec<_>>();
-
- anyhow::Ok((response.autocomplete_id, edits, snapshot))
- }
- });
-
- let buffer = active_buffer.clone();
-
- cx.spawn(async move |_, cx| {
- let (id, edits, old_snapshot) = result.await?;
-
- if edits.is_empty() {
- return anyhow::Ok(None);
- }
-
- let Some((edits, new_snapshot, preview_task)) =
- buffer.read_with(cx, |buffer, cx| {
- let new_snapshot = buffer.snapshot();
-
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
- edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
- .into();
- let preview_task = buffer.preview_edits(edits.clone(), cx);
-
- Some((edits, new_snapshot, preview_task))
- })?
- else {
- return anyhow::Ok(None);
- };
-
- let prediction = EditPrediction {
- id: EditPredictionId(id),
- path: full_path,
- edits,
- snapshot: new_snapshot,
- edit_preview: preview_task.await,
- };
-
- anyhow::Ok(Some(prediction))
- })
- }
-
- fn report_changes_for_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
-
- let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version == registered_buffer.snapshot.version {
- return;
- }
-
- let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- let end_edit_anchor = new_snapshot
- .anchored_edits_since::<Point>(&old_snapshot.version)
- .last()
- .map(|(_, range)| range.end);
- let events = &mut sweep_ai_project.events;
-
- if let Some(Event::BufferChange {
- new_snapshot: last_new_snapshot,
- end_edit_anchor: last_end_edit_anchor,
- ..
- }) = events.back_mut()
- {
- let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
- == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version;
-
- let should_coalesce = is_next_snapshot_of_same_buffer
- && end_edit_anchor
- .as_ref()
- .zip(last_end_edit_anchor.as_ref())
- .is_some_and(|(a, b)| {
- let a = a.to_point(&new_snapshot);
- let b = b.to_point(&new_snapshot);
- a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
- });
-
- if should_coalesce {
- *last_end_edit_anchor = end_edit_anchor;
- *last_new_snapshot = new_snapshot;
- return;
- }
- }
-
- if events.len() >= MAX_EVENT_COUNT {
- events.pop_front();
- }
-
- events.push_back(Event::BufferChange {
- old_snapshot,
- new_snapshot,
- end_edit_anchor,
- });
- }
-}
-
-struct RegisteredBuffer {
- snapshot: BufferSnapshot,
- _subscriptions: [gpui::Subscription; 2],
-}
-
-#[derive(Clone)]
-pub enum Event {
- BufferChange {
- old_snapshot: BufferSnapshot,
- new_snapshot: BufferSnapshot,
- end_edit_anchor: Option<Anchor>,
- },
-}
-
-impl Display for Event {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- Event::BufferChange {
- old_snapshot,
- new_snapshot,
- ..
- } => {
- let old_path = old_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- let new_path = new_snapshot
- .file()
- .map(|f| f.path().as_ref())
- .unwrap_or(RelPath::unix("untitled").unwrap());
- if old_path != new_path {
- // TODO confirm how to do this for sweep
- // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
- }
-
- let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
- if !diff.is_empty() {
- write!(
- f,
- "File: {}:\n{}\n",
- new_path.display(util::paths::PathStyle::Posix),
- diff
- )?
- }
-
- fmt::Result::Ok(())
- }
- }
- }
-}
-
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
- buffer_id: EntityId,
- completion: EditPrediction,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
- if self.buffer_id != old_completion.buffer_id {
- return true;
- }
-
- let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
- return true;
- };
- let Some(new_edits) = self.completion.interpolate(snapshot) else {
- return false;
- };
-
- if old_edits.len() == 1 && new_edits.len() == 1 {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text.as_ref())
- } else {
- true
- }
- }
-}
-
-struct PendingCompletion {
- id: usize,
- _task: Task<()>,
-}
-
-pub struct SweepAiEditPredictionProvider {
- workspace: WeakEntity<Workspace>,
- sweep_ai: Entity<SweepAi>,
- pending_completions: ArrayVec<PendingCompletion, 2>,
- next_pending_completion_id: usize,
- current_completion: Option<CurrentEditPrediction>,
- last_request_timestamp: Instant,
- project: Entity<Project>,
-}
-
-impl SweepAiEditPredictionProvider {
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-
- pub fn new(
- sweep_ai: Entity<SweepAi>,
- workspace: WeakEntity<Workspace>,
- project: Entity<Project>,
- ) -> Self {
- Self {
- sweep_ai,
- pending_completions: ArrayVec::new(),
- next_pending_completion_id: 0,
- current_completion: None,
- last_request_timestamp: Instant::now(),
- project,
- workspace,
- }
- }
-}
-
-impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
- fn name() -> &'static str {
- "zed-predict"
- }
-
- fn display_name() -> &'static str {
- "Zed's Edit Predictions"
- }
-
- fn show_completions_in_menu() -> bool {
- true
- }
-
- fn show_tab_accept_marker() -> bool {
- true
- }
-
- fn is_enabled(
- &self,
- _buffer: &Entity<Buffer>,
- _cursor_position: language::Anchor,
- cx: &App,
- ) -> bool {
- self.sweep_ai.read(cx).api_token.is_some()
- }
-
- fn is_refreshing(&self) -> bool {
- !self.pending_completions.is_empty()
- }
-
- fn refresh(
- &mut self,
- buffer: Entity<Buffer>,
- position: language::Anchor,
- _debounce: bool,
- cx: &mut Context<Self>,
- ) {
- if let Some(current_completion) = self.current_completion.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if current_completion
- .completion
- .interpolate(&snapshot)
- .is_some()
- {
- return;
- }
- }
-
- let pending_completion_id = self.next_pending_completion_id;
- self.next_pending_completion_id += 1;
- let last_request_timestamp = self.last_request_timestamp;
-
- let project = self.project.clone();
- let workspace = self.workspace.clone();
- let task = cx.spawn(async move |this, cx| {
- if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
- .checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- let completion_request = this.update(cx, |this, cx| {
- this.last_request_timestamp = Instant::now();
-
- this.sweep_ai.update(cx, |sweep_ai, cx| {
- let Some(recent_buffers) = workspace
- .read_with(cx, |workspace, cx| {
- workspace.recent_navigation_history_iter(cx)
- })
- .log_err()
- else {
- return Task::ready(Ok(None));
- };
- sweep_ai.request_completion(
- &project,
- recent_buffers.map(move |(project_path, _)| project_path),
- &buffer,
- position,
- cx,
- )
- })
- });
-
- let completion = match completion_request {
- Ok(completion_request) => {
- let completion_request = completion_request.await;
- completion_request.map(|c| {
- c.map(|completion| CurrentEditPrediction {
- buffer_id: buffer.entity_id(),
- completion,
- })
- })
- }
- Err(error) => Err(error),
- };
-
- let Some(new_completion) = completion
- .context("edit prediction failed")
- .log_err()
- .flatten()
- else {
- this.update(cx, |this, cx| {
- if this.pending_completions[0].id == pending_completion_id {
- this.pending_completions.remove(0);
- } else {
- this.pending_completions.clear();
- }
-
- cx.notify();
- })
- .ok();
- return;
- };
-
- this.update(cx, |this, cx| {
- if this.pending_completions[0].id == pending_completion_id {
- this.pending_completions.remove(0);
- } else {
- this.pending_completions.clear();
- }
-
- if let Some(old_completion) = this.current_completion.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if new_completion.should_replace_completion(old_completion, &snapshot) {
- this.current_completion = Some(new_completion);
- }
- } else {
- this.current_completion = Some(new_completion);
- }
-
- cx.notify();
- })
- .ok();
- });
-
- // We always maintain at most two pending completions. When we already
- // have two, we replace the newest one.
- if self.pending_completions.len() <= 1 {
- self.pending_completions.push(PendingCompletion {
- id: pending_completion_id,
- _task: task,
- });
- } else if self.pending_completions.len() == 2 {
- self.pending_completions.pop();
- self.pending_completions.push(PendingCompletion {
- id: pending_completion_id,
- _task: task,
- });
- }
- }
-
- fn cycle(
- &mut self,
- _buffer: Entity<Buffer>,
- _cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
- _cx: &mut Context<Self>,
- ) {
- // Right now we don't support cycling.
- }
-
- fn accept(&mut self, _cx: &mut Context<Self>) {
- self.pending_completions.clear();
- }
-
- fn discard(&mut self, _cx: &mut Context<Self>) {
- self.pending_completions.clear();
- self.current_completion.take();
- }
-
- fn suggest(
- &mut self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
- let CurrentEditPrediction {
- buffer_id,
- completion,
- ..
- } = self.current_completion.as_mut()?;
-
- // Invalidate previous completion if it was generated for a different buffer.
- if *buffer_id != buffer.entity_id() {
- self.current_completion.take();
- return None;
- }
-
- let buffer = buffer.read(cx);
- let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
- self.current_completion.take();
- return None;
- };
-
- let cursor_row = cursor_position.to_point(buffer).row;
- let (closest_edit_ix, (closest_edit_range, _)) =
- edits.iter().enumerate().min_by_key(|(_, (range, _))| {
- let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
- let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
- cmp::min(distance_from_start, distance_from_end)
- })?;
-
- let mut edit_start_ix = closest_edit_ix;
- for (range, _) in edits[..edit_start_ix].iter().rev() {
- let distance_from_closest_edit =
- closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_start_ix -= 1;
- } else {
- break;
- }
- }
-
- let mut edit_end_ix = closest_edit_ix + 1;
- for (range, _) in &edits[edit_end_ix..] {
- let distance_from_closest_edit =
- range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_end_ix += 1;
- } else {
- break;
- }
- }
-
- Some(edit_prediction::EditPrediction::Local {
- id: Some(completion.id.to_string().into()),
- edits: edits[edit_start_ix..edit_end_ix].to_vec(),
- edit_preview: Some(completion.edit_preview.clone()),
- })
- }
-}
@@ -133,7 +133,6 @@ snippet_provider.workspace = true
snippets_ui.workspace = true
supermaven.workspace = true
svg_preview.workspace = true
-sweep_ai.workspace = true
sysinfo.workspace = true
tab_switcher.workspace = true
task.workspace = true
@@ -10,9 +10,9 @@ use language_models::MistralLanguageModelProvider;
use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
-use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
use ui::Window;
use zeta::ZetaEditPredictionProvider;
+use zeta2::SweepFeatureFlag;
use zeta2::Zeta2FeatureFlag;
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@@ -203,55 +203,41 @@ fn assign_edit_prediction_provider(
let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
- EditPredictionProvider::Experimental(name) => {
- if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
- && cx.has_flag::<SweepFeatureFlag>()
- {
- if let Some(project) = editor.project()
- && let Some(workspace) = editor.workspace()
+ value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
+ if let Some(project) = editor.project() {
+ let mut worktree = None;
+ if let Some(buffer) = &singleton_buffer
+ && let Some(file) = buffer.read(cx).file()
{
- let sweep_ai = sweep_ai::SweepAi::register(cx);
-
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- {
- sweep_ai.update(cx, |sweep_ai, cx| {
- sweep_ai.register_buffer(buffer, project, cx);
- });
- }
+ let id = file.worktree_id(cx);
+ worktree = project.read(cx).worktree_for_id(id, cx);
+ }
- let provider = cx.new(|_| {
- sweep_ai::SweepAiEditPredictionProvider::new(
- sweep_ai,
- workspace.downgrade(),
+ if let EditPredictionProvider::Experimental(name) = value
+ && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<SweepFeatureFlag>()
+ {
+ let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
+ let provider = cx.new(|cx| {
+ zeta2::ZetaEditPredictionProvider::new(
project.clone(),
+ &client,
+ &user_store,
+ cx,
)
});
- editor.set_edit_prediction_provider(Some(provider), window, cx);
- }
- } else {
- editor.set_edit_prediction_provider::<SweepAiEditPredictionProvider>(
- None, window, cx,
- );
- }
- }
- EditPredictionProvider::Zed => {
- if user_store.read(cx).current_user().is_some() {
- let mut worktree = None;
- if let Some(buffer) = &singleton_buffer
- && let Some(file) = buffer.read(cx).file()
- {
- let id = file.worktree_id(cx);
- if let Some(inner_worktree) = editor
- .project()
- .and_then(|project| project.read(cx).worktree_for_id(id, cx))
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
{
- worktree = Some(inner_worktree);
+ zeta2.update(cx, |zeta, cx| {
+ zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
+ zeta.register_buffer(buffer, project, cx);
+ });
}
- }
- if let Some(project) = editor.project() {
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ } else if user_store.read(cx).current_user().is_some() {
if cx.has_flag::<Zeta2FeatureFlag>() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
@@ -268,6 +254,9 @@ fn assign_edit_prediction_provider(
&& buffer.read(cx).file().is_some()
{
zeta.update(cx, |zeta, cx| {
+ zeta.set_edit_prediction_model(
+ zeta2::ZetaEditPredictionModel::ZedCloud,
+ );
zeta.register_buffer(buffer, project, cx);
});
}
@@ -17,6 +17,7 @@ eval-support = []
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
+brotli.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
@@ -12,7 +12,7 @@ use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;
-use crate::{BufferEditPrediction, Zeta};
+use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
@@ -85,9 +85,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
&self,
_buffer: &Entity<language::Buffer>,
_cursor_position: language::Anchor,
- _cx: &App,
+ cx: &App,
) -> bool {
- true
+ let zeta = self.zeta.read(cx);
+ if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
+ zeta.sweep_api_token.is_some()
+ } else {
+ true
+ }
}
fn is_refreshing(&self) -> bool {
@@ -1,6 +1,8 @@
+use std::fmt;
use std::{path::Path, sync::Arc};
use serde::{Deserialize, Serialize};
+use util::rel_path::RelPath;
#[derive(Debug, Clone, Serialize)]
pub struct AutocompleteRequest {
@@ -88,3 +90,49 @@ pub struct AdditionalCompletion {
pub logprobs: Option<serde_json::Value>,
pub finish_reason: Option<String>,
}
+
+pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
+ match event {
+ crate::Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ ..
+ } => {
+ let old_path = old_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(RelPath::unix("untitled").unwrap());
+ let new_path = new_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(RelPath::unix("untitled").unwrap());
+ if old_path != new_path {
+ // TODO confirm how to do this for sweep
+ // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
+ }
+
+ let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+ if !diff.is_empty() {
+ write!(
+ f,
+ "File: {}:\n{}\n",
+ new_path.display(util::paths::PathStyle::Posix),
+ diff
+ )?
+ }
+
+ fmt::Result::Ok(())
+ }
+ }
+}
+
+pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
+ format!(
+ "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
+ version = release_channel::AppVersion::global(cx),
+ sha = release_channel::AppCommitSha::try_global(cx)
+ .map_or("unknown".to_string(), |sha| sha.full()),
+ os = client::telemetry::os_name(),
+ )
+ .into()
+}
@@ -22,30 +22,31 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
-use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use open_ai::FunctionDefinition;
-use project::Project;
+use project::{Project, ProjectPath};
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
use std::collections::{VecDeque, hash_map};
-use std::env;
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
+use std::{env, mem};
use thiserror::Error;
use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, TryFutureExt};
+use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod assemble_excerpts;
mod prediction;
mod provider;
pub mod retrieval_search;
+mod sweep_ai;
pub mod udiff;
mod xml_edits;
@@ -55,8 +56,15 @@ pub use crate::prediction::EditPredictionId;
pub use provider::ZetaEditPredictionProvider;
/// Maximum number of events to track.
-const MAX_EVENT_COUNT: usize = 16;
+const EVENT_COUNT_MAX_SWEEP: usize = 6;
+const EVENT_COUNT_MAX_ZETA: usize = 16;
+const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
+pub struct SweepFeatureFlag;
+
+impl FeatureFlag for SweepFeatureFlag {
+ const NAME: &str = "sweep-ai";
+}
pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
max_bytes: 512,
min_bytes: 128,
@@ -143,6 +151,15 @@ pub struct Zeta {
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
+ edit_prediction_model: ZetaEditPredictionModel,
+ sweep_api_token: Option<String>,
+ sweep_ai_debug_info: Arc<str>,
+}
+
+#[derive(PartialEq, Eq)]
+pub enum ZetaEditPredictionModel {
+ ZedCloud,
+ Sweep,
}
#[derive(Debug, Clone, PartialEq)]
@@ -219,12 +236,14 @@ pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ZetaProject {
syntax_index: Option<Entity<SyntaxIndex>>,
events: VecDeque<Event>,
+ recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
refresh_context_timestamp: Option<Instant>,
+ _subscription: gpui::Subscription,
}
#[derive(Debug, Clone)]
@@ -287,6 +306,7 @@ pub enum Event {
BufferChange {
old_snapshot: BufferSnapshot,
new_snapshot: BufferSnapshot,
+ end_edit_anchor: Option<Anchor>,
timestamp: Instant,
},
}
@@ -381,7 +401,19 @@ impl Zeta {
debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
+ edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
+ sweep_api_token: None,
+ sweep_ai_debug_info: sweep_ai::debug_info(cx),
+ }
+ }
+
+ pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
+ if model == ZetaEditPredictionModel::Sweep {
+ self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
+ .context("No SWEEP_AI_TOKEN environment variable set")
+ .log_err();
}
+ self.edit_prediction_model = model;
}
#[cfg(feature = "eval-support")]
@@ -443,7 +475,7 @@ impl Zeta {
self.user_store.read(cx).edit_prediction_usage()
}
- pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
+ pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
self.get_or_init_zeta_project(project, cx);
}
@@ -460,7 +492,7 @@ impl Zeta {
fn get_or_init_zeta_project(
&mut self,
project: &Entity<Project>,
- cx: &mut App,
+ cx: &mut Context<Self>,
) -> &mut ZetaProject {
self.projects
.entry(project.entity_id())
@@ -473,12 +505,31 @@ impl Zeta {
None
},
events: VecDeque::new(),
+ recent_paths: VecDeque::new(),
registered_buffers: HashMap::default(),
current_prediction: None,
context: None,
refresh_context_task: None,
refresh_context_debounce_task: None,
refresh_context_timestamp: None,
+ _subscription: cx.subscribe(&project, |this, project, event, cx| {
+ // TODO [zeta2] init with recent paths
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
+ let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+ if let Some(path) = path {
+ if let Some(ix) = zeta_project
+ .recent_paths
+ .iter()
+ .position(|probe| probe == &path)
+ {
+ zeta_project.recent_paths.remove(ix);
+ }
+ zeta_project.recent_paths.push_front(path);
+ }
+ }
+ }
+ }),
})
}
@@ -525,66 +576,64 @@ impl Zeta {
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut Context<Self>,
- ) -> BufferSnapshot {
- let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ ) {
+ let event_count_max = match self.edit_prediction_model {
+ ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
+ ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
+ };
+
+ let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version != registered_buffer.snapshot.version {
- let old_snapshot =
- std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- Self::push_event(
- zeta_project,
- buffer_change_grouping_interval,
- Event::BufferChange {
- old_snapshot,
- new_snapshot: new_snapshot.clone(),
- timestamp: Instant::now(),
- },
- );
+ if new_snapshot.version == registered_buffer.snapshot.version {
+ return;
}
- new_snapshot
- }
-
- fn push_event(
- zeta_project: &mut ZetaProject,
- buffer_change_grouping_interval: Duration,
- event: Event,
- ) {
- let events = &mut zeta_project.events;
+ let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ let end_edit_anchor = new_snapshot
+ .anchored_edits_since::<Point>(&old_snapshot.version)
+ .last()
+ .map(|(_, range)| range.end);
+ let events = &mut sweep_ai_project.events;
- if buffer_change_grouping_interval > Duration::ZERO
- && let Some(Event::BufferChange {
- new_snapshot: last_new_snapshot,
- timestamp: last_timestamp,
- ..
- }) = events.back_mut()
+ if let Some(Event::BufferChange {
+ new_snapshot: last_new_snapshot,
+ end_edit_anchor: last_end_edit_anchor,
+ ..
+ }) = events.back_mut()
{
- // Coalesce edits for the same buffer when they happen one after the other.
- let Event::BufferChange {
- old_snapshot,
- new_snapshot,
- timestamp,
- } = &event;
-
- if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
- && old_snapshot.remote_id() == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version
- {
- *last_new_snapshot = new_snapshot.clone();
- *last_timestamp = *timestamp;
+ let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
+ == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version;
+
+ let should_coalesce = is_next_snapshot_of_same_buffer
+ && end_edit_anchor
+ .as_ref()
+ .zip(last_end_edit_anchor.as_ref())
+ .is_some_and(|(a, b)| {
+ let a = a.to_point(&new_snapshot);
+ let b = b.to_point(&new_snapshot);
+ a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
+ });
+
+ if should_coalesce {
+ *last_end_edit_anchor = end_edit_anchor;
+ *last_new_snapshot = new_snapshot;
return;
}
}
- if events.len() >= MAX_EVENT_COUNT {
- // These are halved instead of popping to improve prompt caching.
- events.drain(..MAX_EVENT_COUNT / 2);
+ if events.len() >= event_count_max {
+ events.pop_front();
}
- events.push_back(event);
+ events.push_back(Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ end_edit_anchor,
+ timestamp: Instant::now(),
+ });
}
fn current_prediction_for_buffer(
@@ -706,6 +755,203 @@ impl Zeta {
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPrediction>>> {
+ match self.edit_prediction_model {
+ ZetaEditPredictionModel::ZedCloud => {
+ self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
+ }
+ ZetaEditPredictionModel::Sweep => {
+ self.request_prediction_with_sweep(project, active_buffer, position, cx)
+ }
+ }
+ }
+
+ fn request_prediction_with_sweep(
+ &mut self,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPrediction>>> {
+ let snapshot = active_buffer.read(cx).snapshot();
+ let debug_info = self.sweep_ai_debug_info.clone();
+ let Some(api_token) = self.sweep_api_token.clone() else {
+ return Task::ready(Ok(None));
+ };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let project_file = project::File::from_dyn(snapshot.file());
+ let repo_name = project_file
+ .map(|file| file.worktree.read(cx).root_name_str())
+ .unwrap_or("untitled")
+ .into();
+ let offset = position.to_offset(&snapshot);
+
+ let project_state = self.get_or_init_zeta_project(project, cx);
+ let events = project_state.events.clone();
+ let recent_buffers = project_state.recent_paths.iter().cloned();
+ let http_client = cx.http_client();
+
+ let recent_buffer_snapshots = recent_buffers
+ .filter_map(|project_path| {
+ let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+ if active_buffer == &buffer {
+ None
+ } else {
+ Some(buffer.read(cx).snapshot())
+ }
+ })
+ .take(3)
+ .collect::<Vec<_>>();
+
+ let result = cx.background_spawn(async move {
+ let text = snapshot.text();
+
+ let mut recent_changes = String::new();
+ for event in events {
+ sweep_ai::write_event(event, &mut recent_changes).unwrap();
+ }
+
+ let file_chunks = recent_buffer_snapshots
+ .into_iter()
+ .map(|snapshot| {
+ let end_point = language::Point::new(30, 0).min(snapshot.max_point());
+ sweep_ai::FileChunk {
+ content: snapshot
+ .text_for_range(language::Point::zero()..end_point)
+ .collect(),
+ file_path: snapshot
+ .file()
+ .map(|f| f.path().as_unix_str())
+ .unwrap_or("untitled")
+ .to_string(),
+ start_line: 0,
+ end_line: end_point.row as usize,
+ timestamp: snapshot.file().and_then(|file| {
+ Some(
+ file.disk_state()
+ .mtime()?
+ .to_seconds_and_nanos_for_persistence()?
+ .0,
+ )
+ }),
+ }
+ })
+ .collect();
+
+ let request_body = sweep_ai::AutocompleteRequest {
+ debug_info,
+ repo_name,
+ file_path: full_path.clone(),
+ file_contents: text.clone(),
+ original_file_contents: text,
+ cursor_position: offset,
+ recent_changes: recent_changes.clone(),
+ changes_above_cursor: true,
+ multiple_suggestions: false,
+ branch: None,
+ file_chunks,
+ retrieval_chunks: vec![],
+ recent_user_actions: vec![],
+ // TODO
+ privacy_mode_enabled: false,
+ };
+
+ let mut buf: Vec<u8> = Vec::new();
+ let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+ serde_json::to_writer(writer, &request_body)?;
+ let body: AsyncBody = buf.into();
+
+ const SWEEP_API_URL: &str =
+ "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+ let request = http_client::Request::builder()
+ .uri(SWEEP_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .header("Content-Encoding", "br")
+ .method(Method::POST)
+ .body(body)?;
+
+ let mut response = http_client.send(request).await?;
+
+ let mut body: Vec<u8> = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
+
+ let old_text = snapshot
+ .text_for_range(response.start_index..response.end_index)
+ .collect::<String>();
+ let edits = language::text_diff(&old_text, &response.completion)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(response.start_index + range.start)
+ ..snapshot.anchor_before(response.start_index + range.end),
+ text,
+ )
+ })
+ .collect::<Vec<_>>();
+
+ anyhow::Ok((response.autocomplete_id, edits, snapshot))
+ });
+
+ let buffer = active_buffer.clone();
+
+ cx.spawn(async move |_, cx| {
+ let (id, edits, old_snapshot) = result.await?;
+
+ if edits.is_empty() {
+ return anyhow::Ok(None);
+ }
+
+ let Some((edits, new_snapshot, preview_task)) =
+ buffer.read_with(cx, |buffer, cx| {
+ let new_snapshot = buffer.snapshot();
+
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
+ edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
+ .into();
+ let preview_task = buffer.preview_edits(edits.clone(), cx);
+
+ Some((edits, new_snapshot, preview_task))
+ })?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ let prediction = EditPrediction {
+ id: EditPredictionId(id.into()),
+ edits,
+ snapshot: new_snapshot,
+ edit_preview: preview_task.await,
+ buffer,
+ };
+
+ anyhow::Ok(Some(prediction))
+ })
+ }
+
+ fn request_prediction_with_zed_cloud(
+ &mut self,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let project_state = self.projects.get(&project.entity_id());
@@ -1653,7 +1899,7 @@ impl Zeta {
pub fn wait_for_initial_indexing(
&mut self,
project: &Entity<Project>,
- cx: &mut App,
+ cx: &mut Context<Self>,
) -> Task<Result<()>> {
let zeta_project = self.get_or_init_zeta_project(project, cx);
if let Some(syntax_index) = &zeta_project.syntax_index {
@@ -49,7 +49,6 @@ settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
soa-rs = "0.8.1"
-sweep_ai.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
@@ -8,16 +8,15 @@ use anyhow::Result;
use collections::HashSet;
use gpui::{AsyncApp, Entity};
use project::Project;
-use sweep_ai::SweepAi;
use util::ResultExt as _;
use zeta2::{Zeta, udiff::DiffLine};
use crate::{
- EvaluateArguments, PredictionOptions, PredictionProvider,
+ EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
- predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
+ predict::{PredictionDetails, perform_predict, setup_zeta},
};
#[derive(Debug)]
@@ -46,46 +45,35 @@ pub async fn run_evaluate(
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
- .map(|_| {
- (
- setup_zeta(&project, &app_state, cx).unwrap(),
- if matches!(args.options.provider, PredictionProvider::Sweep) {
- Some(setup_sweep(&project, cx).unwrap())
- } else {
- None
- },
- )
- })
+ .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let tasks =
- providers
- .into_iter()
- .enumerate()
- .map(move |(repetition_ix, (zeta, sweep))| {
- let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
- let example = example.clone();
- let project = project.clone();
- let options = options.clone();
-
- cx.spawn(async move |cx| {
- let name = example.name.clone();
- run_evaluate_one(
- example,
- repetition_ix,
- project,
- zeta,
- sweep,
- options,
- !args.skip_prediction,
- cx,
- )
- .await
- .map_err(|err| (err, name, repetition_ix))
- })
- });
+ let tasks = providers
+ .into_iter()
+ .enumerate()
+ .map(move |(repetition_ix, zeta)| {
+ let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
+ let example = example.clone();
+ let project = project.clone();
+ let options = options.clone();
+
+ cx.spawn(async move |cx| {
+ let name = example.name.clone();
+ run_evaluate_one(
+ example,
+ repetition_ix,
+ project,
+ zeta,
+ options,
+ !args.skip_prediction,
+ cx,
+ )
+ .await
+ .map_err(|err| (err, name, repetition_ix))
+ })
+ });
futures::future::join_all(tasks).await
})
});
@@ -177,7 +165,6 @@ pub async fn run_evaluate_one(
repetition_ix: Option<u16>,
project: Entity<Project>,
zeta: Entity<Zeta>,
- sweep: Option<Entity<SweepAi>>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
@@ -186,7 +173,6 @@ pub async fn run_evaluate_one(
example.clone(),
project,
zeta,
- sweep,
repetition_ix,
prediction_options,
cx,
@@ -191,7 +191,7 @@ pub struct EvaluateArguments {
skip_prediction: bool,
}
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
enum PredictionProvider {
#[default]
Zeta2,
@@ -21,7 +21,6 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
-use sweep_ai::SweepAi;
use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
@@ -31,14 +30,9 @@ pub async fn run_predict(
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
- let zeta = setup_zeta(&project, app_state, cx).unwrap();
- let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
- Some(setup_sweep(&project, cx).unwrap())
- } else {
- None
- };
+ let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
+ let result = perform_predict(example, project, zeta, None, args.options, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
@@ -47,6 +41,7 @@ pub async fn run_predict(
}
pub fn setup_zeta(
+ provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@@ -54,6 +49,14 @@ pub fn setup_zeta(
let zeta =
cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+ zeta.update(cx, |zeta, _cx| {
+ let model = match provider {
+ PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
+ PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
+ };
+ zeta.set_edit_prediction_model(model);
+ })?;
+
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
@@ -71,31 +74,10 @@ pub fn setup_zeta(
anyhow::Ok(zeta)
}
-pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
- let sweep = cx.new(|cx| SweepAi::new(cx))?;
-
- let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
-
- cx.subscribe(&buffer_store, {
- let project = project.clone();
- let sweep = sweep.clone();
- move |_, event, cx| match event {
- BufferStoreEvent::BufferAdded(buffer) => {
- sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
- }
- _ => {}
- }
- })?
- .detach();
-
- anyhow::Ok(sweep)
-}
-
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
zeta: Entity<Zeta>,
- sweep: Option<Entity<SweepAi>>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
@@ -147,194 +129,152 @@ pub async fn perform_predict(
zeta.set_options(options);
})?;
- let prediction = match options.provider {
- crate::PredictionProvider::Zeta2 => {
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
-
- let debug_task = cx.background_spawn({
- let result = result.clone();
- async move {
- let mut start_time = None;
- let mut search_queries_generated_at = None;
- let mut search_queries_executed_at = None;
- while let Some(event) = debug_rx.next().await {
- match event {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
- start_time = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_prompt.md"),
- &info.search_prompt,
- )?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- search_queries_generated_at = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_queries.json"),
- serde_json::to_string_pretty(&info.search_queries).unwrap(),
- )?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
- search_queries_executed_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
- zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
- let prediction_started_at = Instant::now();
- start_time.get_or_insert(prediction_started_at);
- let prompt = request.local_prompt.unwrap_or_default();
- fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
- {
- let mut result = result.lock().unwrap();
- result.prompt_len = prompt.chars().count();
-
- for included_file in request.request.included_files {
- let insertions =
- vec![(request.request.cursor_point, CURSOR_MARKER)];
- result.excerpts.extend(included_file.excerpts.iter().map(
- |excerpt| {
- ActualExcerpt {
- path: included_file
- .path
- .components()
- .skip(1)
- .collect(),
- text: String::from(excerpt.text.as_ref()),
- }
- },
- ));
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut result.excerpts_text,
- );
- }
- }
-
- let response =
- request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response =
- zeta2::text_from_response(response).unwrap_or_default();
- let prediction_finished_at = Instant::now();
- fs::write(
- example_run_dir.join("prediction_response.md"),
- &response,
- )?;
-
+ let mut debug_task = gpui::Task::ready(Ok(()));
+
+ if options.provider == crate::PredictionProvider::Zeta2 {
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+
+ debug_task = cx.background_spawn({
+ let result = result.clone();
+ async move {
+ let mut start_time = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ start_time = Some(info.timestamp);
+ fs::write(
+ example_run_dir.join("search_prompt.md"),
+ &info.search_prompt,
+ )?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ fs::write(
+ example_run_dir.join("search_queries.json"),
+ serde_json::to_string_pretty(&info.search_queries).unwrap(),
+ )?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
+ zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+ let prediction_started_at = Instant::now();
+ start_time.get_or_insert(prediction_started_at);
+ let prompt = request.local_prompt.unwrap_or_default();
+ fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
+
+ {
let mut result = result.lock().unwrap();
- result.generated_len = response.chars().count();
-
- if !options.use_expected_context {
- result.planning_search_time = Some(
- search_queries_generated_at.unwrap() - start_time.unwrap(),
- );
- result.running_search_time = Some(
- search_queries_executed_at.unwrap()
- - search_queries_generated_at.unwrap(),
+ result.prompt_len = prompt.chars().count();
+
+ for included_file in request.request.included_files {
+ let insertions =
+ vec![(request.request.cursor_point, CURSOR_MARKER)];
+ result.excerpts.extend(included_file.excerpts.iter().map(
+ |excerpt| ActualExcerpt {
+ path: included_file.path.components().skip(1).collect(),
+ text: String::from(excerpt.text.as_ref()),
+ },
+ ));
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut result.excerpts_text,
);
}
- result.prediction_time =
- prediction_finished_at - prediction_started_at;
- result.total_time = prediction_finished_at - start_time.unwrap();
+ }
- break;
+ let response =
+ request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+ let response = zeta2::text_from_response(response).unwrap_or_default();
+ let prediction_finished_at = Instant::now();
+ fs::write(example_run_dir.join("prediction_response.md"), &response)?;
+
+ let mut result = result.lock().unwrap();
+ result.generated_len = response.chars().count();
+
+ if !options.use_expected_context {
+ result.planning_search_time = Some(
+ search_queries_generated_at.unwrap() - start_time.unwrap(),
+ );
+ result.running_search_time = Some(
+ search_queries_executed_at.unwrap()
+ - search_queries_generated_at.unwrap(),
+ );
}
+ result.prediction_time = prediction_finished_at - prediction_started_at;
+ result.total_time = prediction_finished_at - start_time.unwrap();
+
+ break;
}
}
- anyhow::Ok(())
}
- });
-
- if options.use_expected_context {
- let context_excerpts_tasks = example
- .example
- .expected_context
- .iter()
- .flat_map(|section| {
- section.alternatives[0].excerpts.iter().map(|excerpt| {
- resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
- })
+ anyhow::Ok(())
+ }
+ });
+
+ if options.use_expected_context {
+ let context_excerpts_tasks = example
+ .example
+ .expected_context
+ .iter()
+ .flat_map(|section| {
+ section.alternatives[0].excerpts.iter().map(|excerpt| {
+ resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
})
- .collect::<Vec<_>>();
- let context_excerpts_vec =
- futures::future::try_join_all(context_excerpts_tasks).await?;
-
- let mut context_excerpts = HashMap::default();
- for (buffer, mut excerpts) in context_excerpts_vec {
- context_excerpts
- .entry(buffer)
- .or_insert(Vec::new())
- .append(&mut excerpts);
- }
-
- zeta.update(cx, |zeta, _cx| {
- zeta.set_context(project.clone(), context_excerpts)
- })?;
- } else {
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?
- .await?;
+ })
+ .collect::<Vec<_>>();
+ let context_excerpts_vec =
+ futures::future::try_join_all(context_excerpts_tasks).await?;
+
+ let mut context_excerpts = HashMap::default();
+ for (buffer, mut excerpts) in context_excerpts_vec {
+ context_excerpts
+ .entry(buffer)
+ .or_insert(Vec::new())
+ .append(&mut excerpts);
}
- let prediction = zeta
- .update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
- })?
- .await?
- .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
-
- debug_task.await?;
-
- prediction
+ zeta.update(cx, |zeta, _cx| {
+ zeta.set_context(project.clone(), context_excerpts)
+ })?;
+ } else {
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?
+ .await?;
}
- crate::PredictionProvider::Sweep => sweep
- .unwrap()
- .update(cx, |sweep, cx| {
- let mut recent_paths = Vec::new();
- for path in zeta
- .read(cx)
- .history_for_project(&project)
- .rev()
- .filter_map(|event| event.project_path(cx))
- {
- if !recent_paths.contains(&path) {
- recent_paths.push(path);
- }
- }
+ }
- sweep.request_completion(
- &project,
- recent_paths.into_iter(),
- &cursor_buffer,
- cursor_anchor,
- cx,
- )
- })?
- .await?
- .map(
- |sweep_ai::EditPrediction {
- edits, snapshot, ..
- }| { (cursor_buffer.clone(), snapshot, edits) },
- ),
- };
+ let prediction = zeta
+ .update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?
+ .await?;
+
+ debug_task.await?;
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
- .map(|(buffer, snapshot, edits)| {
- let old_text = snapshot.text();
- let new_text = buffer
+ .map(|prediction| {
+ let old_text = prediction.snapshot.text();
+ let new_text = prediction
+ .buffer
.update(cx, |buffer, cx| {
let branch = buffer.branch(cx);
branch.update(cx, |branch, cx| {
- branch.edit(edits.iter().cloned(), None, cx);
+ branch.edit(prediction.edits.iter().cloned(), None, cx);
branch.text()
})
})