1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use db::smol::stream::StreamExt as _;
12use edit_prediction::DataCollectionState;
13use futures::channel::mpsc;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use arrayvec::ArrayVec;
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
23 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
24 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
25 ZED_VERSION_HEADER_NAME,
26};
27use collections::{HashMap, HashSet, VecDeque};
28use futures::AsyncReadExt;
29use gpui::{
30 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
31 SharedString, Subscription, Task, actions,
32};
33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
34use input_excerpt::excerpt_for_cursor_position;
35use language::{
36 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
37};
38use language_model::{LlmApiToken, RefreshLlmTokenListener};
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::collections::hash_map;
43use std::mem;
44use std::str::FromStr;
45use std::{
46 cmp,
47 fmt::Write,
48 future::Future,
49 ops::Range,
50 path::Path,
51 rc::Rc,
52 sync::Arc,
53 time::{Duration, Instant},
54};
55use telemetry_events::EditPredictionRating;
56use thiserror::Error;
57use util::ResultExt;
58use util::rel_path::RelPath;
59use uuid::Uuid;
60use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
61use worktree::Worktree;
62
63const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
64const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
65const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
66const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
69
70const MAX_CONTEXT_TOKENS: usize = 150;
71const MAX_REWRITE_TOKENS: usize = 350;
72const MAX_EVENT_TOKENS: usize = 500;
73
74/// Maximum number of events to track.
75const MAX_EVENT_COUNT: usize = 16;
76
77actions!(
78 edit_prediction,
79 [
80 /// Clears the edit prediction history.
81 ClearHistory
82 ]
83);
84
85#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
86pub struct EditPredictionId(Uuid);
87
88impl From<EditPredictionId> for gpui::ElementId {
89 fn from(value: EditPredictionId) -> Self {
90 gpui::ElementId::Uuid(value.0)
91 }
92}
93
94impl std::fmt::Display for EditPredictionId {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{}", self.0)
97 }
98}
99
100struct ZedPredictUpsell;
101
102impl Dismissable for ZedPredictUpsell {
103 const KEY: &'static str = "dismissed-edit-predict-upsell";
104
105 fn dismissed() -> bool {
106 // To make this backwards compatible with older versions of Zed, we
107 // check if the user has seen the previous Edit Prediction Onboarding
108 // before, by checking the data collection choice which was written to
109 // the database once the user clicked on "Accept and Enable"
110 if KEY_VALUE_STORE
111 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
112 .log_err()
113 .is_some_and(|s| s.is_some())
114 {
115 return true;
116 }
117
118 KEY_VALUE_STORE
119 .read_kvp(Self::KEY)
120 .log_err()
121 .is_some_and(|s| s.is_some())
122 }
123}
124
125pub fn should_show_upsell_modal() -> bool {
126 !ZedPredictUpsell::dismissed()
127}
128
129#[derive(Clone)]
130struct ZetaGlobal(Entity<Zeta>);
131
132impl Global for ZetaGlobal {}
133
134#[derive(Clone)]
135pub struct EditPrediction {
136 id: EditPredictionId,
137 path: Arc<Path>,
138 excerpt_range: Range<usize>,
139 cursor_offset: usize,
140 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
141 snapshot: BufferSnapshot,
142 edit_preview: EditPreview,
143 input_outline: Arc<str>,
144 input_events: Arc<str>,
145 input_excerpt: Arc<str>,
146 output_excerpt: Arc<str>,
147 buffer_snapshotted_at: Instant,
148 response_received_at: Instant,
149}
150
151impl EditPrediction {
152 fn latency(&self) -> Duration {
153 self.response_received_at
154 .duration_since(self.buffer_snapshotted_at)
155 }
156
157 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
158 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
159 }
160}
161
162impl std::fmt::Debug for EditPrediction {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("EditPrediction")
165 .field("id", &self.id)
166 .field("path", &self.path)
167 .field("edits", &self.edits)
168 .finish_non_exhaustive()
169 }
170}
171
172pub struct Zeta {
173 projects: HashMap<EntityId, ZetaProject>,
174 client: Arc<Client>,
175 shown_completions: VecDeque<EditPrediction>,
176 rated_completions: HashSet<EditPredictionId>,
177 data_collection_choice: DataCollectionChoice,
178 discarded_completions: Vec<EditPredictionRejection>,
179 llm_token: LlmApiToken,
180 _llm_token_subscription: Subscription,
181 /// Whether an update to a newer version of Zed is required to continue using Zeta.
182 update_required: bool,
183 user_store: Entity<UserStore>,
184 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
185 discard_completions_debounce_task: Option<Task<()>>,
186 discard_completions_tx: mpsc::UnboundedSender<()>,
187}
188
189struct ZetaProject {
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192}
193
194impl Zeta {
195 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
196 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
197 }
198
199 pub fn register(
200 worktree: Option<Entity<Worktree>>,
201 client: Arc<Client>,
202 user_store: Entity<UserStore>,
203 cx: &mut App,
204 ) -> Entity<Self> {
205 let this = Self::global(cx).unwrap_or_else(|| {
206 let entity = cx.new(|cx| Self::new(client, user_store, cx));
207 cx.set_global(ZetaGlobal(entity.clone()));
208 entity
209 });
210
211 this.update(cx, move |this, cx| {
212 if let Some(worktree) = worktree {
213 let worktree_id = worktree.read(cx).id();
214 this.license_detection_watchers
215 .entry(worktree_id)
216 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
217 }
218 });
219
220 this
221 }
222
223 pub fn clear_history(&mut self) {
224 for zeta_project in self.projects.values_mut() {
225 zeta_project.events.clear();
226 }
227 }
228
229 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
230 self.user_store.read(cx).edit_prediction_usage()
231 }
232
233 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
234 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
235 let data_collection_choice = Self::load_data_collection_choice();
236 let (reject_tx, mut reject_rx) = mpsc::unbounded();
237 cx.spawn(async move |this, cx| {
238 while let Some(()) = reject_rx.next().await {
239 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
240 .await
241 .log_err();
242 }
243 anyhow::Ok(())
244 })
245 .detach();
246
247 Self {
248 projects: HashMap::default(),
249 client,
250 shown_completions: VecDeque::new(),
251 rated_completions: HashSet::default(),
252 discarded_completions: Vec::new(),
253 discard_completions_debounce_task: None,
254 discard_completions_tx: reject_tx,
255 data_collection_choice,
256 llm_token: LlmApiToken::default(),
257 _llm_token_subscription: cx.subscribe(
258 &refresh_llm_token_listener,
259 |this, _listener, _event, cx| {
260 let client = this.client.clone();
261 let llm_token = this.llm_token.clone();
262 cx.spawn(async move |_this, _cx| {
263 llm_token.refresh(&client).await?;
264 anyhow::Ok(())
265 })
266 .detach_and_log_err(cx);
267 },
268 ),
269 update_required: false,
270 license_detection_watchers: HashMap::default(),
271 user_store,
272 }
273 }
274
275 fn get_or_init_zeta_project(
276 &mut self,
277 project: &Entity<Project>,
278 cx: &mut Context<Self>,
279 ) -> &mut ZetaProject {
280 let project_id = project.entity_id();
281 match self.projects.entry(project_id) {
282 hash_map::Entry::Occupied(entry) => entry.into_mut(),
283 hash_map::Entry::Vacant(entry) => {
284 cx.observe_release(project, move |this, _, _cx| {
285 this.projects.remove(&project_id);
286 })
287 .detach();
288 entry.insert(ZetaProject {
289 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
290 registered_buffers: HashMap::default(),
291 })
292 }
293 }
294 }
295
296 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
297 let events = &mut zeta_project.events;
298
299 if let Some(Event::BufferChange {
300 new_snapshot: last_new_snapshot,
301 timestamp: last_timestamp,
302 ..
303 }) = events.back_mut()
304 {
305 // Coalesce edits for the same buffer when they happen one after the other.
306 let Event::BufferChange {
307 old_snapshot,
308 new_snapshot,
309 timestamp,
310 } = &event;
311
312 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
313 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
314 && old_snapshot.version == last_new_snapshot.version
315 {
316 *last_new_snapshot = new_snapshot.clone();
317 *last_timestamp = *timestamp;
318 return;
319 }
320 }
321
322 if events.len() >= MAX_EVENT_COUNT {
323 // These are halved instead of popping to improve prompt caching.
324 events.drain(..MAX_EVENT_COUNT / 2);
325 }
326
327 events.push_back(event);
328 }
329
330 pub fn register_buffer(
331 &mut self,
332 buffer: &Entity<Buffer>,
333 project: &Entity<Project>,
334 cx: &mut Context<Self>,
335 ) {
336 let zeta_project = self.get_or_init_zeta_project(project, cx);
337 Self::register_buffer_impl(zeta_project, buffer, project, cx);
338 }
339
340 fn register_buffer_impl<'a>(
341 zeta_project: &'a mut ZetaProject,
342 buffer: &Entity<Buffer>,
343 project: &Entity<Project>,
344 cx: &mut Context<Self>,
345 ) -> &'a mut RegisteredBuffer {
346 let buffer_id = buffer.entity_id();
347 match zeta_project.registered_buffers.entry(buffer_id) {
348 hash_map::Entry::Occupied(entry) => entry.into_mut(),
349 hash_map::Entry::Vacant(entry) => {
350 let snapshot = buffer.read(cx).snapshot();
351 let project_entity_id = project.entity_id();
352 entry.insert(RegisteredBuffer {
353 snapshot,
354 _subscriptions: [
355 cx.subscribe(buffer, {
356 let project = project.downgrade();
357 move |this, buffer, event, cx| {
358 if let language::BufferEvent::Edited = event
359 && let Some(project) = project.upgrade()
360 {
361 this.report_changes_for_buffer(&buffer, &project, cx);
362 }
363 }
364 }),
365 cx.observe_release(buffer, move |this, _buffer, _cx| {
366 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
367 else {
368 return;
369 };
370 zeta_project.registered_buffers.remove(&buffer_id);
371 }),
372 ],
373 })
374 }
375 }
376 }
377
378 fn request_completion_impl<F, R>(
379 &mut self,
380 project: &Entity<Project>,
381 buffer: &Entity<Buffer>,
382 cursor: language::Anchor,
383 cx: &mut Context<Self>,
384 perform_predict_edits: F,
385 ) -> Task<Result<Option<EditPrediction>>>
386 where
387 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
388 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
389 + Send
390 + 'static,
391 {
392 let buffer = buffer.clone();
393 let buffer_snapshotted_at = Instant::now();
394 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
395 let zeta = cx.entity();
396 let client = self.client.clone();
397 let llm_token = self.llm_token.clone();
398 let app_version = AppVersion::global(cx);
399
400 let zeta_project = self.get_or_init_zeta_project(project, cx);
401 let mut events = Vec::with_capacity(zeta_project.events.len());
402 events.extend(zeta_project.events.iter().cloned());
403 let events = Arc::new(events);
404
405 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
406 let can_collect_file = self.can_collect_file(file, cx);
407 let git_info = if can_collect_file {
408 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
409 } else {
410 None
411 };
412 (git_info, can_collect_file)
413 } else {
414 (None, false)
415 };
416
417 let full_path: Arc<Path> = snapshot
418 .file()
419 .map(|f| Arc::from(f.full_path(cx).as_path()))
420 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
421 let full_path_str = full_path.to_string_lossy().into_owned();
422 let cursor_point = cursor.to_point(&snapshot);
423 let cursor_offset = cursor_point.to_offset(&snapshot);
424 let prompt_for_events = {
425 let events = events.clone();
426 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
427 };
428 let gather_task = gather_context(
429 full_path_str,
430 &snapshot,
431 cursor_point,
432 prompt_for_events,
433 cx,
434 );
435
436 cx.spawn(async move |this, cx| {
437 let GatherContextOutput {
438 mut body,
439 editable_range,
440 included_events_count,
441 } = gather_task.await?;
442 let done_gathering_context_at = Instant::now();
443
444 let included_events = &events[events.len() - included_events_count..events.len()];
445 body.can_collect_data = can_collect_file
446 && this
447 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
448 .unwrap_or(false);
449 if body.can_collect_data {
450 body.git_info = git_info;
451 }
452
453 log::debug!(
454 "Events:\n{}\nExcerpt:\n{:?}",
455 body.input_events,
456 body.input_excerpt
457 );
458
459 let input_outline = body.outline.clone().unwrap_or_default();
460 let input_events = body.input_events.clone();
461 let input_excerpt = body.input_excerpt.clone();
462
463 let response = perform_predict_edits(PerformPredictEditsParams {
464 client,
465 llm_token,
466 app_version,
467 body,
468 })
469 .await;
470 let (response, usage) = match response {
471 Ok(response) => response,
472 Err(err) => {
473 if err.is::<ZedUpdateRequiredError>() {
474 cx.update(|cx| {
475 zeta.update(cx, |zeta, _cx| {
476 zeta.update_required = true;
477 });
478
479 let error_message: SharedString = err.to_string().into();
480 show_app_notification(
481 NotificationId::unique::<ZedUpdateRequiredError>(),
482 cx,
483 move |cx| {
484 cx.new(|cx| {
485 ErrorMessagePrompt::new(error_message.clone(), cx)
486 .with_link_button(
487 "Update Zed",
488 "https://zed.dev/releases",
489 )
490 })
491 },
492 );
493 })
494 .ok();
495 }
496
497 return Err(err);
498 }
499 };
500
501 let received_response_at = Instant::now();
502 log::debug!("completion response: {}", &response.output_excerpt);
503
504 if let Some(usage) = usage {
505 this.update(cx, |this, cx| {
506 this.user_store.update(cx, |user_store, cx| {
507 user_store.update_edit_prediction_usage(usage, cx);
508 });
509 })
510 .ok();
511 }
512
513 let edit_prediction = Self::process_completion_response(
514 response,
515 buffer,
516 &snapshot,
517 editable_range,
518 cursor_offset,
519 full_path,
520 input_outline,
521 input_events,
522 input_excerpt,
523 buffer_snapshotted_at,
524 cx,
525 )
526 .await;
527
528 let finished_at = Instant::now();
529
530 // record latency for ~1% of requests
531 if rand::random::<u8>() <= 2 {
532 telemetry::event!(
533 "Edit Prediction Request",
534 context_latency = done_gathering_context_at
535 .duration_since(buffer_snapshotted_at)
536 .as_millis(),
537 request_latency = received_response_at
538 .duration_since(done_gathering_context_at)
539 .as_millis(),
540 process_latency = finished_at.duration_since(received_response_at).as_millis()
541 );
542 }
543
544 edit_prediction
545 })
546 }
547
548 #[cfg(any(test, feature = "test-support"))]
549 pub fn fake_completion(
550 &mut self,
551 project: &Entity<Project>,
552 buffer: &Entity<Buffer>,
553 position: language::Anchor,
554 response: PredictEditsResponse,
555 cx: &mut Context<Self>,
556 ) -> Task<Result<Option<EditPrediction>>> {
557 self.request_completion_impl(project, buffer, position, cx, |_params| {
558 std::future::ready(Ok((response, None)))
559 })
560 }
561
562 pub fn request_completion(
563 &mut self,
564 project: &Entity<Project>,
565 buffer: &Entity<Buffer>,
566 position: language::Anchor,
567 cx: &mut Context<Self>,
568 ) -> Task<Result<Option<EditPrediction>>> {
569 self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
570 }
571
572 pub fn perform_predict_edits(
573 params: PerformPredictEditsParams,
574 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
575 async move {
576 let PerformPredictEditsParams {
577 client,
578 llm_token,
579 app_version,
580 body,
581 ..
582 } = params;
583
584 let http_client = client.http_client();
585 let mut token = llm_token.acquire(&client).await?;
586 let mut did_retry = false;
587
588 loop {
589 let request_builder = http_client::Request::builder().method(Method::POST);
590 let request_builder =
591 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
592 request_builder.uri(predict_edits_url)
593 } else {
594 request_builder.uri(
595 http_client
596 .build_zed_llm_url("/predict_edits/v2", &[])?
597 .as_ref(),
598 )
599 };
600 let request = request_builder
601 .header("Content-Type", "application/json")
602 .header("Authorization", format!("Bearer {}", token))
603 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
604 .body(serde_json::to_string(&body)?.into())?;
605
606 let mut response = http_client.send(request).await?;
607
608 if let Some(minimum_required_version) = response
609 .headers()
610 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
611 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
612 {
613 anyhow::ensure!(
614 app_version >= minimum_required_version,
615 ZedUpdateRequiredError {
616 minimum_version: minimum_required_version
617 }
618 );
619 }
620
621 if response.status().is_success() {
622 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
623
624 let mut body = String::new();
625 response.body_mut().read_to_string(&mut body).await?;
626 return Ok((serde_json::from_str(&body)?, usage));
627 } else if !did_retry
628 && response
629 .headers()
630 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
631 .is_some()
632 {
633 did_retry = true;
634 token = llm_token.refresh(&client).await?;
635 } else {
636 let mut body = String::new();
637 response.body_mut().read_to_string(&mut body).await?;
638 anyhow::bail!(
639 "error predicting edits.\nStatus: {:?}\nBody: {}",
640 response.status(),
641 body
642 );
643 }
644 }
645 }
646 }
647
648 fn accept_edit_prediction(
649 &mut self,
650 request_id: EditPredictionId,
651 cx: &mut Context<Self>,
652 ) -> Task<Result<()>> {
653 let client = self.client.clone();
654 let llm_token = self.llm_token.clone();
655 let app_version = AppVersion::global(cx);
656 cx.spawn(async move |this, cx| {
657 let http_client = client.http_client();
658 let mut response = llm_token_retry(&llm_token, &client, |token| {
659 let request_builder = http_client::Request::builder().method(Method::POST);
660 let request_builder =
661 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
662 request_builder.uri(accept_prediction_url)
663 } else {
664 request_builder.uri(
665 http_client
666 .build_zed_llm_url("/predict_edits/accept", &[])?
667 .as_ref(),
668 )
669 };
670 Ok(request_builder
671 .header("Content-Type", "application/json")
672 .header("Authorization", format!("Bearer {}", token))
673 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
674 .body(
675 serde_json::to_string(&AcceptEditPredictionBody {
676 request_id: request_id.0.to_string(),
677 })?
678 .into(),
679 )?)
680 })
681 .await?;
682
683 if let Some(minimum_required_version) = response
684 .headers()
685 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
686 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
687 && app_version < minimum_required_version
688 {
689 return Err(anyhow!(ZedUpdateRequiredError {
690 minimum_version: minimum_required_version
691 }));
692 }
693
694 if response.status().is_success() {
695 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
696 this.update(cx, |this, cx| {
697 this.user_store.update(cx, |user_store, cx| {
698 user_store.update_edit_prediction_usage(usage, cx);
699 });
700 })?;
701 }
702
703 Ok(())
704 } else {
705 let mut body = String::new();
706 response.body_mut().read_to_string(&mut body).await?;
707 Err(anyhow!(
708 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
709 response.status(),
710 body
711 ))
712 }
713 })
714 }
715
716 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
717 let client = self.client.clone();
718 let llm_token = self.llm_token.clone();
719 let app_version = AppVersion::global(cx);
720 let last_rejection = self.discarded_completions.last().cloned();
721 let body = serde_json::to_string(&RejectEditPredictionsBody {
722 rejections: self.discarded_completions.clone(),
723 })
724 .ok();
725
726 let Some(last_rejection) = last_rejection else {
727 return Task::ready(anyhow::Ok(()));
728 };
729
730 cx.spawn(async move |this, cx| {
731 let http_client = client.http_client();
732 let mut response = llm_token_retry(&llm_token, &client, |token| {
733 let request_builder = http_client::Request::builder().method(Method::POST);
734 let request_builder = request_builder.uri(
735 http_client
736 .build_zed_llm_url("/predict_edits/reject", &[])?
737 .as_ref(),
738 );
739 Ok(request_builder
740 .header("Content-Type", "application/json")
741 .header("Authorization", format!("Bearer {}", token))
742 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
743 .body(
744 body.as_ref()
745 .context("failed to serialize body")?
746 .clone()
747 .into(),
748 )?)
749 })
750 .await?;
751
752 if let Some(minimum_required_version) = response
753 .headers()
754 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
755 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
756 && app_version < minimum_required_version
757 {
758 return Err(anyhow!(ZedUpdateRequiredError {
759 minimum_version: minimum_required_version
760 }));
761 }
762
763 if response.status().is_success() {
764 this.update(cx, |this, _| {
765 if let Some(ix) = this
766 .discarded_completions
767 .iter()
768 .position(|rejection| rejection.request_id == last_rejection.request_id)
769 {
770 this.discarded_completions.drain(..ix + 1);
771 }
772 })
773 } else {
774 let mut body = String::new();
775 response.body_mut().read_to_string(&mut body).await?;
776 Err(anyhow!(
777 "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
778 response.status(),
779 body
780 ))
781 }
782 })
783 }
784
785 fn process_completion_response(
786 prediction_response: PredictEditsResponse,
787 buffer: Entity<Buffer>,
788 snapshot: &BufferSnapshot,
789 editable_range: Range<usize>,
790 cursor_offset: usize,
791 path: Arc<Path>,
792 input_outline: String,
793 input_events: String,
794 input_excerpt: String,
795 buffer_snapshotted_at: Instant,
796 cx: &AsyncApp,
797 ) -> Task<Result<Option<EditPrediction>>> {
798 let snapshot = snapshot.clone();
799 let request_id = prediction_response.request_id;
800 let output_excerpt = prediction_response.output_excerpt;
801 cx.spawn(async move |cx| {
802 let output_excerpt: Arc<str> = output_excerpt.into();
803
804 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
805 .background_spawn({
806 let output_excerpt = output_excerpt.clone();
807 let editable_range = editable_range.clone();
808 let snapshot = snapshot.clone();
809 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
810 })
811 .await?
812 .into();
813
814 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
815 let edits = edits.clone();
816 move |buffer, cx| {
817 let new_snapshot = buffer.snapshot();
818 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
819 edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
820 .into();
821 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
822 }
823 })?
824 else {
825 return anyhow::Ok(None);
826 };
827
828 let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
829
830 let edit_preview = edit_preview.await;
831
832 Ok(Some(EditPrediction {
833 id: EditPredictionId(request_id),
834 path,
835 excerpt_range: editable_range,
836 cursor_offset,
837 edits,
838 edit_preview,
839 snapshot,
840 input_outline: input_outline.into(),
841 input_events: input_events.into(),
842 input_excerpt: input_excerpt.into(),
843 output_excerpt,
844 buffer_snapshotted_at,
845 response_received_at: Instant::now(),
846 }))
847 })
848 }
849
850 fn parse_edits(
851 output_excerpt: Arc<str>,
852 editable_range: Range<usize>,
853 snapshot: &BufferSnapshot,
854 ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
855 let content = output_excerpt.replace(CURSOR_MARKER, "");
856
857 let start_markers = content
858 .match_indices(EDITABLE_REGION_START_MARKER)
859 .collect::<Vec<_>>();
860 anyhow::ensure!(
861 start_markers.len() == 1,
862 "expected exactly one start marker, found {}",
863 start_markers.len()
864 );
865
866 let end_markers = content
867 .match_indices(EDITABLE_REGION_END_MARKER)
868 .collect::<Vec<_>>();
869 anyhow::ensure!(
870 end_markers.len() == 1,
871 "expected exactly one end marker, found {}",
872 end_markers.len()
873 );
874
875 let sof_markers = content
876 .match_indices(START_OF_FILE_MARKER)
877 .collect::<Vec<_>>();
878 anyhow::ensure!(
879 sof_markers.len() <= 1,
880 "expected at most one start-of-file marker, found {}",
881 sof_markers.len()
882 );
883
884 let codefence_start = start_markers[0].0;
885 let content = &content[codefence_start..];
886
887 let newline_ix = content.find('\n').context("could not find newline")?;
888 let content = &content[newline_ix + 1..];
889
890 let codefence_end = content
891 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
892 .context("could not find end marker")?;
893 let new_text = &content[..codefence_end];
894
895 let old_text = snapshot
896 .text_for_range(editable_range.clone())
897 .collect::<String>();
898
899 Ok(Self::compute_edits(
900 old_text,
901 new_text,
902 editable_range.start,
903 snapshot,
904 ))
905 }
906
907 pub fn compute_edits(
908 old_text: String,
909 new_text: &str,
910 offset: usize,
911 snapshot: &BufferSnapshot,
912 ) -> Vec<(Range<Anchor>, Arc<str>)> {
913 text_diff(&old_text, new_text)
914 .into_iter()
915 .map(|(mut old_range, new_text)| {
916 old_range.start += offset;
917 old_range.end += offset;
918
919 let prefix_len = common_prefix(
920 snapshot.chars_for_range(old_range.clone()),
921 new_text.chars(),
922 );
923 old_range.start += prefix_len;
924
925 let suffix_len = common_prefix(
926 snapshot.reversed_chars_for_range(old_range.clone()),
927 new_text[prefix_len..].chars().rev(),
928 );
929 old_range.end = old_range.end.saturating_sub(suffix_len);
930
931 let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
932 let range = if old_range.is_empty() {
933 let anchor = snapshot.anchor_after(old_range.start);
934 anchor..anchor
935 } else {
936 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
937 };
938 (range, new_text)
939 })
940 .collect()
941 }
942
943 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
944 self.rated_completions.contains(&completion_id)
945 }
946
947 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
948 self.shown_completions.push_front(completion.clone());
949 if self.shown_completions.len() > 50 {
950 let completion = self.shown_completions.pop_back().unwrap();
951 self.rated_completions.remove(&completion.id);
952 }
953 cx.notify();
954 }
955
956 pub fn rate_completion(
957 &mut self,
958 completion: &EditPrediction,
959 rating: EditPredictionRating,
960 feedback: String,
961 cx: &mut Context<Self>,
962 ) {
963 self.rated_completions.insert(completion.id);
964 telemetry::event!(
965 "Edit Prediction Rated",
966 rating,
967 input_events = completion.input_events,
968 input_excerpt = completion.input_excerpt,
969 input_outline = completion.input_outline,
970 output_excerpt = completion.output_excerpt,
971 feedback
972 );
973 self.client.telemetry().flush_events().detach();
974 cx.notify();
975 }
976
977 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
978 self.shown_completions.iter()
979 }
980
981 pub fn shown_completions_len(&self) -> usize {
982 self.shown_completions.len()
983 }
984
985 fn report_changes_for_buffer(
986 &mut self,
987 buffer: &Entity<Buffer>,
988 project: &Entity<Project>,
989 cx: &mut Context<Self>,
990 ) -> BufferSnapshot {
991 let zeta_project = self.get_or_init_zeta_project(project, cx);
992 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
993
994 let new_snapshot = buffer.read(cx).snapshot();
995 if new_snapshot.version != registered_buffer.snapshot.version {
996 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
997 Self::push_event(
998 zeta_project,
999 Event::BufferChange {
1000 old_snapshot,
1001 new_snapshot: new_snapshot.clone(),
1002 timestamp: Instant::now(),
1003 },
1004 );
1005 }
1006
1007 new_snapshot
1008 }
1009
1010 fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1011 self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1012 }
1013
1014 fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1015 if !self.data_collection_choice.is_enabled() {
1016 return false;
1017 }
1018 let mut last_checked_file = None;
1019 for event in events {
1020 match event {
1021 Event::BufferChange {
1022 old_snapshot,
1023 new_snapshot,
1024 ..
1025 } => {
1026 if let Some(old_file) = old_snapshot.file()
1027 && let Some(new_file) = new_snapshot.file()
1028 {
1029 if let Some(last_checked_file) = last_checked_file
1030 && Arc::ptr_eq(last_checked_file, old_file)
1031 && Arc::ptr_eq(last_checked_file, new_file)
1032 {
1033 continue;
1034 }
1035 if !self.can_collect_file(old_file, cx) {
1036 return false;
1037 }
1038 if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1039 {
1040 return false;
1041 }
1042 last_checked_file = Some(new_file);
1043 } else {
1044 return false;
1045 }
1046 }
1047 }
1048 }
1049 true
1050 }
1051
1052 fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1053 if !file.is_local() || file.is_private() {
1054 return false;
1055 }
1056 self.license_detection_watchers
1057 .get(&file.worktree_id(cx))
1058 .is_some_and(|watcher| watcher.is_project_open_source())
1059 }
1060
1061 fn load_data_collection_choice() -> DataCollectionChoice {
1062 let choice = KEY_VALUE_STORE
1063 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1064 .log_err()
1065 .flatten();
1066
1067 match choice.as_deref() {
1068 Some("true") => DataCollectionChoice::Enabled,
1069 Some("false") => DataCollectionChoice::Disabled,
1070 Some(_) => {
1071 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1072 DataCollectionChoice::NotAnswered
1073 }
1074 None => DataCollectionChoice::NotAnswered,
1075 }
1076 }
1077
1078 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1079 self.data_collection_choice = self.data_collection_choice.toggle();
1080 let new_choice = self.data_collection_choice;
1081 db::write_and_log(cx, move || {
1082 KEY_VALUE_STORE.write_kvp(
1083 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1084 new_choice.is_enabled().to_string(),
1085 )
1086 });
1087 }
1088
1089 fn discard_completion(
1090 &mut self,
1091 completion_id: EditPredictionId,
1092 was_shown: bool,
1093 cx: &mut Context<Self>,
1094 ) {
1095 self.discarded_completions.push(EditPredictionRejection {
1096 request_id: completion_id.to_string(),
1097 was_shown,
1098 });
1099
1100 let reached_request_limit =
1101 self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1102 let discard_completions_tx = self.discard_completions_tx.clone();
1103 self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1104 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1105 if !reached_request_limit {
1106 cx.background_executor()
1107 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1108 .await;
1109 }
1110 discard_completions_tx.unbounded_send(()).log_err();
1111 }));
1112 }
1113}
1114
1115pub struct PerformPredictEditsParams {
1116 pub client: Arc<Client>,
1117 pub llm_token: LlmApiToken,
1118 pub app_version: SemanticVersion,
1119 pub body: PredictEditsBody,
1120}
1121
1122#[derive(Error, Debug)]
1123#[error(
1124 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1125)]
1126pub struct ZedUpdateRequiredError {
1127 minimum_version: SemanticVersion,
1128}
1129
1130fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1131 a.zip(b)
1132 .take_while(|(a, b)| a == b)
1133 .map(|(a, _)| a.len_utf8())
1134 .sum()
1135}
1136
1137fn git_info_for_file(
1138 project: &Entity<Project>,
1139 project_path: &ProjectPath,
1140 cx: &App,
1141) -> Option<PredictEditsGitInfo> {
1142 let git_store = project.read(cx).git_store().read(cx);
1143 if let Some((repository, _repo_path)) =
1144 git_store.repository_and_path_for_project_path(project_path, cx)
1145 {
1146 let repository = repository.read(cx);
1147 let head_sha = repository
1148 .head_commit
1149 .as_ref()
1150 .map(|head_commit| head_commit.sha.to_string());
1151 let remote_origin_url = repository.remote_origin_url.clone();
1152 let remote_upstream_url = repository.remote_upstream_url.clone();
1153 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1154 return None;
1155 }
1156 Some(PredictEditsGitInfo {
1157 head_sha,
1158 remote_origin_url,
1159 remote_upstream_url,
1160 })
1161 } else {
1162 None
1163 }
1164}
1165
1166pub struct GatherContextOutput {
1167 pub body: PredictEditsBody,
1168 pub editable_range: Range<usize>,
1169 pub included_events_count: usize,
1170}
1171
1172pub fn gather_context(
1173 full_path_str: String,
1174 snapshot: &BufferSnapshot,
1175 cursor_point: language::Point,
1176 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1177 cx: &App,
1178) -> Task<Result<GatherContextOutput>> {
1179 cx.background_spawn({
1180 let snapshot = snapshot.clone();
1181 async move {
1182 let input_excerpt = excerpt_for_cursor_position(
1183 cursor_point,
1184 &full_path_str,
1185 &snapshot,
1186 MAX_REWRITE_TOKENS,
1187 MAX_CONTEXT_TOKENS,
1188 );
1189 let (input_events, included_events_count) = prompt_for_events();
1190 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1191
1192 let body = PredictEditsBody {
1193 input_events,
1194 input_excerpt: input_excerpt.prompt,
1195 can_collect_data: false,
1196 diagnostic_groups: None,
1197 git_info: None,
1198 outline: None,
1199 speculated_output: None,
1200 };
1201
1202 Ok(GatherContextOutput {
1203 body,
1204 editable_range,
1205 included_events_count,
1206 })
1207 }
1208 })
1209}
1210
1211fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1212 let mut result = String::new();
1213 for (ix, event) in events.iter().rev().enumerate() {
1214 let event_string = event.to_prompt();
1215 let event_tokens = guess_token_count(event_string.len());
1216 if event_tokens > remaining_tokens {
1217 return (result, ix);
1218 }
1219
1220 if !result.is_empty() {
1221 result.insert_str(0, "\n\n");
1222 }
1223 result.insert_str(0, &event_string);
1224 remaining_tokens -= event_tokens;
1225 }
1226 return (result, events.len());
1227}
1228
1229struct RegisteredBuffer {
1230 snapshot: BufferSnapshot,
1231 _subscriptions: [gpui::Subscription; 2],
1232}
1233
1234#[derive(Clone)]
1235pub enum Event {
1236 BufferChange {
1237 old_snapshot: BufferSnapshot,
1238 new_snapshot: BufferSnapshot,
1239 timestamp: Instant,
1240 },
1241}
1242
1243impl Event {
1244 fn to_prompt(&self) -> String {
1245 match self {
1246 Event::BufferChange {
1247 old_snapshot,
1248 new_snapshot,
1249 ..
1250 } => {
1251 let mut prompt = String::new();
1252
1253 let old_path = old_snapshot
1254 .file()
1255 .map(|f| f.path().as_ref())
1256 .unwrap_or(RelPath::unix("untitled").unwrap());
1257 let new_path = new_snapshot
1258 .file()
1259 .map(|f| f.path().as_ref())
1260 .unwrap_or(RelPath::unix("untitled").unwrap());
1261 if old_path != new_path {
1262 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1263 }
1264
1265 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1266 if !diff.is_empty() {
1267 write!(
1268 prompt,
1269 "User edited {:?}:\n```diff\n{}\n```",
1270 new_path, diff
1271 )
1272 .unwrap();
1273 }
1274
1275 prompt
1276 }
1277 }
1278 }
1279}
1280
1281#[derive(Debug, Clone)]
1282struct CurrentEditPrediction {
1283 buffer_id: EntityId,
1284 completion: EditPrediction,
1285 was_shown: bool,
1286}
1287
1288impl CurrentEditPrediction {
1289 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1290 if self.buffer_id != old_completion.buffer_id {
1291 return true;
1292 }
1293
1294 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1295 return true;
1296 };
1297 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1298 return false;
1299 };
1300
1301 if old_edits.len() == 1 && new_edits.len() == 1 {
1302 let (old_range, old_text) = &old_edits[0];
1303 let (new_range, new_text) = &new_edits[0];
1304 new_range == old_range && new_text.starts_with(old_text.as_ref())
1305 } else {
1306 true
1307 }
1308 }
1309}
1310
1311struct PendingCompletion {
1312 id: usize,
1313 _task: Task<()>,
1314}
1315
1316#[derive(Debug, Clone, Copy)]
1317pub enum DataCollectionChoice {
1318 NotAnswered,
1319 Enabled,
1320 Disabled,
1321}
1322
1323impl DataCollectionChoice {
1324 pub fn is_enabled(self) -> bool {
1325 match self {
1326 Self::Enabled => true,
1327 Self::NotAnswered | Self::Disabled => false,
1328 }
1329 }
1330
1331 pub fn is_answered(self) -> bool {
1332 match self {
1333 Self::Enabled | Self::Disabled => true,
1334 Self::NotAnswered => false,
1335 }
1336 }
1337
1338 #[must_use]
1339 pub fn toggle(&self) -> DataCollectionChoice {
1340 match self {
1341 Self::Enabled => Self::Disabled,
1342 Self::Disabled => Self::Enabled,
1343 Self::NotAnswered => Self::Enabled,
1344 }
1345 }
1346}
1347
1348impl From<bool> for DataCollectionChoice {
1349 fn from(value: bool) -> Self {
1350 match value {
1351 true => DataCollectionChoice::Enabled,
1352 false => DataCollectionChoice::Disabled,
1353 }
1354 }
1355}
1356
1357async fn llm_token_retry(
1358 llm_token: &LlmApiToken,
1359 client: &Arc<Client>,
1360 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1361) -> Result<Response<AsyncBody>> {
1362 let mut did_retry = false;
1363 let http_client = client.http_client();
1364 let mut token = llm_token.acquire(client).await?;
1365 loop {
1366 let request = build_request(token.clone())?;
1367 let response = http_client.send(request).await?;
1368
1369 if !did_retry
1370 && !response.status().is_success()
1371 && response
1372 .headers()
1373 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1374 .is_some()
1375 {
1376 did_retry = true;
1377 token = llm_token.refresh(client).await?;
1378 continue;
1379 }
1380
1381 return Ok(response);
1382 }
1383}
1384
1385pub struct ZetaEditPredictionProvider {
1386 zeta: Entity<Zeta>,
1387 singleton_buffer: Option<Entity<Buffer>>,
1388 pending_completions: ArrayVec<PendingCompletion, 2>,
1389 next_pending_completion_id: usize,
1390 current_completion: Option<CurrentEditPrediction>,
1391 last_request_timestamp: Instant,
1392 project: Entity<Project>,
1393}
1394
1395impl ZetaEditPredictionProvider {
1396 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1397
1398 pub fn new(
1399 zeta: Entity<Zeta>,
1400 project: Entity<Project>,
1401 singleton_buffer: Option<Entity<Buffer>>,
1402 ) -> Self {
1403 Self {
1404 zeta,
1405 singleton_buffer,
1406 pending_completions: ArrayVec::new(),
1407 next_pending_completion_id: 0,
1408 current_completion: None,
1409 last_request_timestamp: Instant::now(),
1410 project,
1411 }
1412 }
1413}
1414
1415impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1416 fn name() -> &'static str {
1417 "zed-predict"
1418 }
1419
1420 fn display_name() -> &'static str {
1421 "Zed's Edit Predictions"
1422 }
1423
1424 fn show_completions_in_menu() -> bool {
1425 true
1426 }
1427
1428 fn show_tab_accept_marker() -> bool {
1429 true
1430 }
1431
1432 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1433 if let Some(buffer) = &self.singleton_buffer
1434 && let Some(file) = buffer.read(cx).file()
1435 {
1436 let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1437 if self.zeta.read(cx).data_collection_choice.is_enabled() {
1438 DataCollectionState::Enabled {
1439 is_project_open_source,
1440 }
1441 } else {
1442 DataCollectionState::Disabled {
1443 is_project_open_source,
1444 }
1445 }
1446 } else {
1447 return DataCollectionState::Disabled {
1448 is_project_open_source: false,
1449 };
1450 }
1451 }
1452
1453 fn toggle_data_collection(&mut self, cx: &mut App) {
1454 self.zeta
1455 .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1456 }
1457
1458 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1459 self.zeta.read(cx).usage(cx)
1460 }
1461
1462 fn is_enabled(
1463 &self,
1464 _buffer: &Entity<Buffer>,
1465 _cursor_position: language::Anchor,
1466 _cx: &App,
1467 ) -> bool {
1468 true
1469 }
1470 fn is_refreshing(&self) -> bool {
1471 !self.pending_completions.is_empty()
1472 }
1473
1474 fn refresh(
1475 &mut self,
1476 buffer: Entity<Buffer>,
1477 position: language::Anchor,
1478 _debounce: bool,
1479 cx: &mut Context<Self>,
1480 ) {
1481 if self.zeta.read(cx).update_required {
1482 return;
1483 }
1484
1485 if self
1486 .zeta
1487 .read(cx)
1488 .user_store
1489 .read_with(cx, |user_store, _cx| {
1490 user_store.account_too_young() || user_store.has_overdue_invoices()
1491 })
1492 {
1493 return;
1494 }
1495
1496 if let Some(current_completion) = self.current_completion.as_ref() {
1497 let snapshot = buffer.read(cx).snapshot();
1498 if current_completion
1499 .completion
1500 .interpolate(&snapshot)
1501 .is_some()
1502 {
1503 return;
1504 }
1505 }
1506
1507 let pending_completion_id = self.next_pending_completion_id;
1508 self.next_pending_completion_id += 1;
1509 let last_request_timestamp = self.last_request_timestamp;
1510
1511 let project = self.project.clone();
1512 let task = cx.spawn(async move |this, cx| {
1513 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1514 .checked_duration_since(Instant::now())
1515 {
1516 cx.background_executor().timer(timeout).await;
1517 }
1518
1519 let completion_request = this.update(cx, |this, cx| {
1520 this.last_request_timestamp = Instant::now();
1521 this.zeta.update(cx, |zeta, cx| {
1522 zeta.request_completion(&project, &buffer, position, cx)
1523 })
1524 });
1525
1526 let completion = match completion_request {
1527 Ok(completion_request) => {
1528 let completion_request = completion_request.await;
1529 completion_request.map(|c| {
1530 c.map(|completion| CurrentEditPrediction {
1531 buffer_id: buffer.entity_id(),
1532 completion,
1533 was_shown: false,
1534 })
1535 })
1536 }
1537 Err(error) => Err(error),
1538 };
1539 let Some(new_completion) = completion
1540 .context("edit prediction failed")
1541 .log_err()
1542 .flatten()
1543 else {
1544 this.update(cx, |this, cx| {
1545 if this.pending_completions[0].id == pending_completion_id {
1546 this.pending_completions.remove(0);
1547 } else {
1548 this.pending_completions.clear();
1549 }
1550
1551 cx.notify();
1552 })
1553 .ok();
1554 return;
1555 };
1556
1557 this.update(cx, |this, cx| {
1558 if this.pending_completions[0].id == pending_completion_id {
1559 this.pending_completions.remove(0);
1560 } else {
1561 this.pending_completions.clear();
1562 }
1563
1564 if let Some(old_completion) = this.current_completion.as_ref() {
1565 let snapshot = buffer.read(cx).snapshot();
1566 if new_completion.should_replace_completion(old_completion, &snapshot) {
1567 this.zeta.update(cx, |zeta, cx| {
1568 zeta.completion_shown(&new_completion.completion, cx);
1569 });
1570 this.current_completion = Some(new_completion);
1571 }
1572 } else {
1573 this.zeta.update(cx, |zeta, cx| {
1574 zeta.completion_shown(&new_completion.completion, cx);
1575 });
1576 this.current_completion = Some(new_completion);
1577 }
1578
1579 cx.notify();
1580 })
1581 .ok();
1582 });
1583
1584 // We always maintain at most two pending completions. When we already
1585 // have two, we replace the newest one.
1586 if self.pending_completions.len() <= 1 {
1587 self.pending_completions.push(PendingCompletion {
1588 id: pending_completion_id,
1589 _task: task,
1590 });
1591 } else if self.pending_completions.len() == 2 {
1592 self.pending_completions.pop();
1593 self.pending_completions.push(PendingCompletion {
1594 id: pending_completion_id,
1595 _task: task,
1596 });
1597 }
1598 }
1599
1600 fn cycle(
1601 &mut self,
1602 _buffer: Entity<Buffer>,
1603 _cursor_position: language::Anchor,
1604 _direction: edit_prediction::Direction,
1605 _cx: &mut Context<Self>,
1606 ) {
1607 // Right now we don't support cycling.
1608 }
1609
1610 fn accept(&mut self, cx: &mut Context<Self>) {
1611 let completion_id = self
1612 .current_completion
1613 .as_ref()
1614 .map(|completion| completion.completion.id);
1615 if let Some(completion_id) = completion_id {
1616 self.zeta
1617 .update(cx, |zeta, cx| {
1618 zeta.accept_edit_prediction(completion_id, cx)
1619 })
1620 .detach();
1621 }
1622 self.pending_completions.clear();
1623 }
1624
1625 fn discard(&mut self, cx: &mut Context<Self>) {
1626 self.pending_completions.clear();
1627 if let Some(completion) = self.current_completion.take() {
1628 self.zeta.update(cx, |zeta, cx| {
1629 zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1630 });
1631 }
1632 }
1633
1634 fn did_show(&mut self, _cx: &mut Context<Self>) {
1635 if let Some(current_completion) = self.current_completion.as_mut() {
1636 current_completion.was_shown = true;
1637 }
1638 }
1639
1640 fn suggest(
1641 &mut self,
1642 buffer: &Entity<Buffer>,
1643 cursor_position: language::Anchor,
1644 cx: &mut Context<Self>,
1645 ) -> Option<edit_prediction::EditPrediction> {
1646 let CurrentEditPrediction {
1647 buffer_id,
1648 completion,
1649 ..
1650 } = self.current_completion.as_mut()?;
1651
1652 // Invalidate previous completion if it was generated for a different buffer.
1653 if *buffer_id != buffer.entity_id() {
1654 self.current_completion.take();
1655 return None;
1656 }
1657
1658 let buffer = buffer.read(cx);
1659 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1660 self.current_completion.take();
1661 return None;
1662 };
1663
1664 let cursor_row = cursor_position.to_point(buffer).row;
1665 let (closest_edit_ix, (closest_edit_range, _)) =
1666 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1667 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1668 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1669 cmp::min(distance_from_start, distance_from_end)
1670 })?;
1671
1672 let mut edit_start_ix = closest_edit_ix;
1673 for (range, _) in edits[..edit_start_ix].iter().rev() {
1674 let distance_from_closest_edit =
1675 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1676 if distance_from_closest_edit <= 1 {
1677 edit_start_ix -= 1;
1678 } else {
1679 break;
1680 }
1681 }
1682
1683 let mut edit_end_ix = closest_edit_ix + 1;
1684 for (range, _) in &edits[edit_end_ix..] {
1685 let distance_from_closest_edit =
1686 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1687 if distance_from_closest_edit <= 1 {
1688 edit_end_ix += 1;
1689 } else {
1690 break;
1691 }
1692 }
1693
1694 Some(edit_prediction::EditPrediction::Local {
1695 id: Some(completion.id.to_string().into()),
1696 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1697 edit_preview: Some(completion.edit_preview.clone()),
1698 })
1699 }
1700}
1701
1702/// Typical number of string bytes per token for the purposes of limiting model input. This is
1703/// intentionally low to err on the side of underestimating limits.
1704const BYTES_PER_TOKEN_GUESS: usize = 3;
1705
1706fn guess_token_count(bytes: usize) -> usize {
1707 bytes / BYTES_PER_TOKEN_GUESS
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712 use client::test::FakeServer;
1713 use clock::{FakeSystemClock, ReplicaId};
1714 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1715 use gpui::TestAppContext;
1716 use http_client::FakeHttpClient;
1717 use indoc::indoc;
1718 use language::Point;
1719 use parking_lot::Mutex;
1720 use serde_json::json;
1721 use settings::SettingsStore;
1722 use util::{path, rel_path::rel_path};
1723
1724 use super::*;
1725
1726 const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1727
1728 #[gpui::test]
1729 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1730 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1731 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1732 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1733 });
1734
1735 let edit_preview = cx
1736 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1737 .await;
1738
1739 let completion = EditPrediction {
1740 edits,
1741 edit_preview,
1742 path: Path::new("").into(),
1743 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1744 id: EditPredictionId(Uuid::new_v4()),
1745 excerpt_range: 0..0,
1746 cursor_offset: 0,
1747 input_outline: "".into(),
1748 input_events: "".into(),
1749 input_excerpt: "".into(),
1750 output_excerpt: "".into(),
1751 buffer_snapshotted_at: Instant::now(),
1752 response_received_at: Instant::now(),
1753 };
1754
1755 cx.update(|cx| {
1756 assert_eq!(
1757 from_completion_edits(
1758 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1759 &buffer,
1760 cx
1761 ),
1762 vec![(2..5, "REM".into()), (9..11, "".into())]
1763 );
1764
1765 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1766 assert_eq!(
1767 from_completion_edits(
1768 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1769 &buffer,
1770 cx
1771 ),
1772 vec![(2..2, "REM".into()), (6..8, "".into())]
1773 );
1774
1775 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1776 assert_eq!(
1777 from_completion_edits(
1778 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1779 &buffer,
1780 cx
1781 ),
1782 vec![(2..5, "REM".into()), (9..11, "".into())]
1783 );
1784
1785 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1786 assert_eq!(
1787 from_completion_edits(
1788 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1789 &buffer,
1790 cx
1791 ),
1792 vec![(3..3, "EM".into()), (7..9, "".into())]
1793 );
1794
1795 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1796 assert_eq!(
1797 from_completion_edits(
1798 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1799 &buffer,
1800 cx
1801 ),
1802 vec![(4..4, "M".into()), (8..10, "".into())]
1803 );
1804
1805 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1806 assert_eq!(
1807 from_completion_edits(
1808 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1809 &buffer,
1810 cx
1811 ),
1812 vec![(9..11, "".into())]
1813 );
1814
1815 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1816 assert_eq!(
1817 from_completion_edits(
1818 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1819 &buffer,
1820 cx
1821 ),
1822 vec![(4..4, "M".into()), (8..10, "".into())]
1823 );
1824
1825 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1826 assert_eq!(
1827 from_completion_edits(
1828 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1829 &buffer,
1830 cx
1831 ),
1832 vec![(4..4, "M".into())]
1833 );
1834
1835 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1836 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1837 })
1838 }
1839
1840 #[gpui::test]
1841 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1842 init_test(cx);
1843
1844 assert_eq!(
1845 apply_edit_prediction(
1846 indoc! {"
1847 fn main() {
1848 let word_1 = \"lorem\";
1849 let range = word.len()..word.len();
1850 }
1851 "},
1852 indoc! {"
1853 <|editable_region_start|>
1854 fn main() {
1855 let word_1 = \"lorem\";
1856 let range = word_1.len()..word_1.len();
1857 }
1858
1859 <|editable_region_end|>
1860 "},
1861 cx,
1862 )
1863 .await,
1864 indoc! {"
1865 fn main() {
1866 let word_1 = \"lorem\";
1867 let range = word_1.len()..word_1.len();
1868 }
1869 "},
1870 );
1871
1872 assert_eq!(
1873 apply_edit_prediction(
1874 indoc! {"
1875 fn main() {
1876 let story = \"the quick\"
1877 }
1878 "},
1879 indoc! {"
1880 <|editable_region_start|>
1881 fn main() {
1882 let story = \"the quick brown fox jumps over the lazy dog\";
1883 }
1884
1885 <|editable_region_end|>
1886 "},
1887 cx,
1888 )
1889 .await,
1890 indoc! {"
1891 fn main() {
1892 let story = \"the quick brown fox jumps over the lazy dog\";
1893 }
1894 "},
1895 );
1896 }
1897
1898 #[gpui::test]
1899 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1900 init_test(cx);
1901
1902 let buffer_content = "lorem\n";
1903 let completion_response = indoc! {"
1904 ```animals.js
1905 <|start_of_file|>
1906 <|editable_region_start|>
1907 lorem
1908 ipsum
1909 <|editable_region_end|>
1910 ```"};
1911
1912 assert_eq!(
1913 apply_edit_prediction(buffer_content, completion_response, cx).await,
1914 "lorem\nipsum"
1915 );
1916 }
1917
1918 #[gpui::test]
1919 async fn test_can_collect_data(cx: &mut TestAppContext) {
1920 init_test(cx);
1921
1922 let fs = project::FakeFs::new(cx.executor());
1923 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1924 .await;
1925
1926 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1927 let buffer = project
1928 .update(cx, |project, cx| {
1929 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1930 })
1931 .await
1932 .unwrap();
1933
1934 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1935 zeta.update(cx, |zeta, _cx| {
1936 zeta.data_collection_choice = DataCollectionChoice::Enabled
1937 });
1938
1939 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1940 assert_eq!(
1941 captured_request.lock().clone().unwrap().can_collect_data,
1942 true
1943 );
1944
1945 zeta.update(cx, |zeta, _cx| {
1946 zeta.data_collection_choice = DataCollectionChoice::Disabled
1947 });
1948
1949 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1950 assert_eq!(
1951 captured_request.lock().clone().unwrap().can_collect_data,
1952 false
1953 );
1954 }
1955
1956 #[gpui::test]
1957 async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1958 init_test(cx);
1959
1960 let fs = project::FakeFs::new(cx.executor());
1961 let project = Project::test(fs.clone(), [], cx).await;
1962
1963 let buffer = cx.new(|_cx| {
1964 Buffer::remote(
1965 language::BufferId::new(1).unwrap(),
1966 ReplicaId::new(1),
1967 language::Capability::ReadWrite,
1968 "fn main() {\n println!(\"Hello\");\n}",
1969 )
1970 });
1971
1972 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1973 zeta.update(cx, |zeta, _cx| {
1974 zeta.data_collection_choice = DataCollectionChoice::Enabled
1975 });
1976
1977 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1978 assert_eq!(
1979 captured_request.lock().clone().unwrap().can_collect_data,
1980 false
1981 );
1982 }
1983
1984 #[gpui::test]
1985 async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1986 init_test(cx);
1987
1988 let fs = project::FakeFs::new(cx.executor());
1989 fs.insert_tree(
1990 path!("/project"),
1991 json!({
1992 "LICENSE": BSD_0_TXT,
1993 ".env": "SECRET_KEY=secret"
1994 }),
1995 )
1996 .await;
1997
1998 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1999 let buffer = project
2000 .update(cx, |project, cx| {
2001 project.open_local_buffer("/project/.env", cx)
2002 })
2003 .await
2004 .unwrap();
2005
2006 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2007 zeta.update(cx, |zeta, _cx| {
2008 zeta.data_collection_choice = DataCollectionChoice::Enabled
2009 });
2010
2011 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2012 assert_eq!(
2013 captured_request.lock().clone().unwrap().can_collect_data,
2014 false
2015 );
2016 }
2017
2018 #[gpui::test]
2019 async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2020 init_test(cx);
2021
2022 let fs = project::FakeFs::new(cx.executor());
2023 let project = Project::test(fs.clone(), [], cx).await;
2024 let buffer = cx.new(|cx| Buffer::local("", cx));
2025
2026 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2027 zeta.update(cx, |zeta, _cx| {
2028 zeta.data_collection_choice = DataCollectionChoice::Enabled
2029 });
2030
2031 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2032 assert_eq!(
2033 captured_request.lock().clone().unwrap().can_collect_data,
2034 false
2035 );
2036 }
2037
2038 #[gpui::test]
2039 async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2040 init_test(cx);
2041
2042 let fs = project::FakeFs::new(cx.executor());
2043 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2044 .await;
2045
2046 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2047 let buffer = project
2048 .update(cx, |project, cx| {
2049 project.open_local_buffer("/project/main.rs", cx)
2050 })
2051 .await
2052 .unwrap();
2053
2054 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2055 zeta.update(cx, |zeta, _cx| {
2056 zeta.data_collection_choice = DataCollectionChoice::Enabled
2057 });
2058
2059 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2060 assert_eq!(
2061 captured_request.lock().clone().unwrap().can_collect_data,
2062 false
2063 );
2064 }
2065
2066 #[gpui::test]
2067 async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2068 init_test(cx);
2069
2070 let fs = project::FakeFs::new(cx.executor());
2071 fs.insert_tree(
2072 path!("/open_source_worktree"),
2073 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2074 )
2075 .await;
2076 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2077 .await;
2078
2079 let project = Project::test(
2080 fs.clone(),
2081 [
2082 path!("/open_source_worktree").as_ref(),
2083 path!("/closed_source_worktree").as_ref(),
2084 ],
2085 cx,
2086 )
2087 .await;
2088 let buffer = project
2089 .update(cx, |project, cx| {
2090 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2091 })
2092 .await
2093 .unwrap();
2094
2095 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2096 zeta.update(cx, |zeta, _cx| {
2097 zeta.data_collection_choice = DataCollectionChoice::Enabled
2098 });
2099
2100 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2101 assert_eq!(
2102 captured_request.lock().clone().unwrap().can_collect_data,
2103 true
2104 );
2105
2106 let closed_source_file = project
2107 .update(cx, |project, cx| {
2108 let worktree2 = project
2109 .worktree_for_root_name("closed_source_worktree", cx)
2110 .unwrap();
2111 worktree2.update(cx, |worktree2, cx| {
2112 worktree2.load_file(rel_path("main.rs"), cx)
2113 })
2114 })
2115 .await
2116 .unwrap()
2117 .file;
2118
2119 buffer.update(cx, |buffer, cx| {
2120 buffer.file_updated(closed_source_file, cx);
2121 });
2122
2123 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2124 assert_eq!(
2125 captured_request.lock().clone().unwrap().can_collect_data,
2126 false
2127 );
2128 }
2129
2130 #[gpui::test]
2131 async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2132 init_test(cx);
2133
2134 let fs = project::FakeFs::new(cx.executor());
2135 fs.insert_tree(
2136 path!("/worktree1"),
2137 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2138 )
2139 .await;
2140 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2141 .await;
2142
2143 let project = Project::test(
2144 fs.clone(),
2145 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2146 cx,
2147 )
2148 .await;
2149 let buffer = project
2150 .update(cx, |project, cx| {
2151 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2152 })
2153 .await
2154 .unwrap();
2155 let private_buffer = project
2156 .update(cx, |project, cx| {
2157 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2158 })
2159 .await
2160 .unwrap();
2161
2162 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2163 zeta.update(cx, |zeta, _cx| {
2164 zeta.data_collection_choice = DataCollectionChoice::Enabled
2165 });
2166
2167 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2168 assert_eq!(
2169 captured_request.lock().clone().unwrap().can_collect_data,
2170 true
2171 );
2172
2173 // this has a side effect of registering the buffer to watch for edits
2174 run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2175 assert_eq!(
2176 captured_request.lock().clone().unwrap().can_collect_data,
2177 false
2178 );
2179
2180 private_buffer.update(cx, |private_buffer, cx| {
2181 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2182 });
2183
2184 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2185 assert_eq!(
2186 captured_request.lock().clone().unwrap().can_collect_data,
2187 false
2188 );
2189
2190 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2191 // included
2192 buffer.update(cx, |buffer, cx| {
2193 buffer.edit(
2194 [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2195 None,
2196 cx,
2197 );
2198 });
2199
2200 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2201 assert_eq!(
2202 captured_request.lock().clone().unwrap().can_collect_data,
2203 true
2204 );
2205 }
2206
2207 fn init_test(cx: &mut TestAppContext) {
2208 cx.update(|cx| {
2209 let settings_store = SettingsStore::test(cx);
2210 cx.set_global(settings_store);
2211 });
2212 }
2213
2214 async fn apply_edit_prediction(
2215 buffer_content: &str,
2216 completion_response: &str,
2217 cx: &mut TestAppContext,
2218 ) -> String {
2219 let fs = project::FakeFs::new(cx.executor());
2220 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2221 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2222 let (zeta, _, response) = make_test_zeta(&project, cx).await;
2223 *response.lock() = completion_response.to_string();
2224 let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2225 buffer.update(cx, |buffer, cx| {
2226 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2227 });
2228 buffer.read_with(cx, |buffer, _| buffer.text())
2229 }
2230
2231 async fn run_edit_prediction(
2232 buffer: &Entity<Buffer>,
2233 project: &Entity<Project>,
2234 zeta: &Entity<Zeta>,
2235 cx: &mut TestAppContext,
2236 ) -> EditPrediction {
2237 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2238 zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2239 cx.background_executor.run_until_parked();
2240 let completion_task = zeta.update(cx, |zeta, cx| {
2241 zeta.request_completion(&project, buffer, cursor, cx)
2242 });
2243 completion_task.await.unwrap().unwrap()
2244 }
2245
2246 async fn make_test_zeta(
2247 project: &Entity<Project>,
2248 cx: &mut TestAppContext,
2249 ) -> (
2250 Entity<Zeta>,
2251 Arc<Mutex<Option<PredictEditsBody>>>,
2252 Arc<Mutex<String>>,
2253 ) {
2254 let default_response = indoc! {"
2255 ```main.rs
2256 <|start_of_file|>
2257 <|editable_region_start|>
2258 hello world
2259 <|editable_region_end|>
2260 ```"
2261 };
2262 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2263 let completion_response: Arc<Mutex<String>> =
2264 Arc::new(Mutex::new(default_response.to_string()));
2265 let http_client = FakeHttpClient::create({
2266 let captured_request = captured_request.clone();
2267 let completion_response = completion_response.clone();
2268 move |req| {
2269 let captured_request = captured_request.clone();
2270 let completion_response = completion_response.clone();
2271 async move {
2272 match (req.method(), req.uri().path()) {
2273 (&Method::POST, "/client/llm_tokens") => {
2274 Ok(http_client::Response::builder()
2275 .status(200)
2276 .body(
2277 serde_json::to_string(&CreateLlmTokenResponse {
2278 token: LlmToken("the-llm-token".to_string()),
2279 })
2280 .unwrap()
2281 .into(),
2282 )
2283 .unwrap())
2284 }
2285 (&Method::POST, "/predict_edits/v2") => {
2286 let mut request_body = String::new();
2287 req.into_body().read_to_string(&mut request_body).await?;
2288 *captured_request.lock() =
2289 Some(serde_json::from_str(&request_body).unwrap());
2290 Ok(http_client::Response::builder()
2291 .status(200)
2292 .body(
2293 serde_json::to_string(&PredictEditsResponse {
2294 request_id: Uuid::new_v4().to_string(),
2295 output_excerpt: completion_response.lock().clone(),
2296 })
2297 .unwrap()
2298 .into(),
2299 )
2300 .unwrap())
2301 }
2302 _ => Ok(http_client::Response::builder()
2303 .status(404)
2304 .body("Not Found".into())
2305 .unwrap()),
2306 }
2307 }
2308 }
2309 });
2310
2311 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2312 cx.update(|cx| {
2313 RefreshLlmTokenListener::register(client.clone(), cx);
2314 });
2315 let _server = FakeServer::for_client(42, &client, cx).await;
2316
2317 let zeta = cx.new(|cx| {
2318 let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2319
2320 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2321 for worktree in worktrees {
2322 let worktree_id = worktree.read(cx).id();
2323 zeta.license_detection_watchers
2324 .entry(worktree_id)
2325 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2326 }
2327
2328 zeta
2329 });
2330
2331 (zeta, captured_request, completion_response)
2332 }
2333
2334 fn to_completion_edits(
2335 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2336 buffer: &Entity<Buffer>,
2337 cx: &App,
2338 ) -> Vec<(Range<Anchor>, Arc<str>)> {
2339 let buffer = buffer.read(cx);
2340 iterator
2341 .into_iter()
2342 .map(|(range, text)| {
2343 (
2344 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2345 text,
2346 )
2347 })
2348 .collect()
2349 }
2350
2351 fn from_completion_edits(
2352 editor_edits: &[(Range<Anchor>, Arc<str>)],
2353 buffer: &Entity<Buffer>,
2354 cx: &App,
2355 ) -> Vec<(Range<usize>, Arc<str>)> {
2356 let buffer = buffer.read(cx);
2357 editor_edits
2358 .iter()
2359 .map(|(range, text)| {
2360 (
2361 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2362 text.clone(),
2363 )
2364 })
2365 .collect()
2366 }
2367
2368 #[ctor::ctor]
2369 fn init_logger() {
2370 zlog::init_test();
2371 }
2372}