1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use db::smol::stream::StreamExt as _;
12use edit_prediction::DataCollectionState;
13use futures::channel::mpsc;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use arrayvec::ArrayVec;
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
23 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
24 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
25 ZED_VERSION_HEADER_NAME,
26};
27use collections::{HashMap, HashSet, VecDeque};
28use futures::AsyncReadExt;
29use gpui::{
30 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SharedString, Subscription,
31 Task, actions,
32};
33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
34use input_excerpt::excerpt_for_cursor_position;
35use language::{
36 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
37};
38use language_model::{LlmApiToken, RefreshLlmTokenListener};
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use semver::Version;
42use settings::WorktreeId;
43use std::collections::hash_map;
44use std::mem;
45use std::str::FromStr;
46use std::{
47 cmp,
48 fmt::Write,
49 future::Future,
50 ops::Range,
51 path::Path,
52 rc::Rc,
53 sync::Arc,
54 time::{Duration, Instant},
55};
56use telemetry_events::EditPredictionRating;
57use thiserror::Error;
58use util::ResultExt;
59use util::rel_path::RelPath;
60use uuid::Uuid;
61use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
62use worktree::Worktree;
63
64const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
65const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
66const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
67const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
68const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
69const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
70
71const MAX_CONTEXT_TOKENS: usize = 150;
72const MAX_REWRITE_TOKENS: usize = 350;
73const MAX_EVENT_TOKENS: usize = 500;
74
75/// Maximum number of events to track.
76const MAX_EVENT_COUNT: usize = 16;
77
78actions!(
79 edit_prediction,
80 [
81 /// Clears the edit prediction history.
82 ClearHistory
83 ]
84);
85
86#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
87pub struct EditPredictionId(Uuid);
88
89impl From<EditPredictionId> for gpui::ElementId {
90 fn from(value: EditPredictionId) -> Self {
91 gpui::ElementId::Uuid(value.0)
92 }
93}
94
95impl std::fmt::Display for EditPredictionId {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.0)
98 }
99}
100
101struct ZedPredictUpsell;
102
103impl Dismissable for ZedPredictUpsell {
104 const KEY: &'static str = "dismissed-edit-predict-upsell";
105
106 fn dismissed() -> bool {
107 // To make this backwards compatible with older versions of Zed, we
108 // check if the user has seen the previous Edit Prediction Onboarding
109 // before, by checking the data collection choice which was written to
110 // the database once the user clicked on "Accept and Enable"
111 if KEY_VALUE_STORE
112 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
113 .log_err()
114 .is_some_and(|s| s.is_some())
115 {
116 return true;
117 }
118
119 KEY_VALUE_STORE
120 .read_kvp(Self::KEY)
121 .log_err()
122 .is_some_and(|s| s.is_some())
123 }
124}
125
126pub fn should_show_upsell_modal() -> bool {
127 !ZedPredictUpsell::dismissed()
128}
129
130#[derive(Clone)]
131struct ZetaGlobal(Entity<Zeta>);
132
133impl Global for ZetaGlobal {}
134
135#[derive(Clone)]
136pub struct EditPrediction {
137 id: EditPredictionId,
138 path: Arc<Path>,
139 excerpt_range: Range<usize>,
140 cursor_offset: usize,
141 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
142 snapshot: BufferSnapshot,
143 edit_preview: EditPreview,
144 input_outline: Arc<str>,
145 input_events: Arc<str>,
146 input_excerpt: Arc<str>,
147 output_excerpt: Arc<str>,
148 buffer_snapshotted_at: Instant,
149 response_received_at: Instant,
150}
151
152impl EditPrediction {
153 fn latency(&self) -> Duration {
154 self.response_received_at
155 .duration_since(self.buffer_snapshotted_at)
156 }
157
158 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
159 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
160 }
161}
162
163impl std::fmt::Debug for EditPrediction {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("EditPrediction")
166 .field("id", &self.id)
167 .field("path", &self.path)
168 .field("edits", &self.edits)
169 .finish_non_exhaustive()
170 }
171}
172
173pub struct Zeta {
174 projects: HashMap<EntityId, ZetaProject>,
175 client: Arc<Client>,
176 shown_completions: VecDeque<EditPrediction>,
177 rated_completions: HashSet<EditPredictionId>,
178 data_collection_choice: DataCollectionChoice,
179 discarded_completions: Vec<EditPredictionRejection>,
180 llm_token: LlmApiToken,
181 _llm_token_subscription: Subscription,
182 /// Whether an update to a newer version of Zed is required to continue using Zeta.
183 update_required: bool,
184 user_store: Entity<UserStore>,
185 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
186 discard_completions_debounce_task: Option<Task<()>>,
187 discard_completions_tx: mpsc::UnboundedSender<()>,
188}
189
190struct ZetaProject {
191 events: VecDeque<Event>,
192 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
193}
194
195impl Zeta {
196 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
197 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
198 }
199
200 pub fn register(
201 worktree: Option<Entity<Worktree>>,
202 client: Arc<Client>,
203 user_store: Entity<UserStore>,
204 cx: &mut App,
205 ) -> Entity<Self> {
206 let this = Self::global(cx).unwrap_or_else(|| {
207 let entity = cx.new(|cx| Self::new(client, user_store, cx));
208 cx.set_global(ZetaGlobal(entity.clone()));
209 entity
210 });
211
212 this.update(cx, move |this, cx| {
213 if let Some(worktree) = worktree {
214 let worktree_id = worktree.read(cx).id();
215 this.license_detection_watchers
216 .entry(worktree_id)
217 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
218 }
219 });
220
221 this
222 }
223
224 pub fn clear_history(&mut self) {
225 for zeta_project in self.projects.values_mut() {
226 zeta_project.events.clear();
227 }
228 }
229
230 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
231 self.user_store.read(cx).edit_prediction_usage()
232 }
233
234 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
235 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
236 let data_collection_choice = Self::load_data_collection_choice();
237 let (reject_tx, mut reject_rx) = mpsc::unbounded();
238 cx.spawn(async move |this, cx| {
239 while let Some(()) = reject_rx.next().await {
240 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
241 .await
242 .log_err();
243 }
244 anyhow::Ok(())
245 })
246 .detach();
247
248 Self {
249 projects: HashMap::default(),
250 client,
251 shown_completions: VecDeque::new(),
252 rated_completions: HashSet::default(),
253 discarded_completions: Vec::new(),
254 discard_completions_debounce_task: None,
255 discard_completions_tx: reject_tx,
256 data_collection_choice,
257 llm_token: LlmApiToken::default(),
258 _llm_token_subscription: cx.subscribe(
259 &refresh_llm_token_listener,
260 |this, _listener, _event, cx| {
261 let client = this.client.clone();
262 let llm_token = this.llm_token.clone();
263 cx.spawn(async move |_this, _cx| {
264 llm_token.refresh(&client).await?;
265 anyhow::Ok(())
266 })
267 .detach_and_log_err(cx);
268 },
269 ),
270 update_required: false,
271 license_detection_watchers: HashMap::default(),
272 user_store,
273 }
274 }
275
276 fn get_or_init_zeta_project(
277 &mut self,
278 project: &Entity<Project>,
279 cx: &mut Context<Self>,
280 ) -> &mut ZetaProject {
281 let project_id = project.entity_id();
282 match self.projects.entry(project_id) {
283 hash_map::Entry::Occupied(entry) => entry.into_mut(),
284 hash_map::Entry::Vacant(entry) => {
285 cx.observe_release(project, move |this, _, _cx| {
286 this.projects.remove(&project_id);
287 })
288 .detach();
289 entry.insert(ZetaProject {
290 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
291 registered_buffers: HashMap::default(),
292 })
293 }
294 }
295 }
296
297 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
298 let events = &mut zeta_project.events;
299
300 if let Some(Event::BufferChange {
301 new_snapshot: last_new_snapshot,
302 timestamp: last_timestamp,
303 ..
304 }) = events.back_mut()
305 {
306 // Coalesce edits for the same buffer when they happen one after the other.
307 let Event::BufferChange {
308 old_snapshot,
309 new_snapshot,
310 timestamp,
311 } = &event;
312
313 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
314 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
315 && old_snapshot.version == last_new_snapshot.version
316 {
317 *last_new_snapshot = new_snapshot.clone();
318 *last_timestamp = *timestamp;
319 return;
320 }
321 }
322
323 if events.len() >= MAX_EVENT_COUNT {
324 // These are halved instead of popping to improve prompt caching.
325 events.drain(..MAX_EVENT_COUNT / 2);
326 }
327
328 events.push_back(event);
329 }
330
331 pub fn register_buffer(
332 &mut self,
333 buffer: &Entity<Buffer>,
334 project: &Entity<Project>,
335 cx: &mut Context<Self>,
336 ) {
337 let zeta_project = self.get_or_init_zeta_project(project, cx);
338 Self::register_buffer_impl(zeta_project, buffer, project, cx);
339 }
340
341 fn register_buffer_impl<'a>(
342 zeta_project: &'a mut ZetaProject,
343 buffer: &Entity<Buffer>,
344 project: &Entity<Project>,
345 cx: &mut Context<Self>,
346 ) -> &'a mut RegisteredBuffer {
347 let buffer_id = buffer.entity_id();
348 match zeta_project.registered_buffers.entry(buffer_id) {
349 hash_map::Entry::Occupied(entry) => entry.into_mut(),
350 hash_map::Entry::Vacant(entry) => {
351 let snapshot = buffer.read(cx).snapshot();
352 let project_entity_id = project.entity_id();
353 entry.insert(RegisteredBuffer {
354 snapshot,
355 _subscriptions: [
356 cx.subscribe(buffer, {
357 let project = project.downgrade();
358 move |this, buffer, event, cx| {
359 if let language::BufferEvent::Edited = event
360 && let Some(project) = project.upgrade()
361 {
362 this.report_changes_for_buffer(&buffer, &project, cx);
363 }
364 }
365 }),
366 cx.observe_release(buffer, move |this, _buffer, _cx| {
367 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
368 else {
369 return;
370 };
371 zeta_project.registered_buffers.remove(&buffer_id);
372 }),
373 ],
374 })
375 }
376 }
377 }
378
379 fn request_completion_impl<F, R>(
380 &mut self,
381 project: &Entity<Project>,
382 buffer: &Entity<Buffer>,
383 cursor: language::Anchor,
384 cx: &mut Context<Self>,
385 perform_predict_edits: F,
386 ) -> Task<Result<Option<EditPrediction>>>
387 where
388 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
389 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
390 + Send
391 + 'static,
392 {
393 let buffer = buffer.clone();
394 let buffer_snapshotted_at = Instant::now();
395 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
396 let zeta = cx.entity();
397 let client = self.client.clone();
398 let llm_token = self.llm_token.clone();
399 let app_version = AppVersion::global(cx);
400
401 let zeta_project = self.get_or_init_zeta_project(project, cx);
402 let mut events = Vec::with_capacity(zeta_project.events.len());
403 events.extend(zeta_project.events.iter().cloned());
404 let events = Arc::new(events);
405
406 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
407 let can_collect_file = self.can_collect_file(file, cx);
408 let git_info = if can_collect_file {
409 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
410 } else {
411 None
412 };
413 (git_info, can_collect_file)
414 } else {
415 (None, false)
416 };
417
418 let full_path: Arc<Path> = snapshot
419 .file()
420 .map(|f| Arc::from(f.full_path(cx).as_path()))
421 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
422 let full_path_str = full_path.to_string_lossy().into_owned();
423 let cursor_point = cursor.to_point(&snapshot);
424 let cursor_offset = cursor_point.to_offset(&snapshot);
425 let prompt_for_events = {
426 let events = events.clone();
427 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
428 };
429 let gather_task = gather_context(
430 full_path_str,
431 &snapshot,
432 cursor_point,
433 prompt_for_events,
434 cx,
435 );
436
437 cx.spawn(async move |this, cx| {
438 let GatherContextOutput {
439 mut body,
440 editable_range,
441 included_events_count,
442 } = gather_task.await?;
443 let done_gathering_context_at = Instant::now();
444
445 let included_events = &events[events.len() - included_events_count..events.len()];
446 body.can_collect_data = can_collect_file
447 && this
448 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
449 .unwrap_or(false);
450 if body.can_collect_data {
451 body.git_info = git_info;
452 }
453
454 log::debug!(
455 "Events:\n{}\nExcerpt:\n{:?}",
456 body.input_events,
457 body.input_excerpt
458 );
459
460 let input_outline = body.outline.clone().unwrap_or_default();
461 let input_events = body.input_events.clone();
462 let input_excerpt = body.input_excerpt.clone();
463
464 let response = perform_predict_edits(PerformPredictEditsParams {
465 client,
466 llm_token,
467 app_version,
468 body,
469 })
470 .await;
471 let (response, usage) = match response {
472 Ok(response) => response,
473 Err(err) => {
474 if err.is::<ZedUpdateRequiredError>() {
475 cx.update(|cx| {
476 zeta.update(cx, |zeta, _cx| {
477 zeta.update_required = true;
478 });
479
480 let error_message: SharedString = err.to_string().into();
481 show_app_notification(
482 NotificationId::unique::<ZedUpdateRequiredError>(),
483 cx,
484 move |cx| {
485 cx.new(|cx| {
486 ErrorMessagePrompt::new(error_message.clone(), cx)
487 .with_link_button(
488 "Update Zed",
489 "https://zed.dev/releases",
490 )
491 })
492 },
493 );
494 })
495 .ok();
496 }
497
498 return Err(err);
499 }
500 };
501
502 let received_response_at = Instant::now();
503 log::debug!("completion response: {}", &response.output_excerpt);
504
505 if let Some(usage) = usage {
506 this.update(cx, |this, cx| {
507 this.user_store.update(cx, |user_store, cx| {
508 user_store.update_edit_prediction_usage(usage, cx);
509 });
510 })
511 .ok();
512 }
513
514 let edit_prediction = Self::process_completion_response(
515 response,
516 buffer,
517 &snapshot,
518 editable_range,
519 cursor_offset,
520 full_path,
521 input_outline,
522 input_events,
523 input_excerpt,
524 buffer_snapshotted_at,
525 cx,
526 )
527 .await;
528
529 let finished_at = Instant::now();
530
531 // record latency for ~1% of requests
532 if rand::random::<u8>() <= 2 {
533 telemetry::event!(
534 "Edit Prediction Request",
535 context_latency = done_gathering_context_at
536 .duration_since(buffer_snapshotted_at)
537 .as_millis(),
538 request_latency = received_response_at
539 .duration_since(done_gathering_context_at)
540 .as_millis(),
541 process_latency = finished_at.duration_since(received_response_at).as_millis()
542 );
543 }
544
545 edit_prediction
546 })
547 }
548
549 #[cfg(any(test, feature = "test-support"))]
550 pub fn fake_completion(
551 &mut self,
552 project: &Entity<Project>,
553 buffer: &Entity<Buffer>,
554 position: language::Anchor,
555 response: PredictEditsResponse,
556 cx: &mut Context<Self>,
557 ) -> Task<Result<Option<EditPrediction>>> {
558 self.request_completion_impl(project, buffer, position, cx, |_params| {
559 std::future::ready(Ok((response, None)))
560 })
561 }
562
563 pub fn request_completion(
564 &mut self,
565 project: &Entity<Project>,
566 buffer: &Entity<Buffer>,
567 position: language::Anchor,
568 cx: &mut Context<Self>,
569 ) -> Task<Result<Option<EditPrediction>>> {
570 self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
571 }
572
573 pub fn perform_predict_edits(
574 params: PerformPredictEditsParams,
575 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
576 async move {
577 let PerformPredictEditsParams {
578 client,
579 llm_token,
580 app_version,
581 body,
582 ..
583 } = params;
584
585 let http_client = client.http_client();
586 let mut token = llm_token.acquire(&client).await?;
587 let mut did_retry = false;
588
589 loop {
590 let request_builder = http_client::Request::builder().method(Method::POST);
591 let request_builder =
592 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
593 request_builder.uri(predict_edits_url)
594 } else {
595 request_builder.uri(
596 http_client
597 .build_zed_llm_url("/predict_edits/v2", &[])?
598 .as_ref(),
599 )
600 };
601 let request = request_builder
602 .header("Content-Type", "application/json")
603 .header("Authorization", format!("Bearer {}", token))
604 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
605 .body(serde_json::to_string(&body)?.into())?;
606
607 let mut response = http_client.send(request).await?;
608
609 if let Some(minimum_required_version) = response
610 .headers()
611 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
612 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
613 {
614 anyhow::ensure!(
615 app_version >= minimum_required_version,
616 ZedUpdateRequiredError {
617 minimum_version: minimum_required_version
618 }
619 );
620 }
621
622 if response.status().is_success() {
623 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
624
625 let mut body = String::new();
626 response.body_mut().read_to_string(&mut body).await?;
627 return Ok((serde_json::from_str(&body)?, usage));
628 } else if !did_retry
629 && response
630 .headers()
631 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
632 .is_some()
633 {
634 did_retry = true;
635 token = llm_token.refresh(&client).await?;
636 } else {
637 let mut body = String::new();
638 response.body_mut().read_to_string(&mut body).await?;
639 anyhow::bail!(
640 "error predicting edits.\nStatus: {:?}\nBody: {}",
641 response.status(),
642 body
643 );
644 }
645 }
646 }
647 }
648
649 fn accept_edit_prediction(
650 &mut self,
651 request_id: EditPredictionId,
652 cx: &mut Context<Self>,
653 ) -> Task<Result<()>> {
654 let client = self.client.clone();
655 let llm_token = self.llm_token.clone();
656 let app_version = AppVersion::global(cx);
657 cx.spawn(async move |this, cx| {
658 let http_client = client.http_client();
659 let mut response = llm_token_retry(&llm_token, &client, |token| {
660 let request_builder = http_client::Request::builder().method(Method::POST);
661 let request_builder =
662 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
663 request_builder.uri(accept_prediction_url)
664 } else {
665 request_builder.uri(
666 http_client
667 .build_zed_llm_url("/predict_edits/accept", &[])?
668 .as_ref(),
669 )
670 };
671 Ok(request_builder
672 .header("Content-Type", "application/json")
673 .header("Authorization", format!("Bearer {}", token))
674 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
675 .body(
676 serde_json::to_string(&AcceptEditPredictionBody {
677 request_id: request_id.0.to_string(),
678 })?
679 .into(),
680 )?)
681 })
682 .await?;
683
684 if let Some(minimum_required_version) = response
685 .headers()
686 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
687 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
688 && app_version < minimum_required_version
689 {
690 return Err(anyhow!(ZedUpdateRequiredError {
691 minimum_version: minimum_required_version
692 }));
693 }
694
695 if response.status().is_success() {
696 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
697 this.update(cx, |this, cx| {
698 this.user_store.update(cx, |user_store, cx| {
699 user_store.update_edit_prediction_usage(usage, cx);
700 });
701 })?;
702 }
703
704 Ok(())
705 } else {
706 let mut body = String::new();
707 response.body_mut().read_to_string(&mut body).await?;
708 Err(anyhow!(
709 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
710 response.status(),
711 body
712 ))
713 }
714 })
715 }
716
717 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
718 let client = self.client.clone();
719 let llm_token = self.llm_token.clone();
720 let app_version = AppVersion::global(cx);
721 let last_rejection = self.discarded_completions.last().cloned();
722 let body = serde_json::to_string(&RejectEditPredictionsBody {
723 rejections: self.discarded_completions.clone(),
724 })
725 .ok();
726
727 let Some(last_rejection) = last_rejection else {
728 return Task::ready(anyhow::Ok(()));
729 };
730
731 cx.spawn(async move |this, cx| {
732 let http_client = client.http_client();
733 let mut response = llm_token_retry(&llm_token, &client, |token| {
734 let request_builder = http_client::Request::builder().method(Method::POST);
735 let request_builder = request_builder.uri(
736 http_client
737 .build_zed_llm_url("/predict_edits/reject", &[])?
738 .as_ref(),
739 );
740 Ok(request_builder
741 .header("Content-Type", "application/json")
742 .header("Authorization", format!("Bearer {}", token))
743 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
744 .body(
745 body.as_ref()
746 .context("failed to serialize body")?
747 .clone()
748 .into(),
749 )?)
750 })
751 .await?;
752
753 if let Some(minimum_required_version) = response
754 .headers()
755 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
756 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
757 && app_version < minimum_required_version
758 {
759 return Err(anyhow!(ZedUpdateRequiredError {
760 minimum_version: minimum_required_version
761 }));
762 }
763
764 if response.status().is_success() {
765 this.update(cx, |this, _| {
766 if let Some(ix) = this
767 .discarded_completions
768 .iter()
769 .position(|rejection| rejection.request_id == last_rejection.request_id)
770 {
771 this.discarded_completions.drain(..ix + 1);
772 }
773 })
774 } else {
775 let mut body = String::new();
776 response.body_mut().read_to_string(&mut body).await?;
777 Err(anyhow!(
778 "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
779 response.status(),
780 body
781 ))
782 }
783 })
784 }
785
786 fn process_completion_response(
787 prediction_response: PredictEditsResponse,
788 buffer: Entity<Buffer>,
789 snapshot: &BufferSnapshot,
790 editable_range: Range<usize>,
791 cursor_offset: usize,
792 path: Arc<Path>,
793 input_outline: String,
794 input_events: String,
795 input_excerpt: String,
796 buffer_snapshotted_at: Instant,
797 cx: &AsyncApp,
798 ) -> Task<Result<Option<EditPrediction>>> {
799 let snapshot = snapshot.clone();
800 let request_id = prediction_response.request_id;
801 let output_excerpt = prediction_response.output_excerpt;
802 cx.spawn(async move |cx| {
803 let output_excerpt: Arc<str> = output_excerpt.into();
804
805 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
806 .background_spawn({
807 let output_excerpt = output_excerpt.clone();
808 let editable_range = editable_range.clone();
809 let snapshot = snapshot.clone();
810 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
811 })
812 .await?
813 .into();
814
815 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
816 let edits = edits.clone();
817 move |buffer, cx| {
818 let new_snapshot = buffer.snapshot();
819 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
820 edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
821 .into();
822 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
823 }
824 })?
825 else {
826 return anyhow::Ok(None);
827 };
828
829 let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
830
831 let edit_preview = edit_preview.await;
832
833 Ok(Some(EditPrediction {
834 id: EditPredictionId(request_id),
835 path,
836 excerpt_range: editable_range,
837 cursor_offset,
838 edits,
839 edit_preview,
840 snapshot,
841 input_outline: input_outline.into(),
842 input_events: input_events.into(),
843 input_excerpt: input_excerpt.into(),
844 output_excerpt,
845 buffer_snapshotted_at,
846 response_received_at: Instant::now(),
847 }))
848 })
849 }
850
851 fn parse_edits(
852 output_excerpt: Arc<str>,
853 editable_range: Range<usize>,
854 snapshot: &BufferSnapshot,
855 ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
856 let content = output_excerpt.replace(CURSOR_MARKER, "");
857
858 let start_markers = content
859 .match_indices(EDITABLE_REGION_START_MARKER)
860 .collect::<Vec<_>>();
861 anyhow::ensure!(
862 start_markers.len() == 1,
863 "expected exactly one start marker, found {}",
864 start_markers.len()
865 );
866
867 let end_markers = content
868 .match_indices(EDITABLE_REGION_END_MARKER)
869 .collect::<Vec<_>>();
870 anyhow::ensure!(
871 end_markers.len() == 1,
872 "expected exactly one end marker, found {}",
873 end_markers.len()
874 );
875
876 let sof_markers = content
877 .match_indices(START_OF_FILE_MARKER)
878 .collect::<Vec<_>>();
879 anyhow::ensure!(
880 sof_markers.len() <= 1,
881 "expected at most one start-of-file marker, found {}",
882 sof_markers.len()
883 );
884
885 let codefence_start = start_markers[0].0;
886 let content = &content[codefence_start..];
887
888 let newline_ix = content.find('\n').context("could not find newline")?;
889 let content = &content[newline_ix + 1..];
890
891 let codefence_end = content
892 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
893 .context("could not find end marker")?;
894 let new_text = &content[..codefence_end];
895
896 let old_text = snapshot
897 .text_for_range(editable_range.clone())
898 .collect::<String>();
899
900 Ok(Self::compute_edits(
901 old_text,
902 new_text,
903 editable_range.start,
904 snapshot,
905 ))
906 }
907
908 pub fn compute_edits(
909 old_text: String,
910 new_text: &str,
911 offset: usize,
912 snapshot: &BufferSnapshot,
913 ) -> Vec<(Range<Anchor>, Arc<str>)> {
914 text_diff(&old_text, new_text)
915 .into_iter()
916 .map(|(mut old_range, new_text)| {
917 old_range.start += offset;
918 old_range.end += offset;
919
920 let prefix_len = common_prefix(
921 snapshot.chars_for_range(old_range.clone()),
922 new_text.chars(),
923 );
924 old_range.start += prefix_len;
925
926 let suffix_len = common_prefix(
927 snapshot.reversed_chars_for_range(old_range.clone()),
928 new_text[prefix_len..].chars().rev(),
929 );
930 old_range.end = old_range.end.saturating_sub(suffix_len);
931
932 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
933 let range = if old_range.is_empty() {
934 let anchor = snapshot.anchor_after(old_range.start);
935 anchor..anchor
936 } else {
937 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
938 };
939 (range, new_text)
940 })
941 .collect()
942 }
943
944 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
945 self.rated_completions.contains(&completion_id)
946 }
947
948 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
949 self.shown_completions.push_front(completion.clone());
950 if self.shown_completions.len() > 50 {
951 let completion = self.shown_completions.pop_back().unwrap();
952 self.rated_completions.remove(&completion.id);
953 }
954 cx.notify();
955 }
956
957 pub fn rate_completion(
958 &mut self,
959 completion: &EditPrediction,
960 rating: EditPredictionRating,
961 feedback: String,
962 cx: &mut Context<Self>,
963 ) {
964 self.rated_completions.insert(completion.id);
965 telemetry::event!(
966 "Edit Prediction Rated",
967 rating,
968 input_events = completion.input_events,
969 input_excerpt = completion.input_excerpt,
970 input_outline = completion.input_outline,
971 output_excerpt = completion.output_excerpt,
972 feedback
973 );
974 self.client.telemetry().flush_events().detach();
975 cx.notify();
976 }
977
978 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
979 self.shown_completions.iter()
980 }
981
982 pub fn shown_completions_len(&self) -> usize {
983 self.shown_completions.len()
984 }
985
986 fn report_changes_for_buffer(
987 &mut self,
988 buffer: &Entity<Buffer>,
989 project: &Entity<Project>,
990 cx: &mut Context<Self>,
991 ) -> BufferSnapshot {
992 let zeta_project = self.get_or_init_zeta_project(project, cx);
993 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
994
995 let new_snapshot = buffer.read(cx).snapshot();
996 if new_snapshot.version != registered_buffer.snapshot.version {
997 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
998 Self::push_event(
999 zeta_project,
1000 Event::BufferChange {
1001 old_snapshot,
1002 new_snapshot: new_snapshot.clone(),
1003 timestamp: Instant::now(),
1004 },
1005 );
1006 }
1007
1008 new_snapshot
1009 }
1010
1011 fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1012 self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1013 }
1014
1015 fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1016 if !self.data_collection_choice.is_enabled() {
1017 return false;
1018 }
1019 let mut last_checked_file = None;
1020 for event in events {
1021 match event {
1022 Event::BufferChange {
1023 old_snapshot,
1024 new_snapshot,
1025 ..
1026 } => {
1027 if let Some(old_file) = old_snapshot.file()
1028 && let Some(new_file) = new_snapshot.file()
1029 {
1030 if let Some(last_checked_file) = last_checked_file
1031 && Arc::ptr_eq(last_checked_file, old_file)
1032 && Arc::ptr_eq(last_checked_file, new_file)
1033 {
1034 continue;
1035 }
1036 if !self.can_collect_file(old_file, cx) {
1037 return false;
1038 }
1039 if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1040 {
1041 return false;
1042 }
1043 last_checked_file = Some(new_file);
1044 } else {
1045 return false;
1046 }
1047 }
1048 }
1049 }
1050 true
1051 }
1052
1053 fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1054 if !file.is_local() || file.is_private() {
1055 return false;
1056 }
1057 self.license_detection_watchers
1058 .get(&file.worktree_id(cx))
1059 .is_some_and(|watcher| watcher.is_project_open_source())
1060 }
1061
1062 fn load_data_collection_choice() -> DataCollectionChoice {
1063 let choice = KEY_VALUE_STORE
1064 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1065 .log_err()
1066 .flatten();
1067
1068 match choice.as_deref() {
1069 Some("true") => DataCollectionChoice::Enabled,
1070 Some("false") => DataCollectionChoice::Disabled,
1071 Some(_) => {
1072 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1073 DataCollectionChoice::NotAnswered
1074 }
1075 None => DataCollectionChoice::NotAnswered,
1076 }
1077 }
1078
1079 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1080 self.data_collection_choice = self.data_collection_choice.toggle();
1081 let new_choice = self.data_collection_choice;
1082 db::write_and_log(cx, move || {
1083 KEY_VALUE_STORE.write_kvp(
1084 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1085 new_choice.is_enabled().to_string(),
1086 )
1087 });
1088 }
1089
1090 fn discard_completion(
1091 &mut self,
1092 completion_id: EditPredictionId,
1093 was_shown: bool,
1094 cx: &mut Context<Self>,
1095 ) {
1096 self.discarded_completions.push(EditPredictionRejection {
1097 request_id: completion_id.to_string(),
1098 was_shown,
1099 });
1100
1101 let reached_request_limit =
1102 self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1103 let discard_completions_tx = self.discard_completions_tx.clone();
1104 self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1105 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1106 if !reached_request_limit {
1107 cx.background_executor()
1108 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1109 .await;
1110 }
1111 discard_completions_tx.unbounded_send(()).log_err();
1112 }));
1113 }
1114}
1115
1116pub struct PerformPredictEditsParams {
1117 pub client: Arc<Client>,
1118 pub llm_token: LlmApiToken,
1119 pub app_version: Version,
1120 pub body: PredictEditsBody,
1121}
1122
1123#[derive(Error, Debug)]
1124#[error(
1125 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1126)]
1127pub struct ZedUpdateRequiredError {
1128 minimum_version: Version,
1129}
1130
1131fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1132 a.zip(b)
1133 .take_while(|(a, b)| a == b)
1134 .map(|(a, _)| a.len_utf8())
1135 .sum()
1136}
1137
1138fn git_info_for_file(
1139 project: &Entity<Project>,
1140 project_path: &ProjectPath,
1141 cx: &App,
1142) -> Option<PredictEditsGitInfo> {
1143 let git_store = project.read(cx).git_store().read(cx);
1144 if let Some((repository, _repo_path)) =
1145 git_store.repository_and_path_for_project_path(project_path, cx)
1146 {
1147 let repository = repository.read(cx);
1148 let head_sha = repository
1149 .head_commit
1150 .as_ref()
1151 .map(|head_commit| head_commit.sha.to_string());
1152 let remote_origin_url = repository.remote_origin_url.clone();
1153 let remote_upstream_url = repository.remote_upstream_url.clone();
1154 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1155 return None;
1156 }
1157 Some(PredictEditsGitInfo {
1158 head_sha,
1159 remote_origin_url,
1160 remote_upstream_url,
1161 })
1162 } else {
1163 None
1164 }
1165}
1166
1167pub struct GatherContextOutput {
1168 pub body: PredictEditsBody,
1169 pub editable_range: Range<usize>,
1170 pub included_events_count: usize,
1171}
1172
1173pub fn gather_context(
1174 full_path_str: String,
1175 snapshot: &BufferSnapshot,
1176 cursor_point: language::Point,
1177 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1178 cx: &App,
1179) -> Task<Result<GatherContextOutput>> {
1180 cx.background_spawn({
1181 let snapshot = snapshot.clone();
1182 async move {
1183 let input_excerpt = excerpt_for_cursor_position(
1184 cursor_point,
1185 &full_path_str,
1186 &snapshot,
1187 MAX_REWRITE_TOKENS,
1188 MAX_CONTEXT_TOKENS,
1189 );
1190 let (input_events, included_events_count) = prompt_for_events();
1191 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1192
1193 let body = PredictEditsBody {
1194 input_events,
1195 input_excerpt: input_excerpt.prompt,
1196 can_collect_data: false,
1197 diagnostic_groups: None,
1198 git_info: None,
1199 outline: None,
1200 speculated_output: None,
1201 };
1202
1203 Ok(GatherContextOutput {
1204 body,
1205 editable_range,
1206 included_events_count,
1207 })
1208 }
1209 })
1210}
1211
1212fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1213 let mut result = String::new();
1214 for (ix, event) in events.iter().rev().enumerate() {
1215 let event_string = event.to_prompt();
1216 let event_tokens = guess_token_count(event_string.len());
1217 if event_tokens > remaining_tokens {
1218 return (result, ix);
1219 }
1220
1221 if !result.is_empty() {
1222 result.insert_str(0, "\n\n");
1223 }
1224 result.insert_str(0, &event_string);
1225 remaining_tokens -= event_tokens;
1226 }
1227 return (result, events.len());
1228}
1229
1230struct RegisteredBuffer {
1231 snapshot: BufferSnapshot,
1232 _subscriptions: [gpui::Subscription; 2],
1233}
1234
1235#[derive(Clone)]
1236pub enum Event {
1237 BufferChange {
1238 old_snapshot: BufferSnapshot,
1239 new_snapshot: BufferSnapshot,
1240 timestamp: Instant,
1241 },
1242}
1243
1244impl Event {
1245 fn to_prompt(&self) -> String {
1246 match self {
1247 Event::BufferChange {
1248 old_snapshot,
1249 new_snapshot,
1250 ..
1251 } => {
1252 let mut prompt = String::new();
1253
1254 let old_path = old_snapshot
1255 .file()
1256 .map(|f| f.path().as_ref())
1257 .unwrap_or(RelPath::unix("untitled").unwrap());
1258 let new_path = new_snapshot
1259 .file()
1260 .map(|f| f.path().as_ref())
1261 .unwrap_or(RelPath::unix("untitled").unwrap());
1262 if old_path != new_path {
1263 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1264 }
1265
1266 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1267 if !diff.is_empty() {
1268 write!(
1269 prompt,
1270 "User edited {:?}:\n```diff\n{}\n```",
1271 new_path, diff
1272 )
1273 .unwrap();
1274 }
1275
1276 prompt
1277 }
1278 }
1279 }
1280}
1281
1282#[derive(Debug, Clone)]
1283struct CurrentEditPrediction {
1284 buffer_id: EntityId,
1285 completion: EditPrediction,
1286 was_shown: bool,
1287 was_accepted: bool,
1288}
1289
1290impl CurrentEditPrediction {
1291 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1292 if self.buffer_id != old_completion.buffer_id {
1293 return true;
1294 }
1295
1296 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1297 return true;
1298 };
1299 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1300 return false;
1301 };
1302
1303 if old_edits.len() == 1 && new_edits.len() == 1 {
1304 let (old_range, old_text) = &old_edits[0];
1305 let (new_range, new_text) = &new_edits[0];
1306 new_range == old_range && new_text.starts_with(old_text.as_ref())
1307 } else {
1308 true
1309 }
1310 }
1311}
1312
1313struct PendingCompletion {
1314 id: usize,
1315 task: Task<()>,
1316}
1317
1318#[derive(Debug, Clone, Copy)]
1319pub enum DataCollectionChoice {
1320 NotAnswered,
1321 Enabled,
1322 Disabled,
1323}
1324
1325impl DataCollectionChoice {
1326 pub fn is_enabled(self) -> bool {
1327 match self {
1328 Self::Enabled => true,
1329 Self::NotAnswered | Self::Disabled => false,
1330 }
1331 }
1332
1333 pub fn is_answered(self) -> bool {
1334 match self {
1335 Self::Enabled | Self::Disabled => true,
1336 Self::NotAnswered => false,
1337 }
1338 }
1339
1340 #[must_use]
1341 pub fn toggle(&self) -> DataCollectionChoice {
1342 match self {
1343 Self::Enabled => Self::Disabled,
1344 Self::Disabled => Self::Enabled,
1345 Self::NotAnswered => Self::Enabled,
1346 }
1347 }
1348}
1349
1350impl From<bool> for DataCollectionChoice {
1351 fn from(value: bool) -> Self {
1352 match value {
1353 true => DataCollectionChoice::Enabled,
1354 false => DataCollectionChoice::Disabled,
1355 }
1356 }
1357}
1358
1359async fn llm_token_retry(
1360 llm_token: &LlmApiToken,
1361 client: &Arc<Client>,
1362 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1363) -> Result<Response<AsyncBody>> {
1364 let mut did_retry = false;
1365 let http_client = client.http_client();
1366 let mut token = llm_token.acquire(client).await?;
1367 loop {
1368 let request = build_request(token.clone())?;
1369 let response = http_client.send(request).await?;
1370
1371 if !did_retry
1372 && !response.status().is_success()
1373 && response
1374 .headers()
1375 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1376 .is_some()
1377 {
1378 did_retry = true;
1379 token = llm_token.refresh(client).await?;
1380 continue;
1381 }
1382
1383 return Ok(response);
1384 }
1385}
1386
1387pub struct ZetaEditPredictionProvider {
1388 zeta: Entity<Zeta>,
1389 singleton_buffer: Option<Entity<Buffer>>,
1390 pending_completions: ArrayVec<PendingCompletion, 2>,
1391 canceled_completions: HashMap<usize, Task<()>>,
1392 next_pending_completion_id: usize,
1393 current_completion: Option<CurrentEditPrediction>,
1394 last_request_timestamp: Instant,
1395 project: Entity<Project>,
1396}
1397
1398impl ZetaEditPredictionProvider {
1399 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1400
1401 pub fn new(
1402 zeta: Entity<Zeta>,
1403 project: Entity<Project>,
1404 singleton_buffer: Option<Entity<Buffer>>,
1405 cx: &mut Context<Self>,
1406 ) -> Self {
1407 cx.on_release(|this, cx| {
1408 this.take_current_edit_prediction(cx);
1409 })
1410 .detach();
1411
1412 Self {
1413 zeta,
1414 singleton_buffer,
1415 pending_completions: ArrayVec::new(),
1416 canceled_completions: HashMap::default(),
1417 next_pending_completion_id: 0,
1418 current_completion: None,
1419 last_request_timestamp: Instant::now(),
1420 project,
1421 }
1422 }
1423
1424 fn take_current_edit_prediction(&mut self, cx: &mut App) {
1425 if let Some(completion) = self.current_completion.take() {
1426 if !completion.was_accepted {
1427 self.zeta.update(cx, |zeta, cx| {
1428 zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1429 });
1430 }
1431 }
1432 }
1433}
1434
1435impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1436 fn name() -> &'static str {
1437 "zed-predict"
1438 }
1439
1440 fn display_name() -> &'static str {
1441 "Zed's Edit Predictions"
1442 }
1443
1444 fn show_completions_in_menu() -> bool {
1445 true
1446 }
1447
1448 fn show_tab_accept_marker() -> bool {
1449 true
1450 }
1451
1452 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1453 if let Some(buffer) = &self.singleton_buffer
1454 && let Some(file) = buffer.read(cx).file()
1455 {
1456 let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1457 if self.zeta.read(cx).data_collection_choice.is_enabled() {
1458 DataCollectionState::Enabled {
1459 is_project_open_source,
1460 }
1461 } else {
1462 DataCollectionState::Disabled {
1463 is_project_open_source,
1464 }
1465 }
1466 } else {
1467 return DataCollectionState::Disabled {
1468 is_project_open_source: false,
1469 };
1470 }
1471 }
1472
1473 fn toggle_data_collection(&mut self, cx: &mut App) {
1474 self.zeta
1475 .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1476 }
1477
1478 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1479 self.zeta.read(cx).usage(cx)
1480 }
1481
1482 fn is_enabled(
1483 &self,
1484 _buffer: &Entity<Buffer>,
1485 _cursor_position: language::Anchor,
1486 _cx: &App,
1487 ) -> bool {
1488 true
1489 }
1490 fn is_refreshing(&self, _cx: &App) -> bool {
1491 !self.pending_completions.is_empty()
1492 }
1493
1494 fn refresh(
1495 &mut self,
1496 buffer: Entity<Buffer>,
1497 position: language::Anchor,
1498 _debounce: bool,
1499 cx: &mut Context<Self>,
1500 ) {
1501 if self.zeta.read(cx).update_required {
1502 return;
1503 }
1504
1505 if self
1506 .zeta
1507 .read(cx)
1508 .user_store
1509 .read_with(cx, |user_store, _cx| {
1510 user_store.account_too_young() || user_store.has_overdue_invoices()
1511 })
1512 {
1513 return;
1514 }
1515
1516 if let Some(current_completion) = self.current_completion.as_ref() {
1517 let snapshot = buffer.read(cx).snapshot();
1518 if current_completion
1519 .completion
1520 .interpolate(&snapshot)
1521 .is_some()
1522 {
1523 return;
1524 }
1525 }
1526
1527 let pending_completion_id = self.next_pending_completion_id;
1528 self.next_pending_completion_id += 1;
1529 let last_request_timestamp = self.last_request_timestamp;
1530
1531 let project = self.project.clone();
1532 let task = cx.spawn(async move |this, cx| {
1533 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1534 .checked_duration_since(Instant::now())
1535 {
1536 cx.background_executor().timer(timeout).await;
1537 }
1538
1539 let completion_request = this.update(cx, |this, cx| {
1540 this.last_request_timestamp = Instant::now();
1541 this.zeta.update(cx, |zeta, cx| {
1542 zeta.request_completion(&project, &buffer, position, cx)
1543 })
1544 });
1545
1546 let completion = match completion_request {
1547 Ok(completion_request) => {
1548 let completion_request = completion_request.await;
1549 completion_request.map(|c| {
1550 c.map(|completion| CurrentEditPrediction {
1551 buffer_id: buffer.entity_id(),
1552 completion,
1553 was_shown: false,
1554 was_accepted: false,
1555 })
1556 })
1557 }
1558 Err(error) => Err(error),
1559 };
1560
1561 let discarded = this
1562 .update(cx, |this, cx| {
1563 if this
1564 .pending_completions
1565 .first()
1566 .is_some_and(|completion| completion.id == pending_completion_id)
1567 {
1568 this.pending_completions.remove(0);
1569 } else {
1570 if let Some(discarded) = this.pending_completions.drain(..).next() {
1571 this.canceled_completions
1572 .insert(discarded.id, discarded.task);
1573 }
1574 }
1575
1576 let canceled = this.canceled_completions.remove(&pending_completion_id);
1577
1578 if canceled.is_some()
1579 && let Ok(Some(new_completion)) = &completion
1580 {
1581 this.zeta.update(cx, |zeta, cx| {
1582 zeta.discard_completion(new_completion.completion.id, false, cx);
1583 });
1584 return true;
1585 }
1586
1587 cx.notify();
1588 false
1589 })
1590 .ok()
1591 .unwrap_or(true);
1592
1593 if discarded {
1594 return;
1595 }
1596
1597 let Some(new_completion) = completion
1598 .context("edit prediction failed")
1599 .log_err()
1600 .flatten()
1601 else {
1602 return;
1603 };
1604
1605 this.update(cx, |this, cx| {
1606 if let Some(old_completion) = this.current_completion.as_ref() {
1607 let snapshot = buffer.read(cx).snapshot();
1608 if new_completion.should_replace_completion(old_completion, &snapshot) {
1609 this.zeta.update(cx, |zeta, cx| {
1610 zeta.completion_shown(&new_completion.completion, cx);
1611 });
1612 this.take_current_edit_prediction(cx);
1613 this.current_completion = Some(new_completion);
1614 }
1615 } else {
1616 this.zeta.update(cx, |zeta, cx| {
1617 zeta.completion_shown(&new_completion.completion, cx);
1618 });
1619 this.current_completion = Some(new_completion);
1620 }
1621
1622 cx.notify();
1623 })
1624 .ok();
1625 });
1626
1627 // We always maintain at most two pending completions. When we already
1628 // have two, we replace the newest one.
1629 if self.pending_completions.len() <= 1 {
1630 self.pending_completions.push(PendingCompletion {
1631 id: pending_completion_id,
1632 task,
1633 });
1634 } else if self.pending_completions.len() == 2 {
1635 if let Some(discarded) = self.pending_completions.pop() {
1636 self.canceled_completions
1637 .insert(discarded.id, discarded.task);
1638 }
1639 self.pending_completions.push(PendingCompletion {
1640 id: pending_completion_id,
1641 task,
1642 });
1643 }
1644 }
1645
1646 fn cycle(
1647 &mut self,
1648 _buffer: Entity<Buffer>,
1649 _cursor_position: language::Anchor,
1650 _direction: edit_prediction::Direction,
1651 _cx: &mut Context<Self>,
1652 ) {
1653 // Right now we don't support cycling.
1654 }
1655
1656 fn accept(&mut self, cx: &mut Context<Self>) {
1657 let completion = self.current_completion.as_mut();
1658 if let Some(completion) = completion {
1659 completion.was_accepted = true;
1660 self.zeta
1661 .update(cx, |zeta, cx| {
1662 zeta.accept_edit_prediction(completion.completion.id, cx)
1663 })
1664 .detach();
1665 }
1666 self.pending_completions.clear();
1667 }
1668
1669 fn discard(&mut self, cx: &mut Context<Self>) {
1670 self.pending_completions.clear();
1671 self.take_current_edit_prediction(cx);
1672 }
1673
1674 fn did_show(&mut self, _cx: &mut Context<Self>) {
1675 if let Some(current_completion) = self.current_completion.as_mut() {
1676 current_completion.was_shown = true;
1677 }
1678 }
1679
1680 fn suggest(
1681 &mut self,
1682 buffer: &Entity<Buffer>,
1683 cursor_position: language::Anchor,
1684 cx: &mut Context<Self>,
1685 ) -> Option<edit_prediction::EditPrediction> {
1686 let CurrentEditPrediction {
1687 buffer_id,
1688 completion,
1689 ..
1690 } = self.current_completion.as_mut()?;
1691
1692 // Invalidate previous completion if it was generated for a different buffer.
1693 if *buffer_id != buffer.entity_id() {
1694 self.take_current_edit_prediction(cx);
1695 return None;
1696 }
1697
1698 let buffer = buffer.read(cx);
1699 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1700 self.take_current_edit_prediction(cx);
1701 return None;
1702 };
1703
1704 let cursor_row = cursor_position.to_point(buffer).row;
1705 let (closest_edit_ix, (closest_edit_range, _)) =
1706 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1707 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1708 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1709 cmp::min(distance_from_start, distance_from_end)
1710 })?;
1711
1712 let mut edit_start_ix = closest_edit_ix;
1713 for (range, _) in edits[..edit_start_ix].iter().rev() {
1714 let distance_from_closest_edit =
1715 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1716 if distance_from_closest_edit <= 1 {
1717 edit_start_ix -= 1;
1718 } else {
1719 break;
1720 }
1721 }
1722
1723 let mut edit_end_ix = closest_edit_ix + 1;
1724 for (range, _) in &edits[edit_end_ix..] {
1725 let distance_from_closest_edit =
1726 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1727 if distance_from_closest_edit <= 1 {
1728 edit_end_ix += 1;
1729 } else {
1730 break;
1731 }
1732 }
1733
1734 Some(edit_prediction::EditPrediction::Local {
1735 id: Some(completion.id.to_string().into()),
1736 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1737 edit_preview: Some(completion.edit_preview.clone()),
1738 })
1739 }
1740}
1741
1742/// Typical number of string bytes per token for the purposes of limiting model input. This is
1743/// intentionally low to err on the side of underestimating limits.
1744const BYTES_PER_TOKEN_GUESS: usize = 3;
1745
1746fn guess_token_count(bytes: usize) -> usize {
1747 bytes / BYTES_PER_TOKEN_GUESS
1748}
1749
1750#[cfg(test)]
1751mod tests {
1752 use client::test::FakeServer;
1753 use clock::{FakeSystemClock, ReplicaId};
1754 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1755 use gpui::TestAppContext;
1756 use http_client::FakeHttpClient;
1757 use indoc::indoc;
1758 use language::Point;
1759 use parking_lot::Mutex;
1760 use serde_json::json;
1761 use settings::SettingsStore;
1762 use util::{path, rel_path::rel_path};
1763
1764 use super::*;
1765
1766 const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1767
1768 #[gpui::test]
1769 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1770 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1771 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1772 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1773 });
1774
1775 let edit_preview = cx
1776 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1777 .await;
1778
1779 let completion = EditPrediction {
1780 edits,
1781 edit_preview,
1782 path: Path::new("").into(),
1783 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1784 id: EditPredictionId(Uuid::new_v4()),
1785 excerpt_range: 0..0,
1786 cursor_offset: 0,
1787 input_outline: "".into(),
1788 input_events: "".into(),
1789 input_excerpt: "".into(),
1790 output_excerpt: "".into(),
1791 buffer_snapshotted_at: Instant::now(),
1792 response_received_at: Instant::now(),
1793 };
1794
1795 cx.update(|cx| {
1796 assert_eq!(
1797 from_completion_edits(
1798 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1799 &buffer,
1800 cx
1801 ),
1802 vec![(2..5, "REM".into()), (9..11, "".into())]
1803 );
1804
1805 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1806 assert_eq!(
1807 from_completion_edits(
1808 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1809 &buffer,
1810 cx
1811 ),
1812 vec![(2..2, "REM".into()), (6..8, "".into())]
1813 );
1814
1815 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1816 assert_eq!(
1817 from_completion_edits(
1818 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1819 &buffer,
1820 cx
1821 ),
1822 vec![(2..5, "REM".into()), (9..11, "".into())]
1823 );
1824
1825 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1826 assert_eq!(
1827 from_completion_edits(
1828 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1829 &buffer,
1830 cx
1831 ),
1832 vec![(3..3, "EM".into()), (7..9, "".into())]
1833 );
1834
1835 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1836 assert_eq!(
1837 from_completion_edits(
1838 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1839 &buffer,
1840 cx
1841 ),
1842 vec![(4..4, "M".into()), (8..10, "".into())]
1843 );
1844
1845 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1846 assert_eq!(
1847 from_completion_edits(
1848 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1849 &buffer,
1850 cx
1851 ),
1852 vec![(9..11, "".into())]
1853 );
1854
1855 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1856 assert_eq!(
1857 from_completion_edits(
1858 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1859 &buffer,
1860 cx
1861 ),
1862 vec![(4..4, "M".into()), (8..10, "".into())]
1863 );
1864
1865 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1866 assert_eq!(
1867 from_completion_edits(
1868 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1869 &buffer,
1870 cx
1871 ),
1872 vec![(4..4, "M".into())]
1873 );
1874
1875 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1876 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1877 })
1878 }
1879
1880 #[gpui::test]
1881 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1882 init_test(cx);
1883
1884 assert_eq!(
1885 apply_edit_prediction(
1886 indoc! {"
1887 fn main() {
1888 let word_1 = \"lorem\";
1889 let range = word.len()..word.len();
1890 }
1891 "},
1892 indoc! {"
1893 <|editable_region_start|>
1894 fn main() {
1895 let word_1 = \"lorem\";
1896 let range = word_1.len()..word_1.len();
1897 }
1898
1899 <|editable_region_end|>
1900 "},
1901 cx,
1902 )
1903 .await,
1904 indoc! {"
1905 fn main() {
1906 let word_1 = \"lorem\";
1907 let range = word_1.len()..word_1.len();
1908 }
1909 "},
1910 );
1911
1912 assert_eq!(
1913 apply_edit_prediction(
1914 indoc! {"
1915 fn main() {
1916 let story = \"the quick\"
1917 }
1918 "},
1919 indoc! {"
1920 <|editable_region_start|>
1921 fn main() {
1922 let story = \"the quick brown fox jumps over the lazy dog\";
1923 }
1924
1925 <|editable_region_end|>
1926 "},
1927 cx,
1928 )
1929 .await,
1930 indoc! {"
1931 fn main() {
1932 let story = \"the quick brown fox jumps over the lazy dog\";
1933 }
1934 "},
1935 );
1936 }
1937
1938 #[gpui::test]
1939 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1940 init_test(cx);
1941
1942 let buffer_content = "lorem\n";
1943 let completion_response = indoc! {"
1944 ```animals.js
1945 <|start_of_file|>
1946 <|editable_region_start|>
1947 lorem
1948 ipsum
1949 <|editable_region_end|>
1950 ```"};
1951
1952 assert_eq!(
1953 apply_edit_prediction(buffer_content, completion_response, cx).await,
1954 "lorem\nipsum"
1955 );
1956 }
1957
1958 #[gpui::test]
1959 async fn test_can_collect_data(cx: &mut TestAppContext) {
1960 init_test(cx);
1961
1962 let fs = project::FakeFs::new(cx.executor());
1963 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1964 .await;
1965
1966 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1967 let buffer = project
1968 .update(cx, |project, cx| {
1969 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1970 })
1971 .await
1972 .unwrap();
1973
1974 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1975 zeta.update(cx, |zeta, _cx| {
1976 zeta.data_collection_choice = DataCollectionChoice::Enabled
1977 });
1978
1979 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1980 assert_eq!(
1981 captured_request.lock().clone().unwrap().can_collect_data,
1982 true
1983 );
1984
1985 zeta.update(cx, |zeta, _cx| {
1986 zeta.data_collection_choice = DataCollectionChoice::Disabled
1987 });
1988
1989 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1990 assert_eq!(
1991 captured_request.lock().clone().unwrap().can_collect_data,
1992 false
1993 );
1994 }
1995
1996 #[gpui::test]
1997 async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1998 init_test(cx);
1999
2000 let fs = project::FakeFs::new(cx.executor());
2001 let project = Project::test(fs.clone(), [], cx).await;
2002
2003 let buffer = cx.new(|_cx| {
2004 Buffer::remote(
2005 language::BufferId::new(1).unwrap(),
2006 ReplicaId::new(1),
2007 language::Capability::ReadWrite,
2008 "fn main() {\n println!(\"Hello\");\n}",
2009 )
2010 });
2011
2012 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2013 zeta.update(cx, |zeta, _cx| {
2014 zeta.data_collection_choice = DataCollectionChoice::Enabled
2015 });
2016
2017 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2018 assert_eq!(
2019 captured_request.lock().clone().unwrap().can_collect_data,
2020 false
2021 );
2022 }
2023
2024 #[gpui::test]
2025 async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2026 init_test(cx);
2027
2028 let fs = project::FakeFs::new(cx.executor());
2029 fs.insert_tree(
2030 path!("/project"),
2031 json!({
2032 "LICENSE": BSD_0_TXT,
2033 ".env": "SECRET_KEY=secret"
2034 }),
2035 )
2036 .await;
2037
2038 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2039 let buffer = project
2040 .update(cx, |project, cx| {
2041 project.open_local_buffer("/project/.env", cx)
2042 })
2043 .await
2044 .unwrap();
2045
2046 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2047 zeta.update(cx, |zeta, _cx| {
2048 zeta.data_collection_choice = DataCollectionChoice::Enabled
2049 });
2050
2051 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2052 assert_eq!(
2053 captured_request.lock().clone().unwrap().can_collect_data,
2054 false
2055 );
2056 }
2057
2058 #[gpui::test]
2059 async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2060 init_test(cx);
2061
2062 let fs = project::FakeFs::new(cx.executor());
2063 let project = Project::test(fs.clone(), [], cx).await;
2064 let buffer = cx.new(|cx| Buffer::local("", cx));
2065
2066 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2067 zeta.update(cx, |zeta, _cx| {
2068 zeta.data_collection_choice = DataCollectionChoice::Enabled
2069 });
2070
2071 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2072 assert_eq!(
2073 captured_request.lock().clone().unwrap().can_collect_data,
2074 false
2075 );
2076 }
2077
2078 #[gpui::test]
2079 async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2080 init_test(cx);
2081
2082 let fs = project::FakeFs::new(cx.executor());
2083 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2084 .await;
2085
2086 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2087 let buffer = project
2088 .update(cx, |project, cx| {
2089 project.open_local_buffer("/project/main.rs", cx)
2090 })
2091 .await
2092 .unwrap();
2093
2094 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2095 zeta.update(cx, |zeta, _cx| {
2096 zeta.data_collection_choice = DataCollectionChoice::Enabled
2097 });
2098
2099 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2100 assert_eq!(
2101 captured_request.lock().clone().unwrap().can_collect_data,
2102 false
2103 );
2104 }
2105
2106 #[gpui::test]
2107 async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2108 init_test(cx);
2109
2110 let fs = project::FakeFs::new(cx.executor());
2111 fs.insert_tree(
2112 path!("/open_source_worktree"),
2113 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2114 )
2115 .await;
2116 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2117 .await;
2118
2119 let project = Project::test(
2120 fs.clone(),
2121 [
2122 path!("/open_source_worktree").as_ref(),
2123 path!("/closed_source_worktree").as_ref(),
2124 ],
2125 cx,
2126 )
2127 .await;
2128 let buffer = project
2129 .update(cx, |project, cx| {
2130 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2131 })
2132 .await
2133 .unwrap();
2134
2135 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2136 zeta.update(cx, |zeta, _cx| {
2137 zeta.data_collection_choice = DataCollectionChoice::Enabled
2138 });
2139
2140 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2141 assert_eq!(
2142 captured_request.lock().clone().unwrap().can_collect_data,
2143 true
2144 );
2145
2146 let closed_source_file = project
2147 .update(cx, |project, cx| {
2148 let worktree2 = project
2149 .worktree_for_root_name("closed_source_worktree", cx)
2150 .unwrap();
2151 worktree2.update(cx, |worktree2, cx| {
2152 worktree2.load_file(rel_path("main.rs"), cx)
2153 })
2154 })
2155 .await
2156 .unwrap()
2157 .file;
2158
2159 buffer.update(cx, |buffer, cx| {
2160 buffer.file_updated(closed_source_file, cx);
2161 });
2162
2163 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2164 assert_eq!(
2165 captured_request.lock().clone().unwrap().can_collect_data,
2166 false
2167 );
2168 }
2169
2170 #[gpui::test]
2171 async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2172 init_test(cx);
2173
2174 let fs = project::FakeFs::new(cx.executor());
2175 fs.insert_tree(
2176 path!("/worktree1"),
2177 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2178 )
2179 .await;
2180 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2181 .await;
2182
2183 let project = Project::test(
2184 fs.clone(),
2185 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2186 cx,
2187 )
2188 .await;
2189 let buffer = project
2190 .update(cx, |project, cx| {
2191 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2192 })
2193 .await
2194 .unwrap();
2195 let private_buffer = project
2196 .update(cx, |project, cx| {
2197 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2198 })
2199 .await
2200 .unwrap();
2201
2202 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2203 zeta.update(cx, |zeta, _cx| {
2204 zeta.data_collection_choice = DataCollectionChoice::Enabled
2205 });
2206
2207 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2208 assert_eq!(
2209 captured_request.lock().clone().unwrap().can_collect_data,
2210 true
2211 );
2212
2213 // this has a side effect of registering the buffer to watch for edits
2214 run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2215 assert_eq!(
2216 captured_request.lock().clone().unwrap().can_collect_data,
2217 false
2218 );
2219
2220 private_buffer.update(cx, |private_buffer, cx| {
2221 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2222 });
2223
2224 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2225 assert_eq!(
2226 captured_request.lock().clone().unwrap().can_collect_data,
2227 false
2228 );
2229
2230 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2231 // included
2232 buffer.update(cx, |buffer, cx| {
2233 buffer.edit(
2234 [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2235 None,
2236 cx,
2237 );
2238 });
2239
2240 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2241 assert_eq!(
2242 captured_request.lock().clone().unwrap().can_collect_data,
2243 true
2244 );
2245 }
2246
2247 fn init_test(cx: &mut TestAppContext) {
2248 cx.update(|cx| {
2249 let settings_store = SettingsStore::test(cx);
2250 cx.set_global(settings_store);
2251 });
2252 }
2253
2254 async fn apply_edit_prediction(
2255 buffer_content: &str,
2256 completion_response: &str,
2257 cx: &mut TestAppContext,
2258 ) -> String {
2259 let fs = project::FakeFs::new(cx.executor());
2260 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2261 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2262 let (zeta, _, response) = make_test_zeta(&project, cx).await;
2263 *response.lock() = completion_response.to_string();
2264 let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2265 buffer.update(cx, |buffer, cx| {
2266 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2267 });
2268 buffer.read_with(cx, |buffer, _| buffer.text())
2269 }
2270
2271 async fn run_edit_prediction(
2272 buffer: &Entity<Buffer>,
2273 project: &Entity<Project>,
2274 zeta: &Entity<Zeta>,
2275 cx: &mut TestAppContext,
2276 ) -> EditPrediction {
2277 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2278 zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2279 cx.background_executor.run_until_parked();
2280 let completion_task = zeta.update(cx, |zeta, cx| {
2281 zeta.request_completion(&project, buffer, cursor, cx)
2282 });
2283 completion_task.await.unwrap().unwrap()
2284 }
2285
2286 async fn make_test_zeta(
2287 project: &Entity<Project>,
2288 cx: &mut TestAppContext,
2289 ) -> (
2290 Entity<Zeta>,
2291 Arc<Mutex<Option<PredictEditsBody>>>,
2292 Arc<Mutex<String>>,
2293 ) {
2294 let default_response = indoc! {"
2295 ```main.rs
2296 <|start_of_file|>
2297 <|editable_region_start|>
2298 hello world
2299 <|editable_region_end|>
2300 ```"
2301 };
2302 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2303 let completion_response: Arc<Mutex<String>> =
2304 Arc::new(Mutex::new(default_response.to_string()));
2305 let http_client = FakeHttpClient::create({
2306 let captured_request = captured_request.clone();
2307 let completion_response = completion_response.clone();
2308 move |req| {
2309 let captured_request = captured_request.clone();
2310 let completion_response = completion_response.clone();
2311 async move {
2312 match (req.method(), req.uri().path()) {
2313 (&Method::POST, "/client/llm_tokens") => {
2314 Ok(http_client::Response::builder()
2315 .status(200)
2316 .body(
2317 serde_json::to_string(&CreateLlmTokenResponse {
2318 token: LlmToken("the-llm-token".to_string()),
2319 })
2320 .unwrap()
2321 .into(),
2322 )
2323 .unwrap())
2324 }
2325 (&Method::POST, "/predict_edits/v2") => {
2326 let mut request_body = String::new();
2327 req.into_body().read_to_string(&mut request_body).await?;
2328 *captured_request.lock() =
2329 Some(serde_json::from_str(&request_body).unwrap());
2330 Ok(http_client::Response::builder()
2331 .status(200)
2332 .body(
2333 serde_json::to_string(&PredictEditsResponse {
2334 request_id: Uuid::new_v4().to_string(),
2335 output_excerpt: completion_response.lock().clone(),
2336 })
2337 .unwrap()
2338 .into(),
2339 )
2340 .unwrap())
2341 }
2342 _ => Ok(http_client::Response::builder()
2343 .status(404)
2344 .body("Not Found".into())
2345 .unwrap()),
2346 }
2347 }
2348 }
2349 });
2350
2351 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2352 cx.update(|cx| {
2353 RefreshLlmTokenListener::register(client.clone(), cx);
2354 });
2355 let _server = FakeServer::for_client(42, &client, cx).await;
2356
2357 let zeta = cx.new(|cx| {
2358 let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2359
2360 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2361 for worktree in worktrees {
2362 let worktree_id = worktree.read(cx).id();
2363 zeta.license_detection_watchers
2364 .entry(worktree_id)
2365 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2366 }
2367
2368 zeta
2369 });
2370
2371 (zeta, captured_request, completion_response)
2372 }
2373
2374 fn to_completion_edits(
2375 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2376 buffer: &Entity<Buffer>,
2377 cx: &App,
2378 ) -> Vec<(Range<Anchor>, Arc<str>)> {
2379 let buffer = buffer.read(cx);
2380 iterator
2381 .into_iter()
2382 .map(|(range, text)| {
2383 (
2384 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2385 text,
2386 )
2387 })
2388 .collect()
2389 }
2390
2391 fn from_completion_edits(
2392 editor_edits: &[(Range<Anchor>, Arc<str>)],
2393 buffer: &Entity<Buffer>,
2394 cx: &App,
2395 ) -> Vec<(Range<usize>, Arc<str>)> {
2396 let buffer = buffer.read(cx);
2397 editor_edits
2398 .iter()
2399 .map(|(range, text)| {
2400 (
2401 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2402 text.clone(),
2403 )
2404 })
2405 .collect()
2406 }
2407
2408 #[ctor::ctor]
2409 fn init_logger() {
2410 zlog::init_test();
2411 }
2412}