Detailed changes
@@ -77,7 +77,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
) -> bool {
let zeta = self.zeta.read(cx);
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
- zeta.sweep_api_token.is_some()
+ zeta.sweep_ai.api_token.is_some()
} else {
true
}
@@ -1,10 +1,269 @@
-use std::fmt;
-use std::{path::Path, sync::Arc};
-
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::Event;
+use futures::AsyncReadExt as _;
+use gpui::{
+ App, AppContext as _, Entity, Task,
+ http_client::{self, AsyncBody, Method},
+};
+use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use lsp::DiagnosticSeverity;
+use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
+use std::{
+ collections::VecDeque,
+ fmt::{self, Write as _},
+ ops::Range,
+ path::Path,
+ sync::Arc,
+ time::Instant,
+};
+use util::ResultExt as _;
+
+use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
+
+const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+
+pub struct SweepAi {
+ pub api_token: Option<String>,
+ pub debug_info: Arc<str>,
+}
+
+impl SweepAi {
+ pub fn new(cx: &App) -> Self {
+ SweepAi {
+ api_token: std::env::var("SWEEP_AI_TOKEN")
+ .context("No SWEEP_AI_TOKEN environment variable set")
+ .log_err(),
+ debug_info: debug_info(cx),
+ }
+ }
+
+ pub fn request_prediction_with_sweep(
+ &self,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ position: language::Anchor,
+ events: Vec<Arc<Event>>,
+ recent_paths: &VecDeque<ProjectPath>,
+ diagnostic_search_range: Range<Point>,
+ cx: &mut App,
+ ) -> Task<Result<Option<EditPrediction>>> {
+ let debug_info = self.debug_info.clone();
+ let Some(api_token) = self.api_token.clone() else {
+ return Task::ready(Ok(None));
+ };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let project_file = project::File::from_dyn(snapshot.file());
+ let repo_name = project_file
+ .map(|file| file.worktree.read(cx).root_name_str())
+ .unwrap_or("untitled")
+ .into();
+ let offset = position.to_offset(&snapshot);
+
+ let recent_buffers = recent_paths.iter().cloned();
+ let http_client = cx.http_client();
+
+ let recent_buffer_snapshots = recent_buffers
+ .filter_map(|project_path| {
+ let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+ if active_buffer == &buffer {
+ None
+ } else {
+ Some(buffer.read(cx).snapshot())
+ }
+ })
+ .take(3)
+ .collect::<Vec<_>>();
+
+ let cursor_point = position.to_point(&snapshot);
+ let buffer_snapshotted_at = Instant::now();
+
+ let result = cx.background_spawn(async move {
+ let text = snapshot.text();
+
+ let mut recent_changes = String::new();
+ for event in &events {
+ write_event(event.as_ref(), &mut recent_changes).unwrap();
+ }
+
+ let mut file_chunks = recent_buffer_snapshots
+ .into_iter()
+ .map(|snapshot| {
+ let end_point = Point::new(30, 0).min(snapshot.max_point());
+ FileChunk {
+ content: snapshot.text_for_range(Point::zero()..end_point).collect(),
+ file_path: snapshot
+ .file()
+ .map(|f| f.path().as_unix_str())
+ .unwrap_or("untitled")
+ .to_string(),
+ start_line: 0,
+ end_line: end_point.row as usize,
+ timestamp: snapshot.file().and_then(|file| {
+ Some(
+ file.disk_state()
+ .mtime()?
+ .to_seconds_and_nanos_for_persistence()?
+ .0,
+ )
+ }),
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let mut diagnostic_content = String::new();
+ let mut diagnostic_count = 0;
+
+ for entry in diagnostic_entries {
+ let start_point: Point = entry.range.start;
+
+ let severity = match entry.diagnostic.severity {
+ DiagnosticSeverity::ERROR => "error",
+ DiagnosticSeverity::WARNING => "warning",
+ DiagnosticSeverity::INFORMATION => "info",
+ DiagnosticSeverity::HINT => "hint",
+ _ => continue,
+ };
+
+ diagnostic_count += 1;
+
+ writeln!(
+ &mut diagnostic_content,
+ "{} at line {}: {}",
+ severity,
+ start_point.row + 1,
+ entry.diagnostic.message
+ )?;
+ }
+
+ if !diagnostic_content.is_empty() {
+ file_chunks.push(FileChunk {
+ file_path: format!("Diagnostics for {}", full_path.display()),
+ start_line: 0,
+ end_line: diagnostic_count,
+ content: diagnostic_content,
+ timestamp: None,
+ });
+ }
+
+ let request_body = AutocompleteRequest {
+ debug_info,
+ repo_name,
+ file_path: full_path.clone(),
+ file_contents: text.clone(),
+ original_file_contents: text,
+ cursor_position: offset,
+ recent_changes: recent_changes.clone(),
+ changes_above_cursor: true,
+ multiple_suggestions: false,
+ branch: None,
+ file_chunks,
+ retrieval_chunks: vec![],
+ recent_user_actions: vec![],
+ // TODO
+ privacy_mode_enabled: false,
+ };
+
+ let mut buf: Vec<u8> = Vec::new();
+ let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+ serde_json::to_writer(writer, &request_body)?;
+ let body: AsyncBody = buf.into();
+
+ let inputs = EditPredictionInputs {
+ events,
+ included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ path: full_path.clone(),
+ max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+ excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+ start_line: cloud_llm_client::predict_edits_v3::Line(0),
+ text: request_body.file_contents.into(),
+ }],
+ }],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ column: cursor_point.column,
+ line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+ },
+ cursor_path: full_path.clone(),
+ };
+
+ let request = http_client::Request::builder()
+ .uri(SWEEP_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .header("Content-Encoding", "br")
+ .method(Method::POST)
+ .body(body)?;
+
+ let mut response = http_client.send(request).await?;
+
+ let mut body: Vec<u8> = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ let response_received_at = Instant::now();
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+
+ let old_text = snapshot
+ .text_for_range(response.start_index..response.end_index)
+ .collect::<String>();
+ let edits = language::text_diff(&old_text, &response.completion)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(response.start_index + range.start)
+ ..snapshot.anchor_before(response.start_index + range.end),
+ text,
+ )
+ })
+ .collect::<Vec<_>>();
+
+ anyhow::Ok((
+ response.autocomplete_id,
+ edits,
+ snapshot,
+ response_received_at,
+ inputs,
+ ))
+ });
+
+ let buffer = active_buffer.clone();
+
+ cx.spawn(async move |cx| {
+ let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+ anyhow::Ok(
+ EditPrediction::new(
+ EditPredictionId(id.into()),
+ &buffer,
+ &old_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ response_received_at,
+ inputs,
+ cx,
+ )
+ .await,
+ )
+ })
+ }
+}
#[derive(Debug, Clone, Serialize)]
-pub struct AutocompleteRequest {
+struct AutocompleteRequest {
pub debug_info: Arc<str>,
pub repo_name: String,
pub branch: Option<String>,
@@ -22,7 +281,7 @@ pub struct AutocompleteRequest {
}
#[derive(Debug, Clone, Serialize)]
-pub struct FileChunk {
+struct FileChunk {
pub file_path: String,
pub start_line: usize,
pub end_line: usize,
@@ -31,7 +290,7 @@ pub struct FileChunk {
}
#[derive(Debug, Clone, Serialize)]
-pub struct RetrievalChunk {
+struct RetrievalChunk {
pub file_path: String,
pub start_line: usize,
pub end_line: usize,
@@ -40,7 +299,7 @@ pub struct RetrievalChunk {
}
#[derive(Debug, Clone, Serialize)]
-pub struct UserAction {
+struct UserAction {
pub action_type: ActionType,
pub line_number: usize,
pub offset: usize,
@@ -51,7 +310,7 @@ pub struct UserAction {
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum ActionType {
+enum ActionType {
CursorMovement,
InsertChar,
DeleteChar,
@@ -60,7 +319,7 @@ pub enum ActionType {
}
#[derive(Debug, Clone, Deserialize)]
-pub struct AutocompleteResponse {
+struct AutocompleteResponse {
pub autocomplete_id: String,
pub start_index: usize,
pub end_index: usize,
@@ -80,7 +339,7 @@ pub struct AutocompleteResponse {
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
-pub struct AdditionalCompletion {
+struct AdditionalCompletion {
pub start_index: usize,
pub end_index: usize,
pub completion: String,
@@ -90,7 +349,7 @@ pub struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
-pub(crate) fn write_event(
+fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
@@ -115,7 +374,7 @@ pub(crate) fn write_event(
}
}
-pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
+fn debug_info(cx: &gpui::App) -> Arc<str> {
format!(
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
version = release_channel::AppVersion::global(cx),
@@ -30,7 +30,6 @@ use language::{
};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use lsp::DiagnosticSeverity;
use open_ai::FunctionDefinition;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
@@ -42,7 +41,6 @@ use std::collections::{VecDeque, hash_map};
use telemetry_events::EditPredictionRating;
use workspace::Workspace;
-use std::fmt::Write as _;
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
@@ -80,6 +78,7 @@ use crate::rate_prediction_modal::{
NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
ThumbsUpActivePrediction,
};
+use crate::sweep_ai::SweepAi;
use crate::zeta1::request_prediction_with_zeta1;
pub use provider::ZetaEditPredictionProvider;
@@ -171,7 +170,7 @@ impl FeatureFlag for Zeta2FeatureFlag {
const NAME: &'static str = "zeta2";
fn enabled_for_staff() -> bool {
- false
+ true
}
}
@@ -192,8 +191,7 @@ pub struct Zeta {
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: ZetaEditPredictionModel,
- sweep_api_token: Option<String>,
- sweep_ai_debug_info: Arc<str>,
+ sweep_ai: SweepAi,
data_collection_choice: DataCollectionChoice,
rejected_predictions: Vec<EditPredictionRejection>,
reject_predictions_tx: mpsc::UnboundedSender<()>,
@@ -202,7 +200,7 @@ pub struct Zeta {
rated_predictions: HashSet<EditPredictionId>,
}
-#[derive(Default, PartialEq, Eq)]
+#[derive(Copy, Clone, Default, PartialEq, Eq)]
pub enum ZetaEditPredictionModel {
#[default]
Zeta1,
@@ -499,11 +497,8 @@ impl Zeta {
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: ZetaEditPredictionModel::Zeta2,
- sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
- .context("No SWEEP_AI_TOKEN environment variable set")
- .log_err(),
+ sweep_ai: SweepAi::new(cx),
data_collection_choice,
- sweep_ai_debug_info: sweep_ai::debug_info(cx),
rejected_predictions: Vec::new(),
reject_predictions_debounce_task: None,
reject_predictions_tx: reject_tx,
@@ -517,7 +512,7 @@ impl Zeta {
}
pub fn has_sweep_api_token(&self) -> bool {
- self.sweep_api_token.is_some()
+ self.sweep_ai.api_token.is_some()
}
#[cfg(feature = "eval-support")]
@@ -643,7 +638,9 @@ impl Zeta {
}
}
project::Event::DiagnosticsUpdated { .. } => {
- self.refresh_prediction_from_diagnostics(project, cx);
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ self.refresh_prediction_from_diagnostics(project, cx);
+ }
}
_ => (),
}
@@ -1183,249 +1180,77 @@ impl Zeta {
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 => {
- request_prediction_with_zeta1(self, project, active_buffer, position, cx)
- }
- ZetaEditPredictionModel::Zeta2 => {
- self.request_prediction_with_zeta2(project, active_buffer, position, cx)
- }
- ZetaEditPredictionModel::Sweep => {
- self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
- }
- }
+ self.request_prediction_internal(
+ project.clone(),
+ active_buffer.clone(),
+ position,
+ cx.has_flag::<Zeta2FeatureFlag>(),
+ cx,
+ )
}
- fn request_prediction_with_sweep(
+ fn request_prediction_internal(
&mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
+ project: Entity<Project>,
+ active_buffer: Entity<Buffer>,
position: language::Anchor,
allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
- let snapshot = active_buffer.read(cx).snapshot();
- let debug_info = self.sweep_ai_debug_info.clone();
- let Some(api_token) = self.sweep_api_token.clone() else {
- return Task::ready(Ok(None));
- };
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|file| file.full_path(cx))
- .unwrap_or_else(|| "untitled".into())
- .into();
-
- let project_file = project::File::from_dyn(snapshot.file());
- let repo_name = project_file
- .map(|file| file.worktree.read(cx).root_name_str())
- .unwrap_or("untitled")
- .into();
- let offset = position.to_offset(&snapshot);
+ const DIAGNOSTIC_LINES_RANGE: u32 = 20;
- let project_state = self.get_or_init_zeta_project(project, cx);
- let events = project_state.events(cx);
+ self.get_or_init_zeta_project(&project, cx);
+ let zeta_project = self.projects.get(&project.entity_id()).unwrap();
+ let events = zeta_project.events(cx);
let has_events = !events.is_empty();
- let recent_buffers = project_state.recent_paths.iter().cloned();
- let http_client = cx.http_client();
-
- let recent_buffer_snapshots = recent_buffers
- .filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
- None
- } else {
- Some(buffer.read(cx).snapshot())
- }
- })
- .take(3)
- .collect::<Vec<_>>();
-
- const DIAGNOSTIC_LINES_RANGE: u32 = 20;
+ let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
let diagnostic_search_range =
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
- let buffer_snapshotted_at = Instant::now();
-
- let result = cx.background_spawn({
- let snapshot = snapshot.clone();
- let diagnostic_search_range = diagnostic_search_range.clone();
- async move {
- let text = snapshot.text();
-
- let mut recent_changes = String::new();
- for event in &events {
- sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
- }
-
- let mut file_chunks = recent_buffer_snapshots
- .into_iter()
- .map(|snapshot| {
- let end_point = Point::new(30, 0).min(snapshot.max_point());
- sweep_ai::FileChunk {
- content: snapshot.text_for_range(Point::zero()..end_point).collect(),
- file_path: snapshot
- .file()
- .map(|f| f.path().as_unix_str())
- .unwrap_or("untitled")
- .to_string(),
- start_line: 0,
- end_line: end_point.row as usize,
- timestamp: snapshot.file().and_then(|file| {
- Some(
- file.disk_state()
- .mtime()?
- .to_seconds_and_nanos_for_persistence()?
- .0,
- )
- }),
- }
- })
- .collect::<Vec<_>>();
-
- let diagnostic_entries =
- snapshot.diagnostics_in_range(diagnostic_search_range, false);
- let mut diagnostic_content = String::new();
- let mut diagnostic_count = 0;
-
- for entry in diagnostic_entries {
- let start_point: Point = entry.range.start;
-
- let severity = match entry.diagnostic.severity {
- DiagnosticSeverity::ERROR => "error",
- DiagnosticSeverity::WARNING => "warning",
- DiagnosticSeverity::INFORMATION => "info",
- DiagnosticSeverity::HINT => "hint",
- _ => continue,
- };
-
- diagnostic_count += 1;
-
- writeln!(
- &mut diagnostic_content,
- "{} at line {}: {}",
- severity,
- start_point.row + 1,
- entry.diagnostic.message
- )?;
- }
-
- if !diagnostic_content.is_empty() {
- file_chunks.push(sweep_ai::FileChunk {
- file_path: format!("Diagnostics for {}", full_path.display()),
- start_line: 0,
- end_line: diagnostic_count,
- content: diagnostic_content,
- timestamp: None,
- });
- }
-
- let request_body = sweep_ai::AutocompleteRequest {
- debug_info,
- repo_name,
- file_path: full_path.clone(),
- file_contents: text.clone(),
- original_file_contents: text,
- cursor_position: offset,
- recent_changes: recent_changes.clone(),
- changes_above_cursor: true,
- multiple_suggestions: false,
- branch: None,
- file_chunks,
- retrieval_chunks: vec![],
- recent_user_actions: vec![],
- // TODO
- privacy_mode_enabled: false,
- };
- let mut buf: Vec<u8> = Vec::new();
- let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
- serde_json::to_writer(writer, &request_body)?;
- let body: AsyncBody = buf.into();
-
- let inputs = EditPredictionInputs {
- events,
- included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(0),
- text: request_body.file_contents.into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
- cursor_path: full_path.clone(),
- };
-
- const SWEEP_API_URL: &str =
- "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
-
- let request = http_client::Request::builder()
- .uri(SWEEP_API_URL)
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_token))
- .header("Connection", "keep-alive")
- .header("Content-Encoding", "br")
- .method(Method::POST)
- .body(body)?;
-
- let mut response = http_client.send(request).await?;
-
- let mut body: Vec<u8> = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
-
- let response_received_at = Instant::now();
- if !response.status().is_success() {
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- String::from_utf8_lossy(&body),
- );
- };
-
- let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
-
- let old_text = snapshot
- .text_for_range(response.start_index..response.end_index)
- .collect::<String>();
- let edits = language::text_diff(&old_text, &response.completion)
- .into_iter()
- .map(|(range, text)| {
- (
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
- text,
- )
- })
- .collect::<Vec<_>>();
-
- anyhow::Ok((
- response.autocomplete_id,
- edits,
- snapshot,
- response_received_at,
- inputs,
- ))
- }
- });
-
- let buffer = active_buffer.clone();
- let project = project.clone();
- let active_buffer = active_buffer.clone();
+ let task = match self.edit_prediction_model {
+ ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
+ self,
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ cx,
+ ),
+ ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ cx,
+ ),
+ ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ &zeta_project.recent_paths,
+ diagnostic_search_range.clone(),
+ cx,
+ ),
+ };
cx.spawn(async move |this, cx| {
- let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+ let prediction = task
+ .await?
+ .filter(|prediction| !prediction.edits.is_empty());
- if edits.is_empty() {
+ if prediction.is_none() && allow_jump {
+ let cursor_point = position.to_point(&snapshot);
if has_events
- && allow_jump
&& let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
+ active_buffer.clone(),
&snapshot,
diagnostic_search_range,
cursor_point,
@@ -1436,9 +1261,9 @@ impl Zeta {
{
return this
.update(cx, |this, cx| {
- this.request_prediction_with_sweep(
- &project,
- &jump_buffer,
+ this.request_prediction_internal(
+ project,
+ jump_buffer,
jump_position,
false,
cx,
@@ -1450,19 +1275,7 @@ impl Zeta {
return anyhow::Ok(None);
}
- anyhow::Ok(
- EditPrediction::new(
- EditPredictionId(id.into()),
- &buffer,
- &old_snapshot,
- edits.into(),
- buffer_snapshotted_at,
- response_received_at,
- inputs,
- cx,
- )
- .await,
- )
+ Ok(prediction)
})
}
@@ -1549,7 +1362,9 @@ impl Zeta {
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
+ active_snapshot: BufferSnapshot,
position: language::Anchor,
+ events: Vec<Arc<Event>>,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let project_state = self.projects.get(&project.entity_id());
@@ -1561,7 +1376,6 @@ impl Zeta {
.map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
});
let options = self.options.clone();
- let active_snapshot = active_buffer.read(cx).snapshot();
let buffer_snapshotted_at = Instant::now();
let Some(excerpt_path) = active_snapshot
.file()
@@ -1579,10 +1393,6 @@ impl Zeta {
.collect::<Vec<_>>();
let debug_tx = self.debug_tx.clone();
- let events = project_state
- .map(|state| state.events(cx))
- .unwrap_or_default();
-
let diagnostics = active_snapshot.diagnostic_sets().clone();
let file = active_buffer.read(cx).file();
@@ -32,19 +32,17 @@ pub(crate) fn request_prediction_with_zeta1(
zeta: &mut Zeta,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
+ snapshot: BufferSnapshot,
position: language::Anchor,
+ events: Vec<Arc<Event>>,
cx: &mut Context<Zeta>,
) -> Task<Result<Option<EditPrediction>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
- let snapshot = buffer.read(cx).snapshot();
let client = zeta.client.clone();
let llm_token = zeta.llm_token.clone();
let app_version = AppVersion::global(cx);
- let zeta_project = zeta.get_or_init_zeta_project(project, cx);
- let events = Arc::new(zeta_project.events(cx));
-
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = zeta.can_collect_file(project, file, cx);
let git_info = if can_collect_file {
@@ -42,43 +42,48 @@ actions!(
pub fn init(cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2Inspector::new(
- &project,
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- );
- });
- })
- .detach();
-
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action(move |workspace, _: &OpenZeta2ContextView, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2ContextView::new(
- project.clone(),
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- );
+ workspace.register_action_renderer(|div, _, _, cx| {
+ let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
+ div.when(has_flag, |div| {
+ div.on_action(
+ cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
+ let project = workspace.project();
+ workspace.split_item(
+ SplitDirection::Right,
+ Box::new(cx.new(|cx| {
+ Zeta2Inspector::new(
+ &project,
+ workspace.client(),
+ workspace.user_store(),
+ window,
+ cx,
+ )
+ })),
+ window,
+ cx,
+ )
+ }),
+ )
+ .on_action(cx.listener(
+ move |workspace, _: &OpenZeta2ContextView, window, cx| {
+ let project = workspace.project();
+ workspace.split_item(
+ SplitDirection::Right,
+ Box::new(cx.new(|cx| {
+ Zeta2ContextView::new(
+ project.clone(),
+ workspace.client(),
+ workspace.user_store(),
+ window,
+ cx,
+ )
+ })),
+ window,
+ cx,
+ );
+ },
+ ))
+ })
});
})
.detach();