1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use db::smol::stream::StreamExt as _;
12use edit_prediction::DataCollectionState;
13use futures::channel::mpsc;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use arrayvec::ArrayVec;
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
23 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
24 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
25 ZED_VERSION_HEADER_NAME,
26};
27use collections::{HashMap, HashSet, VecDeque};
28use futures::AsyncReadExt;
29use gpui::{
30 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
31 SharedString, Subscription, Task, actions,
32};
33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
34use input_excerpt::excerpt_for_cursor_position;
35use language::{
36 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
37};
38use language_model::{LlmApiToken, RefreshLlmTokenListener};
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::collections::hash_map;
43use std::mem;
44use std::str::FromStr;
45use std::{
46 cmp,
47 fmt::Write,
48 future::Future,
49 ops::Range,
50 path::Path,
51 rc::Rc,
52 sync::Arc,
53 time::{Duration, Instant},
54};
55use telemetry_events::EditPredictionRating;
56use thiserror::Error;
57use util::ResultExt;
58use util::rel_path::RelPath;
59use uuid::Uuid;
60use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
61use worktree::Worktree;
62
63const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
64const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
65const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
66const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
69
70const MAX_CONTEXT_TOKENS: usize = 150;
71const MAX_REWRITE_TOKENS: usize = 350;
72const MAX_EVENT_TOKENS: usize = 500;
73
74/// Maximum number of events to track.
75const MAX_EVENT_COUNT: usize = 16;
76
77actions!(
78 edit_prediction,
79 [
80 /// Clears the edit prediction history.
81 ClearHistory
82 ]
83);
84
85#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
86pub struct EditPredictionId(Uuid);
87
88impl From<EditPredictionId> for gpui::ElementId {
89 fn from(value: EditPredictionId) -> Self {
90 gpui::ElementId::Uuid(value.0)
91 }
92}
93
94impl std::fmt::Display for EditPredictionId {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{}", self.0)
97 }
98}
99
100struct ZedPredictUpsell;
101
102impl Dismissable for ZedPredictUpsell {
103 const KEY: &'static str = "dismissed-edit-predict-upsell";
104
105 fn dismissed() -> bool {
106 // To make this backwards compatible with older versions of Zed, we
107 // check if the user has seen the previous Edit Prediction Onboarding
108 // before, by checking the data collection choice which was written to
109 // the database once the user clicked on "Accept and Enable"
110 if KEY_VALUE_STORE
111 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
112 .log_err()
113 .is_some_and(|s| s.is_some())
114 {
115 return true;
116 }
117
118 KEY_VALUE_STORE
119 .read_kvp(Self::KEY)
120 .log_err()
121 .is_some_and(|s| s.is_some())
122 }
123}
124
125pub fn should_show_upsell_modal() -> bool {
126 !ZedPredictUpsell::dismissed()
127}
128
129#[derive(Clone)]
130struct ZetaGlobal(Entity<Zeta>);
131
132impl Global for ZetaGlobal {}
133
134#[derive(Clone)]
135pub struct EditPrediction {
136 id: EditPredictionId,
137 path: Arc<Path>,
138 excerpt_range: Range<usize>,
139 cursor_offset: usize,
140 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
141 snapshot: BufferSnapshot,
142 edit_preview: EditPreview,
143 input_outline: Arc<str>,
144 input_events: Arc<str>,
145 input_excerpt: Arc<str>,
146 output_excerpt: Arc<str>,
147 buffer_snapshotted_at: Instant,
148 response_received_at: Instant,
149}
150
151impl EditPrediction {
152 fn latency(&self) -> Duration {
153 self.response_received_at
154 .duration_since(self.buffer_snapshotted_at)
155 }
156
157 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
158 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
159 }
160}
161
162impl std::fmt::Debug for EditPrediction {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("EditPrediction")
165 .field("id", &self.id)
166 .field("path", &self.path)
167 .field("edits", &self.edits)
168 .finish_non_exhaustive()
169 }
170}
171
172pub struct Zeta {
173 projects: HashMap<EntityId, ZetaProject>,
174 client: Arc<Client>,
175 shown_completions: VecDeque<EditPrediction>,
176 rated_completions: HashSet<EditPredictionId>,
177 data_collection_choice: DataCollectionChoice,
178 discarded_completions: Vec<EditPredictionRejection>,
179 llm_token: LlmApiToken,
180 _llm_token_subscription: Subscription,
181 /// Whether an update to a newer version of Zed is required to continue using Zeta.
182 update_required: bool,
183 user_store: Entity<UserStore>,
184 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
185 discard_completions_debounce_task: Option<Task<()>>,
186 discard_completions_tx: mpsc::UnboundedSender<()>,
187}
188
189struct ZetaProject {
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192}
193
194impl Zeta {
195 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
196 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
197 }
198
199 pub fn register(
200 worktree: Option<Entity<Worktree>>,
201 client: Arc<Client>,
202 user_store: Entity<UserStore>,
203 cx: &mut App,
204 ) -> Entity<Self> {
205 let this = Self::global(cx).unwrap_or_else(|| {
206 let entity = cx.new(|cx| Self::new(client, user_store, cx));
207 cx.set_global(ZetaGlobal(entity.clone()));
208 entity
209 });
210
211 this.update(cx, move |this, cx| {
212 if let Some(worktree) = worktree {
213 let worktree_id = worktree.read(cx).id();
214 this.license_detection_watchers
215 .entry(worktree_id)
216 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
217 }
218 });
219
220 this
221 }
222
223 pub fn clear_history(&mut self) {
224 for zeta_project in self.projects.values_mut() {
225 zeta_project.events.clear();
226 }
227 }
228
229 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
230 self.user_store.read(cx).edit_prediction_usage()
231 }
232
233 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
234 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
235 let data_collection_choice = Self::load_data_collection_choice();
236 let (reject_tx, mut reject_rx) = mpsc::unbounded();
237 cx.spawn(async move |this, cx| {
238 while let Some(()) = reject_rx.next().await {
239 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
240 .await
241 .log_err();
242 }
243 anyhow::Ok(())
244 })
245 .detach();
246
247 Self {
248 projects: HashMap::default(),
249 client,
250 shown_completions: VecDeque::new(),
251 rated_completions: HashSet::default(),
252 discarded_completions: Vec::new(),
253 discard_completions_debounce_task: None,
254 discard_completions_tx: reject_tx,
255 data_collection_choice,
256 llm_token: LlmApiToken::default(),
257 _llm_token_subscription: cx.subscribe(
258 &refresh_llm_token_listener,
259 |this, _listener, _event, cx| {
260 let client = this.client.clone();
261 let llm_token = this.llm_token.clone();
262 cx.spawn(async move |_this, _cx| {
263 llm_token.refresh(&client).await?;
264 anyhow::Ok(())
265 })
266 .detach_and_log_err(cx);
267 },
268 ),
269 update_required: false,
270 license_detection_watchers: HashMap::default(),
271 user_store,
272 }
273 }
274
275 fn get_or_init_zeta_project(
276 &mut self,
277 project: &Entity<Project>,
278 cx: &mut Context<Self>,
279 ) -> &mut ZetaProject {
280 let project_id = project.entity_id();
281 match self.projects.entry(project_id) {
282 hash_map::Entry::Occupied(entry) => entry.into_mut(),
283 hash_map::Entry::Vacant(entry) => {
284 cx.observe_release(project, move |this, _, _cx| {
285 this.projects.remove(&project_id);
286 })
287 .detach();
288 entry.insert(ZetaProject {
289 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
290 registered_buffers: HashMap::default(),
291 })
292 }
293 }
294 }
295
296 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
297 let events = &mut zeta_project.events;
298
299 if let Some(Event::BufferChange {
300 new_snapshot: last_new_snapshot,
301 timestamp: last_timestamp,
302 ..
303 }) = events.back_mut()
304 {
305 // Coalesce edits for the same buffer when they happen one after the other.
306 let Event::BufferChange {
307 old_snapshot,
308 new_snapshot,
309 timestamp,
310 } = &event;
311
312 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
313 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
314 && old_snapshot.version == last_new_snapshot.version
315 {
316 *last_new_snapshot = new_snapshot.clone();
317 *last_timestamp = *timestamp;
318 return;
319 }
320 }
321
322 if events.len() >= MAX_EVENT_COUNT {
323 // These are halved instead of popping to improve prompt caching.
324 events.drain(..MAX_EVENT_COUNT / 2);
325 }
326
327 events.push_back(event);
328 }
329
330 pub fn register_buffer(
331 &mut self,
332 buffer: &Entity<Buffer>,
333 project: &Entity<Project>,
334 cx: &mut Context<Self>,
335 ) {
336 let zeta_project = self.get_or_init_zeta_project(project, cx);
337 Self::register_buffer_impl(zeta_project, buffer, project, cx);
338 }
339
340 fn register_buffer_impl<'a>(
341 zeta_project: &'a mut ZetaProject,
342 buffer: &Entity<Buffer>,
343 project: &Entity<Project>,
344 cx: &mut Context<Self>,
345 ) -> &'a mut RegisteredBuffer {
346 let buffer_id = buffer.entity_id();
347 match zeta_project.registered_buffers.entry(buffer_id) {
348 hash_map::Entry::Occupied(entry) => entry.into_mut(),
349 hash_map::Entry::Vacant(entry) => {
350 let snapshot = buffer.read(cx).snapshot();
351 let project_entity_id = project.entity_id();
352 entry.insert(RegisteredBuffer {
353 snapshot,
354 _subscriptions: [
355 cx.subscribe(buffer, {
356 let project = project.downgrade();
357 move |this, buffer, event, cx| {
358 if let language::BufferEvent::Edited = event
359 && let Some(project) = project.upgrade()
360 {
361 this.report_changes_for_buffer(&buffer, &project, cx);
362 }
363 }
364 }),
365 cx.observe_release(buffer, move |this, _buffer, _cx| {
366 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
367 else {
368 return;
369 };
370 zeta_project.registered_buffers.remove(&buffer_id);
371 }),
372 ],
373 })
374 }
375 }
376 }
377
378 fn request_completion_impl<F, R>(
379 &mut self,
380 project: &Entity<Project>,
381 buffer: &Entity<Buffer>,
382 cursor: language::Anchor,
383 cx: &mut Context<Self>,
384 perform_predict_edits: F,
385 ) -> Task<Result<Option<EditPrediction>>>
386 where
387 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
388 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
389 + Send
390 + 'static,
391 {
392 let buffer = buffer.clone();
393 let buffer_snapshotted_at = Instant::now();
394 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
395 let zeta = cx.entity();
396 let client = self.client.clone();
397 let llm_token = self.llm_token.clone();
398 let app_version = AppVersion::global(cx);
399
400 let zeta_project = self.get_or_init_zeta_project(project, cx);
401 let mut events = Vec::with_capacity(zeta_project.events.len());
402 events.extend(zeta_project.events.iter().cloned());
403 let events = Arc::new(events);
404
405 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
406 let can_collect_file = self.can_collect_file(file, cx);
407 let git_info = if can_collect_file {
408 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
409 } else {
410 None
411 };
412 (git_info, can_collect_file)
413 } else {
414 (None, false)
415 };
416
417 let full_path: Arc<Path> = snapshot
418 .file()
419 .map(|f| Arc::from(f.full_path(cx).as_path()))
420 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
421 let full_path_str = full_path.to_string_lossy().into_owned();
422 let cursor_point = cursor.to_point(&snapshot);
423 let cursor_offset = cursor_point.to_offset(&snapshot);
424 let prompt_for_events = {
425 let events = events.clone();
426 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
427 };
428 let gather_task = gather_context(
429 full_path_str,
430 &snapshot,
431 cursor_point,
432 prompt_for_events,
433 cx,
434 );
435
436 cx.spawn(async move |this, cx| {
437 let GatherContextOutput {
438 mut body,
439 editable_range,
440 included_events_count,
441 } = gather_task.await?;
442 let done_gathering_context_at = Instant::now();
443
444 let included_events = &events[events.len() - included_events_count..events.len()];
445 body.can_collect_data = can_collect_file
446 && this
447 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
448 .unwrap_or(false);
449 if body.can_collect_data {
450 body.git_info = git_info;
451 }
452
453 log::debug!(
454 "Events:\n{}\nExcerpt:\n{:?}",
455 body.input_events,
456 body.input_excerpt
457 );
458
459 let input_outline = body.outline.clone().unwrap_or_default();
460 let input_events = body.input_events.clone();
461 let input_excerpt = body.input_excerpt.clone();
462
463 let response = perform_predict_edits(PerformPredictEditsParams {
464 client,
465 llm_token,
466 app_version,
467 body,
468 })
469 .await;
470 let (response, usage) = match response {
471 Ok(response) => response,
472 Err(err) => {
473 if err.is::<ZedUpdateRequiredError>() {
474 cx.update(|cx| {
475 zeta.update(cx, |zeta, _cx| {
476 zeta.update_required = true;
477 });
478
479 let error_message: SharedString = err.to_string().into();
480 show_app_notification(
481 NotificationId::unique::<ZedUpdateRequiredError>(),
482 cx,
483 move |cx| {
484 cx.new(|cx| {
485 ErrorMessagePrompt::new(error_message.clone(), cx)
486 .with_link_button(
487 "Update Zed",
488 "https://zed.dev/releases",
489 )
490 })
491 },
492 );
493 })
494 .ok();
495 }
496
497 return Err(err);
498 }
499 };
500
501 let received_response_at = Instant::now();
502 log::debug!("completion response: {}", &response.output_excerpt);
503
504 if let Some(usage) = usage {
505 this.update(cx, |this, cx| {
506 this.user_store.update(cx, |user_store, cx| {
507 user_store.update_edit_prediction_usage(usage, cx);
508 });
509 })
510 .ok();
511 }
512
513 let edit_prediction = Self::process_completion_response(
514 response,
515 buffer,
516 &snapshot,
517 editable_range,
518 cursor_offset,
519 full_path,
520 input_outline,
521 input_events,
522 input_excerpt,
523 buffer_snapshotted_at,
524 cx,
525 )
526 .await;
527
528 let finished_at = Instant::now();
529
530 // record latency for ~1% of requests
531 if rand::random::<u8>() <= 2 {
532 telemetry::event!(
533 "Edit Prediction Request",
534 context_latency = done_gathering_context_at
535 .duration_since(buffer_snapshotted_at)
536 .as_millis(),
537 request_latency = received_response_at
538 .duration_since(done_gathering_context_at)
539 .as_millis(),
540 process_latency = finished_at.duration_since(received_response_at).as_millis()
541 );
542 }
543
544 edit_prediction
545 })
546 }
547
548 #[cfg(any(test, feature = "test-support"))]
549 pub fn fake_completion(
550 &mut self,
551 project: &Entity<Project>,
552 buffer: &Entity<Buffer>,
553 position: language::Anchor,
554 response: PredictEditsResponse,
555 cx: &mut Context<Self>,
556 ) -> Task<Result<Option<EditPrediction>>> {
557 self.request_completion_impl(project, buffer, position, cx, |_params| {
558 std::future::ready(Ok((response, None)))
559 })
560 }
561
562 pub fn request_completion(
563 &mut self,
564 project: &Entity<Project>,
565 buffer: &Entity<Buffer>,
566 position: language::Anchor,
567 cx: &mut Context<Self>,
568 ) -> Task<Result<Option<EditPrediction>>> {
569 self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
570 }
571
572 pub fn perform_predict_edits(
573 params: PerformPredictEditsParams,
574 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
575 async move {
576 let PerformPredictEditsParams {
577 client,
578 llm_token,
579 app_version,
580 body,
581 ..
582 } = params;
583
584 let http_client = client.http_client();
585 let mut token = llm_token.acquire(&client).await?;
586 let mut did_retry = false;
587
588 loop {
589 let request_builder = http_client::Request::builder().method(Method::POST);
590 let request_builder =
591 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
592 request_builder.uri(predict_edits_url)
593 } else {
594 request_builder.uri(
595 http_client
596 .build_zed_llm_url("/predict_edits/v2", &[])?
597 .as_ref(),
598 )
599 };
600 let request = request_builder
601 .header("Content-Type", "application/json")
602 .header("Authorization", format!("Bearer {}", token))
603 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
604 .body(serde_json::to_string(&body)?.into())?;
605
606 let mut response = http_client.send(request).await?;
607
608 if let Some(minimum_required_version) = response
609 .headers()
610 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
611 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
612 {
613 anyhow::ensure!(
614 app_version >= minimum_required_version,
615 ZedUpdateRequiredError {
616 minimum_version: minimum_required_version
617 }
618 );
619 }
620
621 if response.status().is_success() {
622 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
623
624 let mut body = String::new();
625 response.body_mut().read_to_string(&mut body).await?;
626 return Ok((serde_json::from_str(&body)?, usage));
627 } else if !did_retry
628 && response
629 .headers()
630 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
631 .is_some()
632 {
633 did_retry = true;
634 token = llm_token.refresh(&client).await?;
635 } else {
636 let mut body = String::new();
637 response.body_mut().read_to_string(&mut body).await?;
638 anyhow::bail!(
639 "error predicting edits.\nStatus: {:?}\nBody: {}",
640 response.status(),
641 body
642 );
643 }
644 }
645 }
646 }
647
648 fn accept_edit_prediction(
649 &mut self,
650 request_id: EditPredictionId,
651 cx: &mut Context<Self>,
652 ) -> Task<Result<()>> {
653 let client = self.client.clone();
654 let llm_token = self.llm_token.clone();
655 let app_version = AppVersion::global(cx);
656 cx.spawn(async move |this, cx| {
657 let http_client = client.http_client();
658 let mut response = llm_token_retry(&llm_token, &client, |token| {
659 let request_builder = http_client::Request::builder().method(Method::POST);
660 let request_builder =
661 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
662 request_builder.uri(accept_prediction_url)
663 } else {
664 request_builder.uri(
665 http_client
666 .build_zed_llm_url("/predict_edits/accept", &[])?
667 .as_ref(),
668 )
669 };
670 Ok(request_builder
671 .header("Content-Type", "application/json")
672 .header("Authorization", format!("Bearer {}", token))
673 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
674 .body(
675 serde_json::to_string(&AcceptEditPredictionBody {
676 request_id: request_id.0.to_string(),
677 })?
678 .into(),
679 )?)
680 })
681 .await?;
682
683 if let Some(minimum_required_version) = response
684 .headers()
685 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
686 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
687 && app_version < minimum_required_version
688 {
689 return Err(anyhow!(ZedUpdateRequiredError {
690 minimum_version: minimum_required_version
691 }));
692 }
693
694 if response.status().is_success() {
695 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
696 this.update(cx, |this, cx| {
697 this.user_store.update(cx, |user_store, cx| {
698 user_store.update_edit_prediction_usage(usage, cx);
699 });
700 })?;
701 }
702
703 Ok(())
704 } else {
705 let mut body = String::new();
706 response.body_mut().read_to_string(&mut body).await?;
707 Err(anyhow!(
708 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
709 response.status(),
710 body
711 ))
712 }
713 })
714 }
715
716 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
717 let client = self.client.clone();
718 let llm_token = self.llm_token.clone();
719 let app_version = AppVersion::global(cx);
720 let last_rejection = self.discarded_completions.last().cloned();
721 let body = serde_json::to_string(&RejectEditPredictionsBody {
722 rejections: self.discarded_completions.clone(),
723 })
724 .ok();
725
726 let Some(last_rejection) = last_rejection else {
727 return Task::ready(anyhow::Ok(()));
728 };
729
730 cx.spawn(async move |this, cx| {
731 let http_client = client.http_client();
732 let mut response = llm_token_retry(&llm_token, &client, |token| {
733 let request_builder = http_client::Request::builder().method(Method::POST);
734 let request_builder = request_builder.uri(
735 http_client
736 .build_zed_llm_url("/predict_edits/reject", &[])?
737 .as_ref(),
738 );
739 Ok(request_builder
740 .header("Content-Type", "application/json")
741 .header("Authorization", format!("Bearer {}", token))
742 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
743 .body(
744 body.as_ref()
745 .context("failed to serialize body")?
746 .clone()
747 .into(),
748 )?)
749 })
750 .await?;
751
752 if let Some(minimum_required_version) = response
753 .headers()
754 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
755 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
756 && app_version < minimum_required_version
757 {
758 return Err(anyhow!(ZedUpdateRequiredError {
759 minimum_version: minimum_required_version
760 }));
761 }
762
763 if response.status().is_success() {
764 this.update(cx, |this, _| {
765 if let Some(ix) = this
766 .discarded_completions
767 .iter()
768 .position(|rejection| rejection.request_id == last_rejection.request_id)
769 {
770 this.discarded_completions.drain(..ix + 1);
771 }
772 })
773 } else {
774 let mut body = String::new();
775 response.body_mut().read_to_string(&mut body).await?;
776 Err(anyhow!(
777 "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
778 response.status(),
779 body
780 ))
781 }
782 })
783 }
784
785 fn process_completion_response(
786 prediction_response: PredictEditsResponse,
787 buffer: Entity<Buffer>,
788 snapshot: &BufferSnapshot,
789 editable_range: Range<usize>,
790 cursor_offset: usize,
791 path: Arc<Path>,
792 input_outline: String,
793 input_events: String,
794 input_excerpt: String,
795 buffer_snapshotted_at: Instant,
796 cx: &AsyncApp,
797 ) -> Task<Result<Option<EditPrediction>>> {
798 let snapshot = snapshot.clone();
799 let request_id = prediction_response.request_id;
800 let output_excerpt = prediction_response.output_excerpt;
801 cx.spawn(async move |cx| {
802 let output_excerpt: Arc<str> = output_excerpt.into();
803
804 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
805 .background_spawn({
806 let output_excerpt = output_excerpt.clone();
807 let editable_range = editable_range.clone();
808 let snapshot = snapshot.clone();
809 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
810 })
811 .await?
812 .into();
813
814 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
815 let edits = edits.clone();
816 move |buffer, cx| {
817 let new_snapshot = buffer.snapshot();
818 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
819 edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
820 .into();
821 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
822 }
823 })?
824 else {
825 return anyhow::Ok(None);
826 };
827
828 let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
829
830 let edit_preview = edit_preview.await;
831
832 Ok(Some(EditPrediction {
833 id: EditPredictionId(request_id),
834 path,
835 excerpt_range: editable_range,
836 cursor_offset,
837 edits,
838 edit_preview,
839 snapshot,
840 input_outline: input_outline.into(),
841 input_events: input_events.into(),
842 input_excerpt: input_excerpt.into(),
843 output_excerpt,
844 buffer_snapshotted_at,
845 response_received_at: Instant::now(),
846 }))
847 })
848 }
849
850 fn parse_edits(
851 output_excerpt: Arc<str>,
852 editable_range: Range<usize>,
853 snapshot: &BufferSnapshot,
854 ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
855 let content = output_excerpt.replace(CURSOR_MARKER, "");
856
857 let start_markers = content
858 .match_indices(EDITABLE_REGION_START_MARKER)
859 .collect::<Vec<_>>();
860 anyhow::ensure!(
861 start_markers.len() == 1,
862 "expected exactly one start marker, found {}",
863 start_markers.len()
864 );
865
866 let end_markers = content
867 .match_indices(EDITABLE_REGION_END_MARKER)
868 .collect::<Vec<_>>();
869 anyhow::ensure!(
870 end_markers.len() == 1,
871 "expected exactly one end marker, found {}",
872 end_markers.len()
873 );
874
875 let sof_markers = content
876 .match_indices(START_OF_FILE_MARKER)
877 .collect::<Vec<_>>();
878 anyhow::ensure!(
879 sof_markers.len() <= 1,
880 "expected at most one start-of-file marker, found {}",
881 sof_markers.len()
882 );
883
884 let codefence_start = start_markers[0].0;
885 let content = &content[codefence_start..];
886
887 let newline_ix = content.find('\n').context("could not find newline")?;
888 let content = &content[newline_ix + 1..];
889
890 let codefence_end = content
891 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
892 .context("could not find end marker")?;
893 let new_text = &content[..codefence_end];
894
895 let old_text = snapshot
896 .text_for_range(editable_range.clone())
897 .collect::<String>();
898
899 Ok(Self::compute_edits(
900 old_text,
901 new_text,
902 editable_range.start,
903 snapshot,
904 ))
905 }
906
907 pub fn compute_edits(
908 old_text: String,
909 new_text: &str,
910 offset: usize,
911 snapshot: &BufferSnapshot,
912 ) -> Vec<(Range<Anchor>, Arc<str>)> {
913 text_diff(&old_text, new_text)
914 .into_iter()
915 .map(|(mut old_range, new_text)| {
916 old_range.start += offset;
917 old_range.end += offset;
918
919 let prefix_len = common_prefix(
920 snapshot.chars_for_range(old_range.clone()),
921 new_text.chars(),
922 );
923 old_range.start += prefix_len;
924
925 let suffix_len = common_prefix(
926 snapshot.reversed_chars_for_range(old_range.clone()),
927 new_text[prefix_len..].chars().rev(),
928 );
929 old_range.end = old_range.end.saturating_sub(suffix_len);
930
931 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
932 let range = if old_range.is_empty() {
933 let anchor = snapshot.anchor_after(old_range.start);
934 anchor..anchor
935 } else {
936 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
937 };
938 (range, new_text)
939 })
940 .collect()
941 }
942
943 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
944 self.rated_completions.contains(&completion_id)
945 }
946
947 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
948 self.shown_completions.push_front(completion.clone());
949 if self.shown_completions.len() > 50 {
950 let completion = self.shown_completions.pop_back().unwrap();
951 self.rated_completions.remove(&completion.id);
952 }
953 cx.notify();
954 }
955
956 pub fn rate_completion(
957 &mut self,
958 completion: &EditPrediction,
959 rating: EditPredictionRating,
960 feedback: String,
961 cx: &mut Context<Self>,
962 ) {
963 self.rated_completions.insert(completion.id);
964 telemetry::event!(
965 "Edit Prediction Rated",
966 rating,
967 input_events = completion.input_events,
968 input_excerpt = completion.input_excerpt,
969 input_outline = completion.input_outline,
970 output_excerpt = completion.output_excerpt,
971 feedback
972 );
973 self.client.telemetry().flush_events().detach();
974 cx.notify();
975 }
976
977 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
978 self.shown_completions.iter()
979 }
980
981 pub fn shown_completions_len(&self) -> usize {
982 self.shown_completions.len()
983 }
984
985 fn report_changes_for_buffer(
986 &mut self,
987 buffer: &Entity<Buffer>,
988 project: &Entity<Project>,
989 cx: &mut Context<Self>,
990 ) -> BufferSnapshot {
991 let zeta_project = self.get_or_init_zeta_project(project, cx);
992 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
993
994 let new_snapshot = buffer.read(cx).snapshot();
995 if new_snapshot.version != registered_buffer.snapshot.version {
996 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
997 Self::push_event(
998 zeta_project,
999 Event::BufferChange {
1000 old_snapshot,
1001 new_snapshot: new_snapshot.clone(),
1002 timestamp: Instant::now(),
1003 },
1004 );
1005 }
1006
1007 new_snapshot
1008 }
1009
1010 fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1011 self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1012 }
1013
1014 fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1015 if !self.data_collection_choice.is_enabled() {
1016 return false;
1017 }
1018 let mut last_checked_file = None;
1019 for event in events {
1020 match event {
1021 Event::BufferChange {
1022 old_snapshot,
1023 new_snapshot,
1024 ..
1025 } => {
1026 if let Some(old_file) = old_snapshot.file()
1027 && let Some(new_file) = new_snapshot.file()
1028 {
1029 if let Some(last_checked_file) = last_checked_file
1030 && Arc::ptr_eq(last_checked_file, old_file)
1031 && Arc::ptr_eq(last_checked_file, new_file)
1032 {
1033 continue;
1034 }
1035 if !self.can_collect_file(old_file, cx) {
1036 return false;
1037 }
1038 if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1039 {
1040 return false;
1041 }
1042 last_checked_file = Some(new_file);
1043 } else {
1044 return false;
1045 }
1046 }
1047 }
1048 }
1049 true
1050 }
1051
1052 fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1053 if !file.is_local() || file.is_private() {
1054 return false;
1055 }
1056 self.license_detection_watchers
1057 .get(&file.worktree_id(cx))
1058 .is_some_and(|watcher| watcher.is_project_open_source())
1059 }
1060
1061 fn load_data_collection_choice() -> DataCollectionChoice {
1062 let choice = KEY_VALUE_STORE
1063 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1064 .log_err()
1065 .flatten();
1066
1067 match choice.as_deref() {
1068 Some("true") => DataCollectionChoice::Enabled,
1069 Some("false") => DataCollectionChoice::Disabled,
1070 Some(_) => {
1071 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1072 DataCollectionChoice::NotAnswered
1073 }
1074 None => DataCollectionChoice::NotAnswered,
1075 }
1076 }
1077
1078 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1079 self.data_collection_choice = self.data_collection_choice.toggle();
1080 let new_choice = self.data_collection_choice;
1081 db::write_and_log(cx, move || {
1082 KEY_VALUE_STORE.write_kvp(
1083 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1084 new_choice.is_enabled().to_string(),
1085 )
1086 });
1087 }
1088
1089 fn discard_completion(
1090 &mut self,
1091 completion_id: EditPredictionId,
1092 was_shown: bool,
1093 cx: &mut Context<Self>,
1094 ) {
1095 self.discarded_completions.push(EditPredictionRejection {
1096 request_id: completion_id.to_string(),
1097 was_shown,
1098 });
1099
1100 let reached_request_limit =
1101 self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1102 let discard_completions_tx = self.discard_completions_tx.clone();
1103 self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1104 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1105 if !reached_request_limit {
1106 cx.background_executor()
1107 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1108 .await;
1109 }
1110 discard_completions_tx.unbounded_send(()).log_err();
1111 }));
1112 }
1113}
1114
1115pub struct PerformPredictEditsParams {
1116 pub client: Arc<Client>,
1117 pub llm_token: LlmApiToken,
1118 pub app_version: SemanticVersion,
1119 pub body: PredictEditsBody,
1120}
1121
1122#[derive(Error, Debug)]
1123#[error(
1124 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1125)]
1126pub struct ZedUpdateRequiredError {
1127 minimum_version: SemanticVersion,
1128}
1129
1130fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1131 a.zip(b)
1132 .take_while(|(a, b)| a == b)
1133 .map(|(a, _)| a.len_utf8())
1134 .sum()
1135}
1136
1137fn git_info_for_file(
1138 project: &Entity<Project>,
1139 project_path: &ProjectPath,
1140 cx: &App,
1141) -> Option<PredictEditsGitInfo> {
1142 let git_store = project.read(cx).git_store().read(cx);
1143 if let Some((repository, _repo_path)) =
1144 git_store.repository_and_path_for_project_path(project_path, cx)
1145 {
1146 let repository = repository.read(cx);
1147 let head_sha = repository
1148 .head_commit
1149 .as_ref()
1150 .map(|head_commit| head_commit.sha.to_string());
1151 let remote_origin_url = repository.remote_origin_url.clone();
1152 let remote_upstream_url = repository.remote_upstream_url.clone();
1153 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1154 return None;
1155 }
1156 Some(PredictEditsGitInfo {
1157 head_sha,
1158 remote_origin_url,
1159 remote_upstream_url,
1160 })
1161 } else {
1162 None
1163 }
1164}
1165
1166pub struct GatherContextOutput {
1167 pub body: PredictEditsBody,
1168 pub editable_range: Range<usize>,
1169 pub included_events_count: usize,
1170}
1171
1172pub fn gather_context(
1173 full_path_str: String,
1174 snapshot: &BufferSnapshot,
1175 cursor_point: language::Point,
1176 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1177 cx: &App,
1178) -> Task<Result<GatherContextOutput>> {
1179 cx.background_spawn({
1180 let snapshot = snapshot.clone();
1181 async move {
1182 let input_excerpt = excerpt_for_cursor_position(
1183 cursor_point,
1184 &full_path_str,
1185 &snapshot,
1186 MAX_REWRITE_TOKENS,
1187 MAX_CONTEXT_TOKENS,
1188 );
1189 let (input_events, included_events_count) = prompt_for_events();
1190 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1191
1192 let body = PredictEditsBody {
1193 input_events,
1194 input_excerpt: input_excerpt.prompt,
1195 can_collect_data: false,
1196 diagnostic_groups: None,
1197 git_info: None,
1198 outline: None,
1199 speculated_output: None,
1200 };
1201
1202 Ok(GatherContextOutput {
1203 body,
1204 editable_range,
1205 included_events_count,
1206 })
1207 }
1208 })
1209}
1210
1211fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1212 let mut result = String::new();
1213 for (ix, event) in events.iter().rev().enumerate() {
1214 let event_string = event.to_prompt();
1215 let event_tokens = guess_token_count(event_string.len());
1216 if event_tokens > remaining_tokens {
1217 return (result, ix);
1218 }
1219
1220 if !result.is_empty() {
1221 result.insert_str(0, "\n\n");
1222 }
1223 result.insert_str(0, &event_string);
1224 remaining_tokens -= event_tokens;
1225 }
1226 return (result, events.len());
1227}
1228
1229struct RegisteredBuffer {
1230 snapshot: BufferSnapshot,
1231 _subscriptions: [gpui::Subscription; 2],
1232}
1233
1234#[derive(Clone)]
1235pub enum Event {
1236 BufferChange {
1237 old_snapshot: BufferSnapshot,
1238 new_snapshot: BufferSnapshot,
1239 timestamp: Instant,
1240 },
1241}
1242
1243impl Event {
1244 fn to_prompt(&self) -> String {
1245 match self {
1246 Event::BufferChange {
1247 old_snapshot,
1248 new_snapshot,
1249 ..
1250 } => {
1251 let mut prompt = String::new();
1252
1253 let old_path = old_snapshot
1254 .file()
1255 .map(|f| f.path().as_ref())
1256 .unwrap_or(RelPath::unix("untitled").unwrap());
1257 let new_path = new_snapshot
1258 .file()
1259 .map(|f| f.path().as_ref())
1260 .unwrap_or(RelPath::unix("untitled").unwrap());
1261 if old_path != new_path {
1262 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1263 }
1264
1265 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1266 if !diff.is_empty() {
1267 write!(
1268 prompt,
1269 "User edited {:?}:\n```diff\n{}\n```",
1270 new_path, diff
1271 )
1272 .unwrap();
1273 }
1274
1275 prompt
1276 }
1277 }
1278 }
1279}
1280
1281#[derive(Debug, Clone)]
1282struct CurrentEditPrediction {
1283 buffer_id: EntityId,
1284 completion: EditPrediction,
1285 was_shown: bool,
1286 was_accepted: bool,
1287}
1288
1289impl CurrentEditPrediction {
1290 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1291 if self.buffer_id != old_completion.buffer_id {
1292 return true;
1293 }
1294
1295 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1296 return true;
1297 };
1298 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1299 return false;
1300 };
1301
1302 if old_edits.len() == 1 && new_edits.len() == 1 {
1303 let (old_range, old_text) = &old_edits[0];
1304 let (new_range, new_text) = &new_edits[0];
1305 new_range == old_range && new_text.starts_with(old_text.as_ref())
1306 } else {
1307 true
1308 }
1309 }
1310}
1311
1312struct PendingCompletion {
1313 id: usize,
1314 task: Task<()>,
1315}
1316
1317#[derive(Debug, Clone, Copy)]
1318pub enum DataCollectionChoice {
1319 NotAnswered,
1320 Enabled,
1321 Disabled,
1322}
1323
1324impl DataCollectionChoice {
1325 pub fn is_enabled(self) -> bool {
1326 match self {
1327 Self::Enabled => true,
1328 Self::NotAnswered | Self::Disabled => false,
1329 }
1330 }
1331
1332 pub fn is_answered(self) -> bool {
1333 match self {
1334 Self::Enabled | Self::Disabled => true,
1335 Self::NotAnswered => false,
1336 }
1337 }
1338
1339 #[must_use]
1340 pub fn toggle(&self) -> DataCollectionChoice {
1341 match self {
1342 Self::Enabled => Self::Disabled,
1343 Self::Disabled => Self::Enabled,
1344 Self::NotAnswered => Self::Enabled,
1345 }
1346 }
1347}
1348
1349impl From<bool> for DataCollectionChoice {
1350 fn from(value: bool) -> Self {
1351 match value {
1352 true => DataCollectionChoice::Enabled,
1353 false => DataCollectionChoice::Disabled,
1354 }
1355 }
1356}
1357
1358async fn llm_token_retry(
1359 llm_token: &LlmApiToken,
1360 client: &Arc<Client>,
1361 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1362) -> Result<Response<AsyncBody>> {
1363 let mut did_retry = false;
1364 let http_client = client.http_client();
1365 let mut token = llm_token.acquire(client).await?;
1366 loop {
1367 let request = build_request(token.clone())?;
1368 let response = http_client.send(request).await?;
1369
1370 if !did_retry
1371 && !response.status().is_success()
1372 && response
1373 .headers()
1374 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1375 .is_some()
1376 {
1377 did_retry = true;
1378 token = llm_token.refresh(client).await?;
1379 continue;
1380 }
1381
1382 return Ok(response);
1383 }
1384}
1385
1386pub struct ZetaEditPredictionProvider {
1387 zeta: Entity<Zeta>,
1388 singleton_buffer: Option<Entity<Buffer>>,
1389 pending_completions: ArrayVec<PendingCompletion, 2>,
1390 canceled_completions: HashMap<usize, Task<()>>,
1391 next_pending_completion_id: usize,
1392 current_completion: Option<CurrentEditPrediction>,
1393 last_request_timestamp: Instant,
1394 project: Entity<Project>,
1395}
1396
1397impl ZetaEditPredictionProvider {
1398 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1399
1400 pub fn new(
1401 zeta: Entity<Zeta>,
1402 project: Entity<Project>,
1403 singleton_buffer: Option<Entity<Buffer>>,
1404 cx: &mut Context<Self>,
1405 ) -> Self {
1406 cx.on_release(|this, cx| {
1407 this.take_current_edit_prediction(cx);
1408 })
1409 .detach();
1410
1411 Self {
1412 zeta,
1413 singleton_buffer,
1414 pending_completions: ArrayVec::new(),
1415 canceled_completions: HashMap::default(),
1416 next_pending_completion_id: 0,
1417 current_completion: None,
1418 last_request_timestamp: Instant::now(),
1419 project,
1420 }
1421 }
1422
1423 fn take_current_edit_prediction(&mut self, cx: &mut App) {
1424 if let Some(completion) = self.current_completion.take() {
1425 if !completion.was_accepted {
1426 self.zeta.update(cx, |zeta, cx| {
1427 zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1428 });
1429 }
1430 }
1431 }
1432}
1433
1434impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1435 fn name() -> &'static str {
1436 "zed-predict"
1437 }
1438
1439 fn display_name() -> &'static str {
1440 "Zed's Edit Predictions"
1441 }
1442
1443 fn show_completions_in_menu() -> bool {
1444 true
1445 }
1446
1447 fn show_tab_accept_marker() -> bool {
1448 true
1449 }
1450
1451 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1452 if let Some(buffer) = &self.singleton_buffer
1453 && let Some(file) = buffer.read(cx).file()
1454 {
1455 let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1456 if self.zeta.read(cx).data_collection_choice.is_enabled() {
1457 DataCollectionState::Enabled {
1458 is_project_open_source,
1459 }
1460 } else {
1461 DataCollectionState::Disabled {
1462 is_project_open_source,
1463 }
1464 }
1465 } else {
1466 return DataCollectionState::Disabled {
1467 is_project_open_source: false,
1468 };
1469 }
1470 }
1471
1472 fn toggle_data_collection(&mut self, cx: &mut App) {
1473 self.zeta
1474 .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1475 }
1476
1477 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1478 self.zeta.read(cx).usage(cx)
1479 }
1480
1481 fn is_enabled(
1482 &self,
1483 _buffer: &Entity<Buffer>,
1484 _cursor_position: language::Anchor,
1485 _cx: &App,
1486 ) -> bool {
1487 true
1488 }
1489 fn is_refreshing(&self, _cx: &App) -> bool {
1490 !self.pending_completions.is_empty()
1491 }
1492
1493 fn refresh(
1494 &mut self,
1495 buffer: Entity<Buffer>,
1496 position: language::Anchor,
1497 _debounce: bool,
1498 cx: &mut Context<Self>,
1499 ) {
1500 if self.zeta.read(cx).update_required {
1501 return;
1502 }
1503
1504 if self
1505 .zeta
1506 .read(cx)
1507 .user_store
1508 .read_with(cx, |user_store, _cx| {
1509 user_store.account_too_young() || user_store.has_overdue_invoices()
1510 })
1511 {
1512 return;
1513 }
1514
1515 if let Some(current_completion) = self.current_completion.as_ref() {
1516 let snapshot = buffer.read(cx).snapshot();
1517 if current_completion
1518 .completion
1519 .interpolate(&snapshot)
1520 .is_some()
1521 {
1522 return;
1523 }
1524 }
1525
1526 let pending_completion_id = self.next_pending_completion_id;
1527 self.next_pending_completion_id += 1;
1528 let last_request_timestamp = self.last_request_timestamp;
1529
1530 let project = self.project.clone();
1531 let task = cx.spawn(async move |this, cx| {
1532 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1533 .checked_duration_since(Instant::now())
1534 {
1535 cx.background_executor().timer(timeout).await;
1536 }
1537
1538 let completion_request = this.update(cx, |this, cx| {
1539 this.last_request_timestamp = Instant::now();
1540 this.zeta.update(cx, |zeta, cx| {
1541 zeta.request_completion(&project, &buffer, position, cx)
1542 })
1543 });
1544
1545 let completion = match completion_request {
1546 Ok(completion_request) => {
1547 let completion_request = completion_request.await;
1548 completion_request.map(|c| {
1549 c.map(|completion| CurrentEditPrediction {
1550 buffer_id: buffer.entity_id(),
1551 completion,
1552 was_shown: false,
1553 was_accepted: false,
1554 })
1555 })
1556 }
1557 Err(error) => Err(error),
1558 };
1559
1560 let discarded = this
1561 .update(cx, |this, cx| {
1562 if this
1563 .pending_completions
1564 .first()
1565 .is_some_and(|completion| completion.id == pending_completion_id)
1566 {
1567 this.pending_completions.remove(0);
1568 } else {
1569 if let Some(discarded) = this.pending_completions.drain(..).next() {
1570 this.canceled_completions
1571 .insert(discarded.id, discarded.task);
1572 }
1573 }
1574
1575 let canceled = this.canceled_completions.remove(&pending_completion_id);
1576
1577 if canceled.is_some()
1578 && let Ok(Some(new_completion)) = &completion
1579 {
1580 this.zeta.update(cx, |zeta, cx| {
1581 zeta.discard_completion(new_completion.completion.id, false, cx);
1582 });
1583 return true;
1584 }
1585
1586 cx.notify();
1587 false
1588 })
1589 .ok()
1590 .unwrap_or(true);
1591
1592 if discarded {
1593 return;
1594 }
1595
1596 let Some(new_completion) = completion
1597 .context("edit prediction failed")
1598 .log_err()
1599 .flatten()
1600 else {
1601 return;
1602 };
1603
1604 this.update(cx, |this, cx| {
1605 if let Some(old_completion) = this.current_completion.as_ref() {
1606 let snapshot = buffer.read(cx).snapshot();
1607 if new_completion.should_replace_completion(old_completion, &snapshot) {
1608 this.zeta.update(cx, |zeta, cx| {
1609 zeta.completion_shown(&new_completion.completion, cx);
1610 });
1611 this.take_current_edit_prediction(cx);
1612 this.current_completion = Some(new_completion);
1613 }
1614 } else {
1615 this.zeta.update(cx, |zeta, cx| {
1616 zeta.completion_shown(&new_completion.completion, cx);
1617 });
1618 this.current_completion = Some(new_completion);
1619 }
1620
1621 cx.notify();
1622 })
1623 .ok();
1624 });
1625
1626 // We always maintain at most two pending completions. When we already
1627 // have two, we replace the newest one.
1628 if self.pending_completions.len() <= 1 {
1629 self.pending_completions.push(PendingCompletion {
1630 id: pending_completion_id,
1631 task,
1632 });
1633 } else if self.pending_completions.len() == 2 {
1634 if let Some(discarded) = self.pending_completions.pop() {
1635 self.canceled_completions
1636 .insert(discarded.id, discarded.task);
1637 }
1638 self.pending_completions.push(PendingCompletion {
1639 id: pending_completion_id,
1640 task,
1641 });
1642 }
1643 }
1644
1645 fn cycle(
1646 &mut self,
1647 _buffer: Entity<Buffer>,
1648 _cursor_position: language::Anchor,
1649 _direction: edit_prediction::Direction,
1650 _cx: &mut Context<Self>,
1651 ) {
1652 // Right now we don't support cycling.
1653 }
1654
1655 fn accept(&mut self, cx: &mut Context<Self>) {
1656 let completion = self.current_completion.as_mut();
1657 if let Some(completion) = completion {
1658 completion.was_accepted = true;
1659 self.zeta
1660 .update(cx, |zeta, cx| {
1661 zeta.accept_edit_prediction(completion.completion.id, cx)
1662 })
1663 .detach();
1664 }
1665 self.pending_completions.clear();
1666 }
1667
1668 fn discard(&mut self, cx: &mut Context<Self>) {
1669 self.pending_completions.clear();
1670 self.take_current_edit_prediction(cx);
1671 }
1672
1673 fn did_show(&mut self, _cx: &mut Context<Self>) {
1674 if let Some(current_completion) = self.current_completion.as_mut() {
1675 current_completion.was_shown = true;
1676 }
1677 }
1678
1679 fn suggest(
1680 &mut self,
1681 buffer: &Entity<Buffer>,
1682 cursor_position: language::Anchor,
1683 cx: &mut Context<Self>,
1684 ) -> Option<edit_prediction::EditPrediction> {
1685 let CurrentEditPrediction {
1686 buffer_id,
1687 completion,
1688 ..
1689 } = self.current_completion.as_mut()?;
1690
1691 // Invalidate previous completion if it was generated for a different buffer.
1692 if *buffer_id != buffer.entity_id() {
1693 self.take_current_edit_prediction(cx);
1694 return None;
1695 }
1696
1697 let buffer = buffer.read(cx);
1698 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1699 self.take_current_edit_prediction(cx);
1700 return None;
1701 };
1702
1703 let cursor_row = cursor_position.to_point(buffer).row;
1704 let (closest_edit_ix, (closest_edit_range, _)) =
1705 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1706 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1707 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1708 cmp::min(distance_from_start, distance_from_end)
1709 })?;
1710
1711 let mut edit_start_ix = closest_edit_ix;
1712 for (range, _) in edits[..edit_start_ix].iter().rev() {
1713 let distance_from_closest_edit =
1714 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1715 if distance_from_closest_edit <= 1 {
1716 edit_start_ix -= 1;
1717 } else {
1718 break;
1719 }
1720 }
1721
1722 let mut edit_end_ix = closest_edit_ix + 1;
1723 for (range, _) in &edits[edit_end_ix..] {
1724 let distance_from_closest_edit =
1725 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1726 if distance_from_closest_edit <= 1 {
1727 edit_end_ix += 1;
1728 } else {
1729 break;
1730 }
1731 }
1732
1733 Some(edit_prediction::EditPrediction::Local {
1734 id: Some(completion.id.to_string().into()),
1735 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1736 edit_preview: Some(completion.edit_preview.clone()),
1737 })
1738 }
1739}
1740
1741/// Typical number of string bytes per token for the purposes of limiting model input. This is
1742/// intentionally low to err on the side of underestimating limits.
1743const BYTES_PER_TOKEN_GUESS: usize = 3;
1744
1745fn guess_token_count(bytes: usize) -> usize {
1746 bytes / BYTES_PER_TOKEN_GUESS
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751 use client::test::FakeServer;
1752 use clock::{FakeSystemClock, ReplicaId};
1753 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1754 use gpui::TestAppContext;
1755 use http_client::FakeHttpClient;
1756 use indoc::indoc;
1757 use language::Point;
1758 use parking_lot::Mutex;
1759 use serde_json::json;
1760 use settings::SettingsStore;
1761 use util::{path, rel_path::rel_path};
1762
1763 use super::*;
1764
1765 const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1766
1767 #[gpui::test]
1768 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1769 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1770 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1771 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1772 });
1773
1774 let edit_preview = cx
1775 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1776 .await;
1777
1778 let completion = EditPrediction {
1779 edits,
1780 edit_preview,
1781 path: Path::new("").into(),
1782 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1783 id: EditPredictionId(Uuid::new_v4()),
1784 excerpt_range: 0..0,
1785 cursor_offset: 0,
1786 input_outline: "".into(),
1787 input_events: "".into(),
1788 input_excerpt: "".into(),
1789 output_excerpt: "".into(),
1790 buffer_snapshotted_at: Instant::now(),
1791 response_received_at: Instant::now(),
1792 };
1793
1794 cx.update(|cx| {
1795 assert_eq!(
1796 from_completion_edits(
1797 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1798 &buffer,
1799 cx
1800 ),
1801 vec![(2..5, "REM".into()), (9..11, "".into())]
1802 );
1803
1804 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1805 assert_eq!(
1806 from_completion_edits(
1807 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1808 &buffer,
1809 cx
1810 ),
1811 vec![(2..2, "REM".into()), (6..8, "".into())]
1812 );
1813
1814 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1815 assert_eq!(
1816 from_completion_edits(
1817 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1818 &buffer,
1819 cx
1820 ),
1821 vec![(2..5, "REM".into()), (9..11, "".into())]
1822 );
1823
1824 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1825 assert_eq!(
1826 from_completion_edits(
1827 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1828 &buffer,
1829 cx
1830 ),
1831 vec![(3..3, "EM".into()), (7..9, "".into())]
1832 );
1833
1834 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1835 assert_eq!(
1836 from_completion_edits(
1837 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1838 &buffer,
1839 cx
1840 ),
1841 vec![(4..4, "M".into()), (8..10, "".into())]
1842 );
1843
1844 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1845 assert_eq!(
1846 from_completion_edits(
1847 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1848 &buffer,
1849 cx
1850 ),
1851 vec![(9..11, "".into())]
1852 );
1853
1854 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1855 assert_eq!(
1856 from_completion_edits(
1857 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1858 &buffer,
1859 cx
1860 ),
1861 vec![(4..4, "M".into()), (8..10, "".into())]
1862 );
1863
1864 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1865 assert_eq!(
1866 from_completion_edits(
1867 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1868 &buffer,
1869 cx
1870 ),
1871 vec![(4..4, "M".into())]
1872 );
1873
1874 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1875 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1876 })
1877 }
1878
1879 #[gpui::test]
1880 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1881 init_test(cx);
1882
1883 assert_eq!(
1884 apply_edit_prediction(
1885 indoc! {"
1886 fn main() {
1887 let word_1 = \"lorem\";
1888 let range = word.len()..word.len();
1889 }
1890 "},
1891 indoc! {"
1892 <|editable_region_start|>
1893 fn main() {
1894 let word_1 = \"lorem\";
1895 let range = word_1.len()..word_1.len();
1896 }
1897
1898 <|editable_region_end|>
1899 "},
1900 cx,
1901 )
1902 .await,
1903 indoc! {"
1904 fn main() {
1905 let word_1 = \"lorem\";
1906 let range = word_1.len()..word_1.len();
1907 }
1908 "},
1909 );
1910
1911 assert_eq!(
1912 apply_edit_prediction(
1913 indoc! {"
1914 fn main() {
1915 let story = \"the quick\"
1916 }
1917 "},
1918 indoc! {"
1919 <|editable_region_start|>
1920 fn main() {
1921 let story = \"the quick brown fox jumps over the lazy dog\";
1922 }
1923
1924 <|editable_region_end|>
1925 "},
1926 cx,
1927 )
1928 .await,
1929 indoc! {"
1930 fn main() {
1931 let story = \"the quick brown fox jumps over the lazy dog\";
1932 }
1933 "},
1934 );
1935 }
1936
1937 #[gpui::test]
1938 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1939 init_test(cx);
1940
1941 let buffer_content = "lorem\n";
1942 let completion_response = indoc! {"
1943 ```animals.js
1944 <|start_of_file|>
1945 <|editable_region_start|>
1946 lorem
1947 ipsum
1948 <|editable_region_end|>
1949 ```"};
1950
1951 assert_eq!(
1952 apply_edit_prediction(buffer_content, completion_response, cx).await,
1953 "lorem\nipsum"
1954 );
1955 }
1956
1957 #[gpui::test]
1958 async fn test_can_collect_data(cx: &mut TestAppContext) {
1959 init_test(cx);
1960
1961 let fs = project::FakeFs::new(cx.executor());
1962 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1963 .await;
1964
1965 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1966 let buffer = project
1967 .update(cx, |project, cx| {
1968 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1969 })
1970 .await
1971 .unwrap();
1972
1973 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1974 zeta.update(cx, |zeta, _cx| {
1975 zeta.data_collection_choice = DataCollectionChoice::Enabled
1976 });
1977
1978 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1979 assert_eq!(
1980 captured_request.lock().clone().unwrap().can_collect_data,
1981 true
1982 );
1983
1984 zeta.update(cx, |zeta, _cx| {
1985 zeta.data_collection_choice = DataCollectionChoice::Disabled
1986 });
1987
1988 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1989 assert_eq!(
1990 captured_request.lock().clone().unwrap().can_collect_data,
1991 false
1992 );
1993 }
1994
1995 #[gpui::test]
1996 async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1997 init_test(cx);
1998
1999 let fs = project::FakeFs::new(cx.executor());
2000 let project = Project::test(fs.clone(), [], cx).await;
2001
2002 let buffer = cx.new(|_cx| {
2003 Buffer::remote(
2004 language::BufferId::new(1).unwrap(),
2005 ReplicaId::new(1),
2006 language::Capability::ReadWrite,
2007 "fn main() {\n println!(\"Hello\");\n}",
2008 )
2009 });
2010
2011 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2012 zeta.update(cx, |zeta, _cx| {
2013 zeta.data_collection_choice = DataCollectionChoice::Enabled
2014 });
2015
2016 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2017 assert_eq!(
2018 captured_request.lock().clone().unwrap().can_collect_data,
2019 false
2020 );
2021 }
2022
2023 #[gpui::test]
2024 async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2025 init_test(cx);
2026
2027 let fs = project::FakeFs::new(cx.executor());
2028 fs.insert_tree(
2029 path!("/project"),
2030 json!({
2031 "LICENSE": BSD_0_TXT,
2032 ".env": "SECRET_KEY=secret"
2033 }),
2034 )
2035 .await;
2036
2037 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2038 let buffer = project
2039 .update(cx, |project, cx| {
2040 project.open_local_buffer("/project/.env", cx)
2041 })
2042 .await
2043 .unwrap();
2044
2045 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2046 zeta.update(cx, |zeta, _cx| {
2047 zeta.data_collection_choice = DataCollectionChoice::Enabled
2048 });
2049
2050 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2051 assert_eq!(
2052 captured_request.lock().clone().unwrap().can_collect_data,
2053 false
2054 );
2055 }
2056
2057 #[gpui::test]
2058 async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2059 init_test(cx);
2060
2061 let fs = project::FakeFs::new(cx.executor());
2062 let project = Project::test(fs.clone(), [], cx).await;
2063 let buffer = cx.new(|cx| Buffer::local("", cx));
2064
2065 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2066 zeta.update(cx, |zeta, _cx| {
2067 zeta.data_collection_choice = DataCollectionChoice::Enabled
2068 });
2069
2070 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2071 assert_eq!(
2072 captured_request.lock().clone().unwrap().can_collect_data,
2073 false
2074 );
2075 }
2076
2077 #[gpui::test]
2078 async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2079 init_test(cx);
2080
2081 let fs = project::FakeFs::new(cx.executor());
2082 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2083 .await;
2084
2085 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2086 let buffer = project
2087 .update(cx, |project, cx| {
2088 project.open_local_buffer("/project/main.rs", cx)
2089 })
2090 .await
2091 .unwrap();
2092
2093 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2094 zeta.update(cx, |zeta, _cx| {
2095 zeta.data_collection_choice = DataCollectionChoice::Enabled
2096 });
2097
2098 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2099 assert_eq!(
2100 captured_request.lock().clone().unwrap().can_collect_data,
2101 false
2102 );
2103 }
2104
2105 #[gpui::test]
2106 async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2107 init_test(cx);
2108
2109 let fs = project::FakeFs::new(cx.executor());
2110 fs.insert_tree(
2111 path!("/open_source_worktree"),
2112 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2113 )
2114 .await;
2115 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2116 .await;
2117
2118 let project = Project::test(
2119 fs.clone(),
2120 [
2121 path!("/open_source_worktree").as_ref(),
2122 path!("/closed_source_worktree").as_ref(),
2123 ],
2124 cx,
2125 )
2126 .await;
2127 let buffer = project
2128 .update(cx, |project, cx| {
2129 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2130 })
2131 .await
2132 .unwrap();
2133
2134 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2135 zeta.update(cx, |zeta, _cx| {
2136 zeta.data_collection_choice = DataCollectionChoice::Enabled
2137 });
2138
2139 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2140 assert_eq!(
2141 captured_request.lock().clone().unwrap().can_collect_data,
2142 true
2143 );
2144
2145 let closed_source_file = project
2146 .update(cx, |project, cx| {
2147 let worktree2 = project
2148 .worktree_for_root_name("closed_source_worktree", cx)
2149 .unwrap();
2150 worktree2.update(cx, |worktree2, cx| {
2151 worktree2.load_file(rel_path("main.rs"), cx)
2152 })
2153 })
2154 .await
2155 .unwrap()
2156 .file;
2157
2158 buffer.update(cx, |buffer, cx| {
2159 buffer.file_updated(closed_source_file, cx);
2160 });
2161
2162 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2163 assert_eq!(
2164 captured_request.lock().clone().unwrap().can_collect_data,
2165 false
2166 );
2167 }
2168
2169 #[gpui::test]
2170 async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2171 init_test(cx);
2172
2173 let fs = project::FakeFs::new(cx.executor());
2174 fs.insert_tree(
2175 path!("/worktree1"),
2176 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2177 )
2178 .await;
2179 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2180 .await;
2181
2182 let project = Project::test(
2183 fs.clone(),
2184 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2185 cx,
2186 )
2187 .await;
2188 let buffer = project
2189 .update(cx, |project, cx| {
2190 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2191 })
2192 .await
2193 .unwrap();
2194 let private_buffer = project
2195 .update(cx, |project, cx| {
2196 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2197 })
2198 .await
2199 .unwrap();
2200
2201 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2202 zeta.update(cx, |zeta, _cx| {
2203 zeta.data_collection_choice = DataCollectionChoice::Enabled
2204 });
2205
2206 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2207 assert_eq!(
2208 captured_request.lock().clone().unwrap().can_collect_data,
2209 true
2210 );
2211
2212 // this has a side effect of registering the buffer to watch for edits
2213 run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2214 assert_eq!(
2215 captured_request.lock().clone().unwrap().can_collect_data,
2216 false
2217 );
2218
2219 private_buffer.update(cx, |private_buffer, cx| {
2220 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2221 });
2222
2223 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2224 assert_eq!(
2225 captured_request.lock().clone().unwrap().can_collect_data,
2226 false
2227 );
2228
2229 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2230 // included
2231 buffer.update(cx, |buffer, cx| {
2232 buffer.edit(
2233 [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2234 None,
2235 cx,
2236 );
2237 });
2238
2239 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2240 assert_eq!(
2241 captured_request.lock().clone().unwrap().can_collect_data,
2242 true
2243 );
2244 }
2245
2246 fn init_test(cx: &mut TestAppContext) {
2247 cx.update(|cx| {
2248 let settings_store = SettingsStore::test(cx);
2249 cx.set_global(settings_store);
2250 });
2251 }
2252
2253 async fn apply_edit_prediction(
2254 buffer_content: &str,
2255 completion_response: &str,
2256 cx: &mut TestAppContext,
2257 ) -> String {
2258 let fs = project::FakeFs::new(cx.executor());
2259 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2260 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2261 let (zeta, _, response) = make_test_zeta(&project, cx).await;
2262 *response.lock() = completion_response.to_string();
2263 let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2264 buffer.update(cx, |buffer, cx| {
2265 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2266 });
2267 buffer.read_with(cx, |buffer, _| buffer.text())
2268 }
2269
2270 async fn run_edit_prediction(
2271 buffer: &Entity<Buffer>,
2272 project: &Entity<Project>,
2273 zeta: &Entity<Zeta>,
2274 cx: &mut TestAppContext,
2275 ) -> EditPrediction {
2276 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2277 zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2278 cx.background_executor.run_until_parked();
2279 let completion_task = zeta.update(cx, |zeta, cx| {
2280 zeta.request_completion(&project, buffer, cursor, cx)
2281 });
2282 completion_task.await.unwrap().unwrap()
2283 }
2284
2285 async fn make_test_zeta(
2286 project: &Entity<Project>,
2287 cx: &mut TestAppContext,
2288 ) -> (
2289 Entity<Zeta>,
2290 Arc<Mutex<Option<PredictEditsBody>>>,
2291 Arc<Mutex<String>>,
2292 ) {
2293 let default_response = indoc! {"
2294 ```main.rs
2295 <|start_of_file|>
2296 <|editable_region_start|>
2297 hello world
2298 <|editable_region_end|>
2299 ```"
2300 };
2301 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2302 let completion_response: Arc<Mutex<String>> =
2303 Arc::new(Mutex::new(default_response.to_string()));
2304 let http_client = FakeHttpClient::create({
2305 let captured_request = captured_request.clone();
2306 let completion_response = completion_response.clone();
2307 move |req| {
2308 let captured_request = captured_request.clone();
2309 let completion_response = completion_response.clone();
2310 async move {
2311 match (req.method(), req.uri().path()) {
2312 (&Method::POST, "/client/llm_tokens") => {
2313 Ok(http_client::Response::builder()
2314 .status(200)
2315 .body(
2316 serde_json::to_string(&CreateLlmTokenResponse {
2317 token: LlmToken("the-llm-token".to_string()),
2318 })
2319 .unwrap()
2320 .into(),
2321 )
2322 .unwrap())
2323 }
2324 (&Method::POST, "/predict_edits/v2") => {
2325 let mut request_body = String::new();
2326 req.into_body().read_to_string(&mut request_body).await?;
2327 *captured_request.lock() =
2328 Some(serde_json::from_str(&request_body).unwrap());
2329 Ok(http_client::Response::builder()
2330 .status(200)
2331 .body(
2332 serde_json::to_string(&PredictEditsResponse {
2333 request_id: Uuid::new_v4().to_string(),
2334 output_excerpt: completion_response.lock().clone(),
2335 })
2336 .unwrap()
2337 .into(),
2338 )
2339 .unwrap())
2340 }
2341 _ => Ok(http_client::Response::builder()
2342 .status(404)
2343 .body("Not Found".into())
2344 .unwrap()),
2345 }
2346 }
2347 }
2348 });
2349
2350 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2351 cx.update(|cx| {
2352 RefreshLlmTokenListener::register(client.clone(), cx);
2353 });
2354 let _server = FakeServer::for_client(42, &client, cx).await;
2355
2356 let zeta = cx.new(|cx| {
2357 let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2358
2359 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2360 for worktree in worktrees {
2361 let worktree_id = worktree.read(cx).id();
2362 zeta.license_detection_watchers
2363 .entry(worktree_id)
2364 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2365 }
2366
2367 zeta
2368 });
2369
2370 (zeta, captured_request, completion_response)
2371 }
2372
2373 fn to_completion_edits(
2374 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2375 buffer: &Entity<Buffer>,
2376 cx: &App,
2377 ) -> Vec<(Range<Anchor>, Arc<str>)> {
2378 let buffer = buffer.read(cx);
2379 iterator
2380 .into_iter()
2381 .map(|(range, text)| {
2382 (
2383 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2384 text,
2385 )
2386 })
2387 .collect()
2388 }
2389
2390 fn from_completion_edits(
2391 editor_edits: &[(Range<Anchor>, Arc<str>)],
2392 buffer: &Entity<Buffer>,
2393 cx: &App,
2394 ) -> Vec<(Range<usize>, Arc<str>)> {
2395 let buffer = buffer.read(cx);
2396 editor_edits
2397 .iter()
2398 .map(|(range, text)| {
2399 (
2400 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2401 text.clone(),
2402 )
2403 })
2404 .collect()
2405 }
2406
2407 #[ctor::ctor]
2408 fn init_logger() {
2409 zlog::init_test();
2410 }
2411}