1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use edit_prediction::DataCollectionState;
12pub use init::*;
13use license_detection::LicenseDetectionWatcher;
14pub use rate_completion_modal::*;
15
16use anyhow::{Context as _, Result, anyhow};
17use arrayvec::ArrayVec;
18use client::{Client, EditPredictionUsage, UserStore};
19use cloud_llm_client::{
20 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
21 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
22};
23use collections::{HashMap, HashSet, VecDeque};
24use futures::AsyncReadExt;
25use gpui::{
26 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
27 SharedString, Subscription, Task, actions,
28};
29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
30use input_excerpt::excerpt_for_cursor_position;
31use language::{
32 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
33};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use project::{Project, ProjectPath};
36use release_channel::AppVersion;
37use settings::WorktreeId;
38use std::collections::hash_map;
39use std::mem;
40use std::str::FromStr;
41use std::{
42 cmp,
43 fmt::Write,
44 future::Future,
45 ops::Range,
46 path::Path,
47 rc::Rc,
48 sync::Arc,
49 time::{Duration, Instant},
50};
51use telemetry_events::EditPredictionRating;
52use thiserror::Error;
53use util::ResultExt;
54use util::rel_path::RelPath;
55use uuid::Uuid;
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57use worktree::Worktree;
58
59const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
60const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
61const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
62const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
63const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
64const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
65
66const MAX_CONTEXT_TOKENS: usize = 150;
67const MAX_REWRITE_TOKENS: usize = 350;
68const MAX_EVENT_TOKENS: usize = 500;
69
70/// Maximum number of events to track.
71const MAX_EVENT_COUNT: usize = 16;
72
73actions!(
74 edit_prediction,
75 [
76 /// Clears the edit prediction history.
77 ClearHistory
78 ]
79);
80
81#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
82pub struct EditPredictionId(Uuid);
83
84impl From<EditPredictionId> for gpui::ElementId {
85 fn from(value: EditPredictionId) -> Self {
86 gpui::ElementId::Uuid(value.0)
87 }
88}
89
90impl std::fmt::Display for EditPredictionId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96struct ZedPredictUpsell;
97
98impl Dismissable for ZedPredictUpsell {
99 const KEY: &'static str = "dismissed-edit-predict-upsell";
100
101 fn dismissed() -> bool {
102 // To make this backwards compatible with older versions of Zed, we
103 // check if the user has seen the previous Edit Prediction Onboarding
104 // before, by checking the data collection choice which was written to
105 // the database once the user clicked on "Accept and Enable"
106 if KEY_VALUE_STORE
107 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
108 .log_err()
109 .is_some_and(|s| s.is_some())
110 {
111 return true;
112 }
113
114 KEY_VALUE_STORE
115 .read_kvp(Self::KEY)
116 .log_err()
117 .is_some_and(|s| s.is_some())
118 }
119}
120
121pub fn should_show_upsell_modal() -> bool {
122 !ZedPredictUpsell::dismissed()
123}
124
125#[derive(Clone)]
126struct ZetaGlobal(Entity<Zeta>);
127
128impl Global for ZetaGlobal {}
129
130#[derive(Clone)]
131pub struct EditPrediction {
132 id: EditPredictionId,
133 path: Arc<Path>,
134 excerpt_range: Range<usize>,
135 cursor_offset: usize,
136 edits: Arc<[(Range<Anchor>, String)]>,
137 snapshot: BufferSnapshot,
138 edit_preview: EditPreview,
139 input_outline: Arc<str>,
140 input_events: Arc<str>,
141 input_excerpt: Arc<str>,
142 output_excerpt: Arc<str>,
143 buffer_snapshotted_at: Instant,
144 response_received_at: Instant,
145}
146
147impl EditPrediction {
148 fn latency(&self) -> Duration {
149 self.response_received_at
150 .duration_since(self.buffer_snapshotted_at)
151 }
152
153 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
154 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
155 }
156}
157
158fn interpolate(
159 old_snapshot: &BufferSnapshot,
160 new_snapshot: &BufferSnapshot,
161 current_edits: Arc<[(Range<Anchor>, String)]>,
162) -> Option<Vec<(Range<Anchor>, String)>> {
163 let mut edits = Vec::new();
164
165 let mut model_edits = current_edits.iter().peekable();
166 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
167 while let Some((model_old_range, _)) = model_edits.peek() {
168 let model_old_range = model_old_range.to_offset(old_snapshot);
169 if model_old_range.end < user_edit.old.start {
170 let (model_old_range, model_new_text) = model_edits.next().unwrap();
171 edits.push((model_old_range.clone(), model_new_text.clone()));
172 } else {
173 break;
174 }
175 }
176
177 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
178 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
179 if user_edit.old == model_old_offset_range {
180 let user_new_text = new_snapshot
181 .text_for_range(user_edit.new.clone())
182 .collect::<String>();
183
184 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
185 if !model_suffix.is_empty() {
186 let anchor = old_snapshot.anchor_after(user_edit.old.end);
187 edits.push((anchor..anchor, model_suffix.to_string()));
188 }
189
190 model_edits.next();
191 continue;
192 }
193 }
194 }
195
196 return None;
197 }
198
199 edits.extend(model_edits.cloned());
200
201 if edits.is_empty() { None } else { Some(edits) }
202}
203
204impl std::fmt::Debug for EditPrediction {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("EditPrediction")
207 .field("id", &self.id)
208 .field("path", &self.path)
209 .field("edits", &self.edits)
210 .finish_non_exhaustive()
211 }
212}
213
214pub struct Zeta {
215 projects: HashMap<EntityId, ZetaProject>,
216 client: Arc<Client>,
217 shown_completions: VecDeque<EditPrediction>,
218 rated_completions: HashSet<EditPredictionId>,
219 data_collection_choice: DataCollectionChoice,
220 llm_token: LlmApiToken,
221 _llm_token_subscription: Subscription,
222 /// Whether an update to a newer version of Zed is required to continue using Zeta.
223 update_required: bool,
224 user_store: Entity<UserStore>,
225 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
226}
227
228struct ZetaProject {
229 events: VecDeque<Event>,
230 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
231}
232
233impl Zeta {
234 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
235 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
236 }
237
238 pub fn register(
239 worktree: Option<Entity<Worktree>>,
240 client: Arc<Client>,
241 user_store: Entity<UserStore>,
242 cx: &mut App,
243 ) -> Entity<Self> {
244 let this = Self::global(cx).unwrap_or_else(|| {
245 let entity = cx.new(|cx| Self::new(client, user_store, cx));
246 cx.set_global(ZetaGlobal(entity.clone()));
247 entity
248 });
249
250 this.update(cx, move |this, cx| {
251 if let Some(worktree) = worktree {
252 let worktree_id = worktree.read(cx).id();
253 this.license_detection_watchers
254 .entry(worktree_id)
255 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
256 }
257 });
258
259 this
260 }
261
262 pub fn clear_history(&mut self) {
263 for zeta_project in self.projects.values_mut() {
264 zeta_project.events.clear();
265 }
266 }
267
268 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
269 self.user_store.read(cx).edit_prediction_usage()
270 }
271
272 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
273 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
274 let data_collection_choice = Self::load_data_collection_choice();
275 Self {
276 projects: HashMap::default(),
277 client,
278 shown_completions: VecDeque::new(),
279 rated_completions: HashSet::default(),
280 data_collection_choice,
281 llm_token: LlmApiToken::default(),
282 _llm_token_subscription: cx.subscribe(
283 &refresh_llm_token_listener,
284 |this, _listener, _event, cx| {
285 let client = this.client.clone();
286 let llm_token = this.llm_token.clone();
287 cx.spawn(async move |_this, _cx| {
288 llm_token.refresh(&client).await?;
289 anyhow::Ok(())
290 })
291 .detach_and_log_err(cx);
292 },
293 ),
294 update_required: false,
295 license_detection_watchers: HashMap::default(),
296 user_store,
297 }
298 }
299
300 fn get_or_init_zeta_project(
301 &mut self,
302 project: &Entity<Project>,
303 cx: &mut Context<Self>,
304 ) -> &mut ZetaProject {
305 let project_id = project.entity_id();
306 match self.projects.entry(project_id) {
307 hash_map::Entry::Occupied(entry) => entry.into_mut(),
308 hash_map::Entry::Vacant(entry) => {
309 cx.observe_release(project, move |this, _, _cx| {
310 this.projects.remove(&project_id);
311 })
312 .detach();
313 entry.insert(ZetaProject {
314 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
315 registered_buffers: HashMap::default(),
316 })
317 }
318 }
319 }
320
321 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
322 let events = &mut zeta_project.events;
323
324 if let Some(Event::BufferChange {
325 new_snapshot: last_new_snapshot,
326 timestamp: last_timestamp,
327 ..
328 }) = events.back_mut()
329 {
330 // Coalesce edits for the same buffer when they happen one after the other.
331 let Event::BufferChange {
332 old_snapshot,
333 new_snapshot,
334 timestamp,
335 } = &event;
336
337 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
338 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
339 && old_snapshot.version == last_new_snapshot.version
340 {
341 *last_new_snapshot = new_snapshot.clone();
342 *last_timestamp = *timestamp;
343 return;
344 }
345 }
346
347 if events.len() >= MAX_EVENT_COUNT {
348 // These are halved instead of popping to improve prompt caching.
349 events.drain(..MAX_EVENT_COUNT / 2);
350 }
351
352 events.push_back(event);
353 }
354
355 pub fn register_buffer(
356 &mut self,
357 buffer: &Entity<Buffer>,
358 project: &Entity<Project>,
359 cx: &mut Context<Self>,
360 ) {
361 let zeta_project = self.get_or_init_zeta_project(project, cx);
362 Self::register_buffer_impl(zeta_project, buffer, project, cx);
363 }
364
365 fn register_buffer_impl<'a>(
366 zeta_project: &'a mut ZetaProject,
367 buffer: &Entity<Buffer>,
368 project: &Entity<Project>,
369 cx: &mut Context<Self>,
370 ) -> &'a mut RegisteredBuffer {
371 let buffer_id = buffer.entity_id();
372 match zeta_project.registered_buffers.entry(buffer_id) {
373 hash_map::Entry::Occupied(entry) => entry.into_mut(),
374 hash_map::Entry::Vacant(entry) => {
375 let snapshot = buffer.read(cx).snapshot();
376 let project_entity_id = project.entity_id();
377 entry.insert(RegisteredBuffer {
378 snapshot,
379 _subscriptions: [
380 cx.subscribe(buffer, {
381 let project = project.downgrade();
382 move |this, buffer, event, cx| {
383 if let language::BufferEvent::Edited = event
384 && let Some(project) = project.upgrade()
385 {
386 this.report_changes_for_buffer(&buffer, &project, cx);
387 }
388 }
389 }),
390 cx.observe_release(buffer, move |this, _buffer, _cx| {
391 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
392 else {
393 return;
394 };
395 zeta_project.registered_buffers.remove(&buffer_id);
396 }),
397 ],
398 })
399 }
400 }
401 }
402
403 fn request_completion_impl<F, R>(
404 &mut self,
405 project: &Entity<Project>,
406 buffer: &Entity<Buffer>,
407 cursor: language::Anchor,
408 cx: &mut Context<Self>,
409 perform_predict_edits: F,
410 ) -> Task<Result<Option<EditPrediction>>>
411 where
412 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
413 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
414 + Send
415 + 'static,
416 {
417 let buffer = buffer.clone();
418 let buffer_snapshotted_at = Instant::now();
419 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
420 let zeta = cx.entity();
421 let client = self.client.clone();
422 let llm_token = self.llm_token.clone();
423 let app_version = AppVersion::global(cx);
424
425 let zeta_project = self.get_or_init_zeta_project(project, cx);
426 let mut events = Vec::with_capacity(zeta_project.events.len());
427 events.extend(zeta_project.events.iter().cloned());
428 let events = Arc::new(events);
429
430 let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
431 let can_collect_file = self.can_collect_file(file, cx);
432 let git_info = if can_collect_file {
433 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
434 } else {
435 None
436 };
437 (git_info, can_collect_file)
438 } else {
439 (None, false)
440 };
441
442 let full_path: Arc<Path> = snapshot
443 .file()
444 .map(|f| Arc::from(f.full_path(cx).as_path()))
445 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
446 let full_path_str = full_path.to_string_lossy().into_owned();
447 let cursor_point = cursor.to_point(&snapshot);
448 let cursor_offset = cursor_point.to_offset(&snapshot);
449 let prompt_for_events = {
450 let events = events.clone();
451 move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
452 };
453 let gather_task = gather_context(
454 full_path_str,
455 &snapshot,
456 cursor_point,
457 prompt_for_events,
458 cx,
459 );
460
461 cx.spawn(async move |this, cx| {
462 let GatherContextOutput {
463 mut body,
464 editable_range,
465 included_events_count,
466 } = gather_task.await?;
467 let done_gathering_context_at = Instant::now();
468
469 let included_events = &events[events.len() - included_events_count..events.len()];
470 body.can_collect_data = can_collect_file
471 && this
472 .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
473 .unwrap_or(false);
474 if body.can_collect_data {
475 body.git_info = git_info;
476 }
477
478 log::debug!(
479 "Events:\n{}\nExcerpt:\n{:?}",
480 body.input_events,
481 body.input_excerpt
482 );
483
484 let input_outline = body.outline.clone().unwrap_or_default();
485 let input_events = body.input_events.clone();
486 let input_excerpt = body.input_excerpt.clone();
487
488 let response = perform_predict_edits(PerformPredictEditsParams {
489 client,
490 llm_token,
491 app_version,
492 body,
493 })
494 .await;
495 let (response, usage) = match response {
496 Ok(response) => response,
497 Err(err) => {
498 if err.is::<ZedUpdateRequiredError>() {
499 cx.update(|cx| {
500 zeta.update(cx, |zeta, _cx| {
501 zeta.update_required = true;
502 });
503
504 let error_message: SharedString = err.to_string().into();
505 show_app_notification(
506 NotificationId::unique::<ZedUpdateRequiredError>(),
507 cx,
508 move |cx| {
509 cx.new(|cx| {
510 ErrorMessagePrompt::new(error_message.clone(), cx)
511 .with_link_button(
512 "Update Zed",
513 "https://zed.dev/releases",
514 )
515 })
516 },
517 );
518 })
519 .ok();
520 }
521
522 return Err(err);
523 }
524 };
525
526 let received_response_at = Instant::now();
527 log::debug!("completion response: {}", &response.output_excerpt);
528
529 if let Some(usage) = usage {
530 this.update(cx, |this, cx| {
531 this.user_store.update(cx, |user_store, cx| {
532 user_store.update_edit_prediction_usage(usage, cx);
533 });
534 })
535 .ok();
536 }
537
538 let edit_prediction = Self::process_completion_response(
539 response,
540 buffer,
541 &snapshot,
542 editable_range,
543 cursor_offset,
544 full_path,
545 input_outline,
546 input_events,
547 input_excerpt,
548 buffer_snapshotted_at,
549 cx,
550 )
551 .await;
552
553 let finished_at = Instant::now();
554
555 // record latency for ~1% of requests
556 if rand::random::<u8>() <= 2 {
557 telemetry::event!(
558 "Edit Prediction Request",
559 context_latency = done_gathering_context_at
560 .duration_since(buffer_snapshotted_at)
561 .as_millis(),
562 request_latency = received_response_at
563 .duration_since(done_gathering_context_at)
564 .as_millis(),
565 process_latency = finished_at.duration_since(received_response_at).as_millis()
566 );
567 }
568
569 edit_prediction
570 })
571 }
572
573 #[cfg(any(test, feature = "test-support"))]
574 pub fn fake_completion(
575 &mut self,
576 project: &Entity<Project>,
577 buffer: &Entity<Buffer>,
578 position: language::Anchor,
579 response: PredictEditsResponse,
580 cx: &mut Context<Self>,
581 ) -> Task<Result<Option<EditPrediction>>> {
582 self.request_completion_impl(project, buffer, position, cx, |_params| {
583 std::future::ready(Ok((response, None)))
584 })
585 }
586
587 pub fn request_completion(
588 &mut self,
589 project: &Entity<Project>,
590 buffer: &Entity<Buffer>,
591 position: language::Anchor,
592 cx: &mut Context<Self>,
593 ) -> Task<Result<Option<EditPrediction>>> {
594 self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
595 }
596
597 pub fn perform_predict_edits(
598 params: PerformPredictEditsParams,
599 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
600 async move {
601 let PerformPredictEditsParams {
602 client,
603 llm_token,
604 app_version,
605 body,
606 ..
607 } = params;
608
609 let http_client = client.http_client();
610 let mut token = llm_token.acquire(&client).await?;
611 let mut did_retry = false;
612
613 loop {
614 let request_builder = http_client::Request::builder().method(Method::POST);
615 let request_builder =
616 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
617 request_builder.uri(predict_edits_url)
618 } else {
619 request_builder.uri(
620 http_client
621 .build_zed_llm_url("/predict_edits/v2", &[])?
622 .as_ref(),
623 )
624 };
625 let request = request_builder
626 .header("Content-Type", "application/json")
627 .header("Authorization", format!("Bearer {}", token))
628 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
629 .body(serde_json::to_string(&body)?.into())?;
630
631 let mut response = http_client.send(request).await?;
632
633 if let Some(minimum_required_version) = response
634 .headers()
635 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
636 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
637 {
638 anyhow::ensure!(
639 app_version >= minimum_required_version,
640 ZedUpdateRequiredError {
641 minimum_version: minimum_required_version
642 }
643 );
644 }
645
646 if response.status().is_success() {
647 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
648
649 let mut body = String::new();
650 response.body_mut().read_to_string(&mut body).await?;
651 return Ok((serde_json::from_str(&body)?, usage));
652 } else if !did_retry
653 && response
654 .headers()
655 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
656 .is_some()
657 {
658 did_retry = true;
659 token = llm_token.refresh(&client).await?;
660 } else {
661 let mut body = String::new();
662 response.body_mut().read_to_string(&mut body).await?;
663 anyhow::bail!(
664 "error predicting edits.\nStatus: {:?}\nBody: {}",
665 response.status(),
666 body
667 );
668 }
669 }
670 }
671 }
672
673 fn accept_edit_prediction(
674 &mut self,
675 request_id: EditPredictionId,
676 cx: &mut Context<Self>,
677 ) -> Task<Result<()>> {
678 let client = self.client.clone();
679 let llm_token = self.llm_token.clone();
680 let app_version = AppVersion::global(cx);
681 cx.spawn(async move |this, cx| {
682 let http_client = client.http_client();
683 let mut response = llm_token_retry(&llm_token, &client, |token| {
684 let request_builder = http_client::Request::builder().method(Method::POST);
685 let request_builder =
686 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
687 request_builder.uri(accept_prediction_url)
688 } else {
689 request_builder.uri(
690 http_client
691 .build_zed_llm_url("/predict_edits/accept", &[])?
692 .as_ref(),
693 )
694 };
695 Ok(request_builder
696 .header("Content-Type", "application/json")
697 .header("Authorization", format!("Bearer {}", token))
698 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
699 .body(
700 serde_json::to_string(&AcceptEditPredictionBody {
701 request_id: request_id.0,
702 })?
703 .into(),
704 )?)
705 })
706 .await?;
707
708 if let Some(minimum_required_version) = response
709 .headers()
710 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
711 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
712 && app_version < minimum_required_version
713 {
714 return Err(anyhow!(ZedUpdateRequiredError {
715 minimum_version: minimum_required_version
716 }));
717 }
718
719 if response.status().is_success() {
720 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
721 this.update(cx, |this, cx| {
722 this.user_store.update(cx, |user_store, cx| {
723 user_store.update_edit_prediction_usage(usage, cx);
724 });
725 })?;
726 }
727
728 Ok(())
729 } else {
730 let mut body = String::new();
731 response.body_mut().read_to_string(&mut body).await?;
732 Err(anyhow!(
733 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
734 response.status(),
735 body
736 ))
737 }
738 })
739 }
740
741 fn process_completion_response(
742 prediction_response: PredictEditsResponse,
743 buffer: Entity<Buffer>,
744 snapshot: &BufferSnapshot,
745 editable_range: Range<usize>,
746 cursor_offset: usize,
747 path: Arc<Path>,
748 input_outline: String,
749 input_events: String,
750 input_excerpt: String,
751 buffer_snapshotted_at: Instant,
752 cx: &AsyncApp,
753 ) -> Task<Result<Option<EditPrediction>>> {
754 let snapshot = snapshot.clone();
755 let request_id = prediction_response.request_id;
756 let output_excerpt = prediction_response.output_excerpt;
757 cx.spawn(async move |cx| {
758 let output_excerpt: Arc<str> = output_excerpt.into();
759
760 let edits: Arc<[(Range<Anchor>, String)]> = cx
761 .background_spawn({
762 let output_excerpt = output_excerpt.clone();
763 let editable_range = editable_range.clone();
764 let snapshot = snapshot.clone();
765 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
766 })
767 .await?
768 .into();
769
770 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
771 let edits = edits.clone();
772 |buffer, cx| {
773 let new_snapshot = buffer.snapshot();
774 let edits: Arc<[(Range<Anchor>, String)]> =
775 interpolate(&snapshot, &new_snapshot, edits)?.into();
776 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
777 }
778 })?
779 else {
780 return anyhow::Ok(None);
781 };
782
783 let edit_preview = edit_preview.await;
784
785 Ok(Some(EditPrediction {
786 id: EditPredictionId(request_id),
787 path,
788 excerpt_range: editable_range,
789 cursor_offset,
790 edits,
791 edit_preview,
792 snapshot,
793 input_outline: input_outline.into(),
794 input_events: input_events.into(),
795 input_excerpt: input_excerpt.into(),
796 output_excerpt,
797 buffer_snapshotted_at,
798 response_received_at: Instant::now(),
799 }))
800 })
801 }
802
803 fn parse_edits(
804 output_excerpt: Arc<str>,
805 editable_range: Range<usize>,
806 snapshot: &BufferSnapshot,
807 ) -> Result<Vec<(Range<Anchor>, String)>> {
808 let content = output_excerpt.replace(CURSOR_MARKER, "");
809
810 let start_markers = content
811 .match_indices(EDITABLE_REGION_START_MARKER)
812 .collect::<Vec<_>>();
813 anyhow::ensure!(
814 start_markers.len() == 1,
815 "expected exactly one start marker, found {}",
816 start_markers.len()
817 );
818
819 let end_markers = content
820 .match_indices(EDITABLE_REGION_END_MARKER)
821 .collect::<Vec<_>>();
822 anyhow::ensure!(
823 end_markers.len() == 1,
824 "expected exactly one end marker, found {}",
825 end_markers.len()
826 );
827
828 let sof_markers = content
829 .match_indices(START_OF_FILE_MARKER)
830 .collect::<Vec<_>>();
831 anyhow::ensure!(
832 sof_markers.len() <= 1,
833 "expected at most one start-of-file marker, found {}",
834 sof_markers.len()
835 );
836
837 let codefence_start = start_markers[0].0;
838 let content = &content[codefence_start..];
839
840 let newline_ix = content.find('\n').context("could not find newline")?;
841 let content = &content[newline_ix + 1..];
842
843 let codefence_end = content
844 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
845 .context("could not find end marker")?;
846 let new_text = &content[..codefence_end];
847
848 let old_text = snapshot
849 .text_for_range(editable_range.clone())
850 .collect::<String>();
851
852 Ok(Self::compute_edits(
853 old_text,
854 new_text,
855 editable_range.start,
856 snapshot,
857 ))
858 }
859
860 pub fn compute_edits(
861 old_text: String,
862 new_text: &str,
863 offset: usize,
864 snapshot: &BufferSnapshot,
865 ) -> Vec<(Range<Anchor>, String)> {
866 text_diff(&old_text, new_text)
867 .into_iter()
868 .map(|(mut old_range, new_text)| {
869 old_range.start += offset;
870 old_range.end += offset;
871
872 let prefix_len = common_prefix(
873 snapshot.chars_for_range(old_range.clone()),
874 new_text.chars(),
875 );
876 old_range.start += prefix_len;
877
878 let suffix_len = common_prefix(
879 snapshot.reversed_chars_for_range(old_range.clone()),
880 new_text[prefix_len..].chars().rev(),
881 );
882 old_range.end = old_range.end.saturating_sub(suffix_len);
883
884 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
885 let range = if old_range.is_empty() {
886 let anchor = snapshot.anchor_after(old_range.start);
887 anchor..anchor
888 } else {
889 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
890 };
891 (range, new_text)
892 })
893 .collect()
894 }
895
896 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
897 self.rated_completions.contains(&completion_id)
898 }
899
900 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
901 self.shown_completions.push_front(completion.clone());
902 if self.shown_completions.len() > 50 {
903 let completion = self.shown_completions.pop_back().unwrap();
904 self.rated_completions.remove(&completion.id);
905 }
906 cx.notify();
907 }
908
909 pub fn rate_completion(
910 &mut self,
911 completion: &EditPrediction,
912 rating: EditPredictionRating,
913 feedback: String,
914 cx: &mut Context<Self>,
915 ) {
916 self.rated_completions.insert(completion.id);
917 telemetry::event!(
918 "Edit Prediction Rated",
919 rating,
920 input_events = completion.input_events,
921 input_excerpt = completion.input_excerpt,
922 input_outline = completion.input_outline,
923 output_excerpt = completion.output_excerpt,
924 feedback
925 );
926 self.client.telemetry().flush_events().detach();
927 cx.notify();
928 }
929
930 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
931 self.shown_completions.iter()
932 }
933
934 pub fn shown_completions_len(&self) -> usize {
935 self.shown_completions.len()
936 }
937
938 fn report_changes_for_buffer(
939 &mut self,
940 buffer: &Entity<Buffer>,
941 project: &Entity<Project>,
942 cx: &mut Context<Self>,
943 ) -> BufferSnapshot {
944 let zeta_project = self.get_or_init_zeta_project(project, cx);
945 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
946
947 let new_snapshot = buffer.read(cx).snapshot();
948 if new_snapshot.version != registered_buffer.snapshot.version {
949 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
950 Self::push_event(
951 zeta_project,
952 Event::BufferChange {
953 old_snapshot,
954 new_snapshot: new_snapshot.clone(),
955 timestamp: Instant::now(),
956 },
957 );
958 }
959
960 new_snapshot
961 }
962
963 fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
964 self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
965 }
966
967 fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
968 if !self.data_collection_choice.is_enabled() {
969 return false;
970 }
971 let mut last_checked_file = None;
972 for event in events {
973 match event {
974 Event::BufferChange {
975 old_snapshot,
976 new_snapshot,
977 ..
978 } => {
979 if let Some(old_file) = old_snapshot.file()
980 && let Some(new_file) = new_snapshot.file()
981 {
982 if let Some(last_checked_file) = last_checked_file
983 && Arc::ptr_eq(last_checked_file, old_file)
984 && Arc::ptr_eq(last_checked_file, new_file)
985 {
986 continue;
987 }
988 if !self.can_collect_file(old_file, cx) {
989 return false;
990 }
991 if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
992 {
993 return false;
994 }
995 last_checked_file = Some(new_file);
996 } else {
997 return false;
998 }
999 }
1000 }
1001 }
1002 true
1003 }
1004
1005 fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1006 if !file.is_local() || file.is_private() {
1007 return false;
1008 }
1009 self.license_detection_watchers
1010 .get(&file.worktree_id(cx))
1011 .is_some_and(|watcher| watcher.is_project_open_source())
1012 }
1013
1014 fn load_data_collection_choice() -> DataCollectionChoice {
1015 let choice = KEY_VALUE_STORE
1016 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1017 .log_err()
1018 .flatten();
1019
1020 match choice.as_deref() {
1021 Some("true") => DataCollectionChoice::Enabled,
1022 Some("false") => DataCollectionChoice::Disabled,
1023 Some(_) => {
1024 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1025 DataCollectionChoice::NotAnswered
1026 }
1027 None => DataCollectionChoice::NotAnswered,
1028 }
1029 }
1030
1031 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1032 self.data_collection_choice = self.data_collection_choice.toggle();
1033 let new_choice = self.data_collection_choice;
1034 db::write_and_log(cx, move || {
1035 KEY_VALUE_STORE.write_kvp(
1036 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1037 new_choice.is_enabled().to_string(),
1038 )
1039 });
1040 }
1041}
1042
1043pub struct PerformPredictEditsParams {
1044 pub client: Arc<Client>,
1045 pub llm_token: LlmApiToken,
1046 pub app_version: SemanticVersion,
1047 pub body: PredictEditsBody,
1048}
1049
1050#[derive(Error, Debug)]
1051#[error(
1052 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1053)]
1054pub struct ZedUpdateRequiredError {
1055 minimum_version: SemanticVersion,
1056}
1057
1058fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1059 a.zip(b)
1060 .take_while(|(a, b)| a == b)
1061 .map(|(a, _)| a.len_utf8())
1062 .sum()
1063}
1064
1065fn git_info_for_file(
1066 project: &Entity<Project>,
1067 project_path: &ProjectPath,
1068 cx: &App,
1069) -> Option<PredictEditsGitInfo> {
1070 let git_store = project.read(cx).git_store().read(cx);
1071 if let Some((repository, _repo_path)) =
1072 git_store.repository_and_path_for_project_path(project_path, cx)
1073 {
1074 let repository = repository.read(cx);
1075 let head_sha = repository
1076 .head_commit
1077 .as_ref()
1078 .map(|head_commit| head_commit.sha.to_string());
1079 let remote_origin_url = repository.remote_origin_url.clone();
1080 let remote_upstream_url = repository.remote_upstream_url.clone();
1081 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1082 return None;
1083 }
1084 Some(PredictEditsGitInfo {
1085 head_sha,
1086 remote_origin_url,
1087 remote_upstream_url,
1088 })
1089 } else {
1090 None
1091 }
1092}
1093
1094pub struct GatherContextOutput {
1095 pub body: PredictEditsBody,
1096 pub editable_range: Range<usize>,
1097 pub included_events_count: usize,
1098}
1099
1100pub fn gather_context(
1101 full_path_str: String,
1102 snapshot: &BufferSnapshot,
1103 cursor_point: language::Point,
1104 prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1105 cx: &App,
1106) -> Task<Result<GatherContextOutput>> {
1107 cx.background_spawn({
1108 let snapshot = snapshot.clone();
1109 async move {
1110 let input_excerpt = excerpt_for_cursor_position(
1111 cursor_point,
1112 &full_path_str,
1113 &snapshot,
1114 MAX_REWRITE_TOKENS,
1115 MAX_CONTEXT_TOKENS,
1116 );
1117 let (input_events, included_events_count) = prompt_for_events();
1118 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1119
1120 let body = PredictEditsBody {
1121 input_events,
1122 input_excerpt: input_excerpt.prompt,
1123 can_collect_data: false,
1124 diagnostic_groups: None,
1125 git_info: None,
1126 outline: None,
1127 speculated_output: None,
1128 };
1129
1130 Ok(GatherContextOutput {
1131 body,
1132 editable_range,
1133 included_events_count,
1134 })
1135 }
1136 })
1137}
1138
1139fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1140 let mut result = String::new();
1141 for (ix, event) in events.iter().rev().enumerate() {
1142 let event_string = event.to_prompt();
1143 let event_tokens = guess_token_count(event_string.len());
1144 if event_tokens > remaining_tokens {
1145 return (result, ix);
1146 }
1147
1148 if !result.is_empty() {
1149 result.insert_str(0, "\n\n");
1150 }
1151 result.insert_str(0, &event_string);
1152 remaining_tokens -= event_tokens;
1153 }
1154 return (result, events.len());
1155}
1156
1157struct RegisteredBuffer {
1158 snapshot: BufferSnapshot,
1159 _subscriptions: [gpui::Subscription; 2],
1160}
1161
1162#[derive(Clone)]
1163pub enum Event {
1164 BufferChange {
1165 old_snapshot: BufferSnapshot,
1166 new_snapshot: BufferSnapshot,
1167 timestamp: Instant,
1168 },
1169}
1170
1171impl Event {
1172 fn to_prompt(&self) -> String {
1173 match self {
1174 Event::BufferChange {
1175 old_snapshot,
1176 new_snapshot,
1177 ..
1178 } => {
1179 let mut prompt = String::new();
1180
1181 let old_path = old_snapshot
1182 .file()
1183 .map(|f| f.path().as_ref())
1184 .unwrap_or(RelPath::unix("untitled").unwrap());
1185 let new_path = new_snapshot
1186 .file()
1187 .map(|f| f.path().as_ref())
1188 .unwrap_or(RelPath::unix("untitled").unwrap());
1189 if old_path != new_path {
1190 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1191 }
1192
1193 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1194 if !diff.is_empty() {
1195 write!(
1196 prompt,
1197 "User edited {:?}:\n```diff\n{}\n```",
1198 new_path, diff
1199 )
1200 .unwrap();
1201 }
1202
1203 prompt
1204 }
1205 }
1206 }
1207}
1208
1209#[derive(Debug, Clone)]
1210struct CurrentEditPrediction {
1211 buffer_id: EntityId,
1212 completion: EditPrediction,
1213}
1214
1215impl CurrentEditPrediction {
1216 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1217 if self.buffer_id != old_completion.buffer_id {
1218 return true;
1219 }
1220
1221 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1222 return true;
1223 };
1224 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1225 return false;
1226 };
1227
1228 if old_edits.len() == 1 && new_edits.len() == 1 {
1229 let (old_range, old_text) = &old_edits[0];
1230 let (new_range, new_text) = &new_edits[0];
1231 new_range == old_range && new_text.starts_with(old_text)
1232 } else {
1233 true
1234 }
1235 }
1236}
1237
1238struct PendingCompletion {
1239 id: usize,
1240 _task: Task<()>,
1241}
1242
1243#[derive(Debug, Clone, Copy)]
1244pub enum DataCollectionChoice {
1245 NotAnswered,
1246 Enabled,
1247 Disabled,
1248}
1249
1250impl DataCollectionChoice {
1251 pub fn is_enabled(self) -> bool {
1252 match self {
1253 Self::Enabled => true,
1254 Self::NotAnswered | Self::Disabled => false,
1255 }
1256 }
1257
1258 pub fn is_answered(self) -> bool {
1259 match self {
1260 Self::Enabled | Self::Disabled => true,
1261 Self::NotAnswered => false,
1262 }
1263 }
1264
1265 #[must_use]
1266 pub fn toggle(&self) -> DataCollectionChoice {
1267 match self {
1268 Self::Enabled => Self::Disabled,
1269 Self::Disabled => Self::Enabled,
1270 Self::NotAnswered => Self::Enabled,
1271 }
1272 }
1273}
1274
1275impl From<bool> for DataCollectionChoice {
1276 fn from(value: bool) -> Self {
1277 match value {
1278 true => DataCollectionChoice::Enabled,
1279 false => DataCollectionChoice::Disabled,
1280 }
1281 }
1282}
1283
1284async fn llm_token_retry(
1285 llm_token: &LlmApiToken,
1286 client: &Arc<Client>,
1287 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1288) -> Result<Response<AsyncBody>> {
1289 let mut did_retry = false;
1290 let http_client = client.http_client();
1291 let mut token = llm_token.acquire(client).await?;
1292 loop {
1293 let request = build_request(token.clone())?;
1294 let response = http_client.send(request).await?;
1295
1296 if !did_retry
1297 && !response.status().is_success()
1298 && response
1299 .headers()
1300 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1301 .is_some()
1302 {
1303 did_retry = true;
1304 token = llm_token.refresh(client).await?;
1305 continue;
1306 }
1307
1308 return Ok(response);
1309 }
1310}
1311
1312pub struct ZetaEditPredictionProvider {
1313 zeta: Entity<Zeta>,
1314 singleton_buffer: Option<Entity<Buffer>>,
1315 pending_completions: ArrayVec<PendingCompletion, 2>,
1316 next_pending_completion_id: usize,
1317 current_completion: Option<CurrentEditPrediction>,
1318 last_request_timestamp: Instant,
1319 project: Entity<Project>,
1320}
1321
1322impl ZetaEditPredictionProvider {
1323 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1324
1325 pub fn new(
1326 zeta: Entity<Zeta>,
1327 project: Entity<Project>,
1328 singleton_buffer: Option<Entity<Buffer>>,
1329 ) -> Self {
1330 Self {
1331 zeta,
1332 singleton_buffer,
1333 pending_completions: ArrayVec::new(),
1334 next_pending_completion_id: 0,
1335 current_completion: None,
1336 last_request_timestamp: Instant::now(),
1337 project,
1338 }
1339 }
1340}
1341
1342impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1343 fn name() -> &'static str {
1344 "zed-predict"
1345 }
1346
1347 fn display_name() -> &'static str {
1348 "Zed's Edit Predictions"
1349 }
1350
1351 fn show_completions_in_menu() -> bool {
1352 true
1353 }
1354
1355 fn show_tab_accept_marker() -> bool {
1356 true
1357 }
1358
1359 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1360 if let Some(buffer) = &self.singleton_buffer
1361 && let Some(file) = buffer.read(cx).file()
1362 {
1363 let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1364 if self.zeta.read(cx).data_collection_choice.is_enabled() {
1365 DataCollectionState::Enabled {
1366 is_project_open_source,
1367 }
1368 } else {
1369 DataCollectionState::Disabled {
1370 is_project_open_source,
1371 }
1372 }
1373 } else {
1374 return DataCollectionState::Disabled {
1375 is_project_open_source: false,
1376 };
1377 }
1378 }
1379
1380 fn toggle_data_collection(&mut self, cx: &mut App) {
1381 self.zeta
1382 .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1383 }
1384
1385 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1386 self.zeta.read(cx).usage(cx)
1387 }
1388
1389 fn is_enabled(
1390 &self,
1391 _buffer: &Entity<Buffer>,
1392 _cursor_position: language::Anchor,
1393 _cx: &App,
1394 ) -> bool {
1395 true
1396 }
1397 fn is_refreshing(&self) -> bool {
1398 !self.pending_completions.is_empty()
1399 }
1400
1401 fn refresh(
1402 &mut self,
1403 buffer: Entity<Buffer>,
1404 position: language::Anchor,
1405 _debounce: bool,
1406 cx: &mut Context<Self>,
1407 ) {
1408 if self.zeta.read(cx).update_required {
1409 return;
1410 }
1411
1412 if self
1413 .zeta
1414 .read(cx)
1415 .user_store
1416 .read_with(cx, |user_store, _cx| {
1417 user_store.account_too_young() || user_store.has_overdue_invoices()
1418 })
1419 {
1420 return;
1421 }
1422
1423 if let Some(current_completion) = self.current_completion.as_ref() {
1424 let snapshot = buffer.read(cx).snapshot();
1425 if current_completion
1426 .completion
1427 .interpolate(&snapshot)
1428 .is_some()
1429 {
1430 return;
1431 }
1432 }
1433
1434 let pending_completion_id = self.next_pending_completion_id;
1435 self.next_pending_completion_id += 1;
1436 let last_request_timestamp = self.last_request_timestamp;
1437
1438 let project = self.project.clone();
1439 let task = cx.spawn(async move |this, cx| {
1440 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1441 .checked_duration_since(Instant::now())
1442 {
1443 cx.background_executor().timer(timeout).await;
1444 }
1445
1446 let completion_request = this.update(cx, |this, cx| {
1447 this.last_request_timestamp = Instant::now();
1448 this.zeta.update(cx, |zeta, cx| {
1449 zeta.request_completion(&project, &buffer, position, cx)
1450 })
1451 });
1452
1453 let completion = match completion_request {
1454 Ok(completion_request) => {
1455 let completion_request = completion_request.await;
1456 completion_request.map(|c| {
1457 c.map(|completion| CurrentEditPrediction {
1458 buffer_id: buffer.entity_id(),
1459 completion,
1460 })
1461 })
1462 }
1463 Err(error) => Err(error),
1464 };
1465 let Some(new_completion) = completion
1466 .context("edit prediction failed")
1467 .log_err()
1468 .flatten()
1469 else {
1470 this.update(cx, |this, cx| {
1471 if this.pending_completions[0].id == pending_completion_id {
1472 this.pending_completions.remove(0);
1473 } else {
1474 this.pending_completions.clear();
1475 }
1476
1477 cx.notify();
1478 })
1479 .ok();
1480 return;
1481 };
1482
1483 this.update(cx, |this, cx| {
1484 if this.pending_completions[0].id == pending_completion_id {
1485 this.pending_completions.remove(0);
1486 } else {
1487 this.pending_completions.clear();
1488 }
1489
1490 if let Some(old_completion) = this.current_completion.as_ref() {
1491 let snapshot = buffer.read(cx).snapshot();
1492 if new_completion.should_replace_completion(old_completion, &snapshot) {
1493 this.zeta.update(cx, |zeta, cx| {
1494 zeta.completion_shown(&new_completion.completion, cx);
1495 });
1496 this.current_completion = Some(new_completion);
1497 }
1498 } else {
1499 this.zeta.update(cx, |zeta, cx| {
1500 zeta.completion_shown(&new_completion.completion, cx);
1501 });
1502 this.current_completion = Some(new_completion);
1503 }
1504
1505 cx.notify();
1506 })
1507 .ok();
1508 });
1509
1510 // We always maintain at most two pending completions. When we already
1511 // have two, we replace the newest one.
1512 if self.pending_completions.len() <= 1 {
1513 self.pending_completions.push(PendingCompletion {
1514 id: pending_completion_id,
1515 _task: task,
1516 });
1517 } else if self.pending_completions.len() == 2 {
1518 self.pending_completions.pop();
1519 self.pending_completions.push(PendingCompletion {
1520 id: pending_completion_id,
1521 _task: task,
1522 });
1523 }
1524 }
1525
1526 fn cycle(
1527 &mut self,
1528 _buffer: Entity<Buffer>,
1529 _cursor_position: language::Anchor,
1530 _direction: edit_prediction::Direction,
1531 _cx: &mut Context<Self>,
1532 ) {
1533 // Right now we don't support cycling.
1534 }
1535
1536 fn accept(&mut self, cx: &mut Context<Self>) {
1537 let completion_id = self
1538 .current_completion
1539 .as_ref()
1540 .map(|completion| completion.completion.id);
1541 if let Some(completion_id) = completion_id {
1542 self.zeta
1543 .update(cx, |zeta, cx| {
1544 zeta.accept_edit_prediction(completion_id, cx)
1545 })
1546 .detach();
1547 }
1548 self.pending_completions.clear();
1549 }
1550
1551 fn discard(&mut self, _cx: &mut Context<Self>) {
1552 self.pending_completions.clear();
1553 self.current_completion.take();
1554 }
1555
1556 fn suggest(
1557 &mut self,
1558 buffer: &Entity<Buffer>,
1559 cursor_position: language::Anchor,
1560 cx: &mut Context<Self>,
1561 ) -> Option<edit_prediction::EditPrediction> {
1562 let CurrentEditPrediction {
1563 buffer_id,
1564 completion,
1565 ..
1566 } = self.current_completion.as_mut()?;
1567
1568 // Invalidate previous completion if it was generated for a different buffer.
1569 if *buffer_id != buffer.entity_id() {
1570 self.current_completion.take();
1571 return None;
1572 }
1573
1574 let buffer = buffer.read(cx);
1575 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1576 self.current_completion.take();
1577 return None;
1578 };
1579
1580 let cursor_row = cursor_position.to_point(buffer).row;
1581 let (closest_edit_ix, (closest_edit_range, _)) =
1582 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1583 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1584 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1585 cmp::min(distance_from_start, distance_from_end)
1586 })?;
1587
1588 let mut edit_start_ix = closest_edit_ix;
1589 for (range, _) in edits[..edit_start_ix].iter().rev() {
1590 let distance_from_closest_edit =
1591 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1592 if distance_from_closest_edit <= 1 {
1593 edit_start_ix -= 1;
1594 } else {
1595 break;
1596 }
1597 }
1598
1599 let mut edit_end_ix = closest_edit_ix + 1;
1600 for (range, _) in &edits[edit_end_ix..] {
1601 let distance_from_closest_edit =
1602 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1603 if distance_from_closest_edit <= 1 {
1604 edit_end_ix += 1;
1605 } else {
1606 break;
1607 }
1608 }
1609
1610 Some(edit_prediction::EditPrediction::Local {
1611 id: Some(completion.id.to_string().into()),
1612 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1613 edit_preview: Some(completion.edit_preview.clone()),
1614 })
1615 }
1616}
1617
1618/// Typical number of string bytes per token for the purposes of limiting model input. This is
1619/// intentionally low to err on the side of underestimating limits.
1620const BYTES_PER_TOKEN_GUESS: usize = 3;
1621
1622fn guess_token_count(bytes: usize) -> usize {
1623 bytes / BYTES_PER_TOKEN_GUESS
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628 use client::test::FakeServer;
1629 use clock::FakeSystemClock;
1630 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1631 use gpui::TestAppContext;
1632 use http_client::FakeHttpClient;
1633 use indoc::indoc;
1634 use language::Point;
1635 use parking_lot::Mutex;
1636 use serde_json::json;
1637 use settings::SettingsStore;
1638 use util::{path, rel_path::rel_path};
1639
1640 use super::*;
1641
1642 const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1643
1644 #[gpui::test]
1645 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1646 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1647 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1648 to_completion_edits(
1649 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1650 &buffer,
1651 cx,
1652 )
1653 .into()
1654 });
1655
1656 let edit_preview = cx
1657 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1658 .await;
1659
1660 let completion = EditPrediction {
1661 edits,
1662 edit_preview,
1663 path: Path::new("").into(),
1664 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1665 id: EditPredictionId(Uuid::new_v4()),
1666 excerpt_range: 0..0,
1667 cursor_offset: 0,
1668 input_outline: "".into(),
1669 input_events: "".into(),
1670 input_excerpt: "".into(),
1671 output_excerpt: "".into(),
1672 buffer_snapshotted_at: Instant::now(),
1673 response_received_at: Instant::now(),
1674 };
1675
1676 cx.update(|cx| {
1677 assert_eq!(
1678 from_completion_edits(
1679 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1680 &buffer,
1681 cx
1682 ),
1683 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1684 );
1685
1686 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1687 assert_eq!(
1688 from_completion_edits(
1689 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1690 &buffer,
1691 cx
1692 ),
1693 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1694 );
1695
1696 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1697 assert_eq!(
1698 from_completion_edits(
1699 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1700 &buffer,
1701 cx
1702 ),
1703 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1704 );
1705
1706 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1707 assert_eq!(
1708 from_completion_edits(
1709 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1710 &buffer,
1711 cx
1712 ),
1713 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1714 );
1715
1716 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1717 assert_eq!(
1718 from_completion_edits(
1719 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1720 &buffer,
1721 cx
1722 ),
1723 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1724 );
1725
1726 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1727 assert_eq!(
1728 from_completion_edits(
1729 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1730 &buffer,
1731 cx
1732 ),
1733 vec![(9..11, "".to_string())]
1734 );
1735
1736 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1737 assert_eq!(
1738 from_completion_edits(
1739 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1740 &buffer,
1741 cx
1742 ),
1743 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1744 );
1745
1746 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1747 assert_eq!(
1748 from_completion_edits(
1749 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1750 &buffer,
1751 cx
1752 ),
1753 vec![(4..4, "M".to_string())]
1754 );
1755
1756 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1757 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1758 })
1759 }
1760
1761 #[gpui::test]
1762 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1763 init_test(cx);
1764
1765 assert_eq!(
1766 apply_edit_prediction(
1767 indoc! {"
1768 fn main() {
1769 let word_1 = \"lorem\";
1770 let range = word.len()..word.len();
1771 }
1772 "},
1773 indoc! {"
1774 <|editable_region_start|>
1775 fn main() {
1776 let word_1 = \"lorem\";
1777 let range = word_1.len()..word_1.len();
1778 }
1779
1780 <|editable_region_end|>
1781 "},
1782 cx,
1783 )
1784 .await,
1785 indoc! {"
1786 fn main() {
1787 let word_1 = \"lorem\";
1788 let range = word_1.len()..word_1.len();
1789 }
1790 "},
1791 );
1792
1793 assert_eq!(
1794 apply_edit_prediction(
1795 indoc! {"
1796 fn main() {
1797 let story = \"the quick\"
1798 }
1799 "},
1800 indoc! {"
1801 <|editable_region_start|>
1802 fn main() {
1803 let story = \"the quick brown fox jumps over the lazy dog\";
1804 }
1805
1806 <|editable_region_end|>
1807 "},
1808 cx,
1809 )
1810 .await,
1811 indoc! {"
1812 fn main() {
1813 let story = \"the quick brown fox jumps over the lazy dog\";
1814 }
1815 "},
1816 );
1817 }
1818
1819 #[gpui::test]
1820 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1821 init_test(cx);
1822
1823 let buffer_content = "lorem\n";
1824 let completion_response = indoc! {"
1825 ```animals.js
1826 <|start_of_file|>
1827 <|editable_region_start|>
1828 lorem
1829 ipsum
1830 <|editable_region_end|>
1831 ```"};
1832
1833 assert_eq!(
1834 apply_edit_prediction(buffer_content, completion_response, cx).await,
1835 "lorem\nipsum"
1836 );
1837 }
1838
1839 #[gpui::test]
1840 async fn test_can_collect_data(cx: &mut TestAppContext) {
1841 init_test(cx);
1842
1843 let fs = project::FakeFs::new(cx.executor());
1844 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1845 .await;
1846
1847 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1848 let buffer = project
1849 .update(cx, |project, cx| {
1850 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1851 })
1852 .await
1853 .unwrap();
1854
1855 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1856 zeta.update(cx, |zeta, _cx| {
1857 zeta.data_collection_choice = DataCollectionChoice::Enabled
1858 });
1859
1860 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1861 assert_eq!(
1862 captured_request.lock().clone().unwrap().can_collect_data,
1863 true
1864 );
1865
1866 zeta.update(cx, |zeta, _cx| {
1867 zeta.data_collection_choice = DataCollectionChoice::Disabled
1868 });
1869
1870 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1871 assert_eq!(
1872 captured_request.lock().clone().unwrap().can_collect_data,
1873 false
1874 );
1875 }
1876
1877 #[gpui::test]
1878 async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1879 init_test(cx);
1880
1881 let fs = project::FakeFs::new(cx.executor());
1882 let project = Project::test(fs.clone(), [], cx).await;
1883
1884 let buffer = cx.new(|_cx| {
1885 Buffer::remote(
1886 language::BufferId::new(1).unwrap(),
1887 1,
1888 language::Capability::ReadWrite,
1889 "fn main() {\n println!(\"Hello\");\n}",
1890 )
1891 });
1892
1893 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1894 zeta.update(cx, |zeta, _cx| {
1895 zeta.data_collection_choice = DataCollectionChoice::Enabled
1896 });
1897
1898 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1899 assert_eq!(
1900 captured_request.lock().clone().unwrap().can_collect_data,
1901 false
1902 );
1903 }
1904
1905 #[gpui::test]
1906 async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1907 init_test(cx);
1908
1909 let fs = project::FakeFs::new(cx.executor());
1910 fs.insert_tree(
1911 path!("/project"),
1912 json!({
1913 "LICENSE": BSD_0_TXT,
1914 ".env": "SECRET_KEY=secret"
1915 }),
1916 )
1917 .await;
1918
1919 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1920 let buffer = project
1921 .update(cx, |project, cx| {
1922 project.open_local_buffer("/project/.env", cx)
1923 })
1924 .await
1925 .unwrap();
1926
1927 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1928 zeta.update(cx, |zeta, _cx| {
1929 zeta.data_collection_choice = DataCollectionChoice::Enabled
1930 });
1931
1932 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1933 assert_eq!(
1934 captured_request.lock().clone().unwrap().can_collect_data,
1935 false
1936 );
1937 }
1938
1939 #[gpui::test]
1940 async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1941 init_test(cx);
1942
1943 let fs = project::FakeFs::new(cx.executor());
1944 let project = Project::test(fs.clone(), [], cx).await;
1945 let buffer = cx.new(|cx| Buffer::local("", cx));
1946
1947 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1948 zeta.update(cx, |zeta, _cx| {
1949 zeta.data_collection_choice = DataCollectionChoice::Enabled
1950 });
1951
1952 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1953 assert_eq!(
1954 captured_request.lock().clone().unwrap().can_collect_data,
1955 false
1956 );
1957 }
1958
1959 #[gpui::test]
1960 async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1961 init_test(cx);
1962
1963 let fs = project::FakeFs::new(cx.executor());
1964 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1965 .await;
1966
1967 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1968 let buffer = project
1969 .update(cx, |project, cx| {
1970 project.open_local_buffer("/project/main.rs", cx)
1971 })
1972 .await
1973 .unwrap();
1974
1975 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1976 zeta.update(cx, |zeta, _cx| {
1977 zeta.data_collection_choice = DataCollectionChoice::Enabled
1978 });
1979
1980 run_edit_prediction(&buffer, &project, &zeta, cx).await;
1981 assert_eq!(
1982 captured_request.lock().clone().unwrap().can_collect_data,
1983 false
1984 );
1985 }
1986
1987 #[gpui::test]
1988 async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1989 init_test(cx);
1990
1991 let fs = project::FakeFs::new(cx.executor());
1992 fs.insert_tree(
1993 path!("/open_source_worktree"),
1994 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1995 )
1996 .await;
1997 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1998 .await;
1999
2000 let project = Project::test(
2001 fs.clone(),
2002 [
2003 path!("/open_source_worktree").as_ref(),
2004 path!("/closed_source_worktree").as_ref(),
2005 ],
2006 cx,
2007 )
2008 .await;
2009 let buffer = project
2010 .update(cx, |project, cx| {
2011 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2012 })
2013 .await
2014 .unwrap();
2015
2016 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2017 zeta.update(cx, |zeta, _cx| {
2018 zeta.data_collection_choice = DataCollectionChoice::Enabled
2019 });
2020
2021 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2022 assert_eq!(
2023 captured_request.lock().clone().unwrap().can_collect_data,
2024 true
2025 );
2026
2027 let closed_source_file = project
2028 .update(cx, |project, cx| {
2029 let worktree2 = project
2030 .worktree_for_root_name("closed_source_worktree", cx)
2031 .unwrap();
2032 worktree2.update(cx, |worktree2, cx| {
2033 worktree2.load_file(rel_path("main.rs"), cx)
2034 })
2035 })
2036 .await
2037 .unwrap()
2038 .file;
2039
2040 buffer.update(cx, |buffer, cx| {
2041 buffer.file_updated(closed_source_file, cx);
2042 });
2043
2044 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2045 assert_eq!(
2046 captured_request.lock().clone().unwrap().can_collect_data,
2047 false
2048 );
2049 }
2050
2051 #[gpui::test]
2052 async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2053 init_test(cx);
2054
2055 let fs = project::FakeFs::new(cx.executor());
2056 fs.insert_tree(
2057 path!("/worktree1"),
2058 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2059 )
2060 .await;
2061 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2062 .await;
2063
2064 let project = Project::test(
2065 fs.clone(),
2066 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2067 cx,
2068 )
2069 .await;
2070 let buffer = project
2071 .update(cx, |project, cx| {
2072 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2073 })
2074 .await
2075 .unwrap();
2076 let private_buffer = project
2077 .update(cx, |project, cx| {
2078 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2079 })
2080 .await
2081 .unwrap();
2082
2083 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2084 zeta.update(cx, |zeta, _cx| {
2085 zeta.data_collection_choice = DataCollectionChoice::Enabled
2086 });
2087
2088 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2089 assert_eq!(
2090 captured_request.lock().clone().unwrap().can_collect_data,
2091 true
2092 );
2093
2094 // this has a side effect of registering the buffer to watch for edits
2095 run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2096 assert_eq!(
2097 captured_request.lock().clone().unwrap().can_collect_data,
2098 false
2099 );
2100
2101 private_buffer.update(cx, |private_buffer, cx| {
2102 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2103 });
2104
2105 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2106 assert_eq!(
2107 captured_request.lock().clone().unwrap().can_collect_data,
2108 false
2109 );
2110
2111 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2112 // included
2113 buffer.update(cx, |buffer, cx| {
2114 buffer.edit(
2115 [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2116 None,
2117 cx,
2118 );
2119 });
2120
2121 run_edit_prediction(&buffer, &project, &zeta, cx).await;
2122 assert_eq!(
2123 captured_request.lock().clone().unwrap().can_collect_data,
2124 true
2125 );
2126 }
2127
2128 fn init_test(cx: &mut TestAppContext) {
2129 cx.update(|cx| {
2130 let settings_store = SettingsStore::test(cx);
2131 cx.set_global(settings_store);
2132 language::init(cx);
2133 client::init_settings(cx);
2134 Project::init_settings(cx);
2135 });
2136 }
2137
2138 async fn apply_edit_prediction(
2139 buffer_content: &str,
2140 completion_response: &str,
2141 cx: &mut TestAppContext,
2142 ) -> String {
2143 let fs = project::FakeFs::new(cx.executor());
2144 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2145 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2146 let (zeta, _, response) = make_test_zeta(&project, cx).await;
2147 *response.lock() = completion_response.to_string();
2148 let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2149 buffer.update(cx, |buffer, cx| {
2150 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2151 });
2152 buffer.read_with(cx, |buffer, _| buffer.text())
2153 }
2154
2155 async fn run_edit_prediction(
2156 buffer: &Entity<Buffer>,
2157 project: &Entity<Project>,
2158 zeta: &Entity<Zeta>,
2159 cx: &mut TestAppContext,
2160 ) -> EditPrediction {
2161 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2162 zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2163 cx.background_executor.run_until_parked();
2164 let completion_task = zeta.update(cx, |zeta, cx| {
2165 zeta.request_completion(&project, buffer, cursor, cx)
2166 });
2167 completion_task.await.unwrap().unwrap()
2168 }
2169
2170 async fn make_test_zeta(
2171 project: &Entity<Project>,
2172 cx: &mut TestAppContext,
2173 ) -> (
2174 Entity<Zeta>,
2175 Arc<Mutex<Option<PredictEditsBody>>>,
2176 Arc<Mutex<String>>,
2177 ) {
2178 let default_response = indoc! {"
2179 ```main.rs
2180 <|start_of_file|>
2181 <|editable_region_start|>
2182 hello world
2183 <|editable_region_end|>
2184 ```"
2185 };
2186 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2187 let completion_response: Arc<Mutex<String>> =
2188 Arc::new(Mutex::new(default_response.to_string()));
2189 let http_client = FakeHttpClient::create({
2190 let captured_request = captured_request.clone();
2191 let completion_response = completion_response.clone();
2192 move |req| {
2193 let captured_request = captured_request.clone();
2194 let completion_response = completion_response.clone();
2195 async move {
2196 match (req.method(), req.uri().path()) {
2197 (&Method::POST, "/client/llm_tokens") => {
2198 Ok(http_client::Response::builder()
2199 .status(200)
2200 .body(
2201 serde_json::to_string(&CreateLlmTokenResponse {
2202 token: LlmToken("the-llm-token".to_string()),
2203 })
2204 .unwrap()
2205 .into(),
2206 )
2207 .unwrap())
2208 }
2209 (&Method::POST, "/predict_edits/v2") => {
2210 let mut request_body = String::new();
2211 req.into_body().read_to_string(&mut request_body).await?;
2212 *captured_request.lock() =
2213 Some(serde_json::from_str(&request_body).unwrap());
2214 Ok(http_client::Response::builder()
2215 .status(200)
2216 .body(
2217 serde_json::to_string(&PredictEditsResponse {
2218 request_id: Uuid::new_v4(),
2219 output_excerpt: completion_response.lock().clone(),
2220 })
2221 .unwrap()
2222 .into(),
2223 )
2224 .unwrap())
2225 }
2226 _ => Ok(http_client::Response::builder()
2227 .status(404)
2228 .body("Not Found".into())
2229 .unwrap()),
2230 }
2231 }
2232 }
2233 });
2234
2235 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2236 cx.update(|cx| {
2237 RefreshLlmTokenListener::register(client.clone(), cx);
2238 });
2239 let _server = FakeServer::for_client(42, &client, cx).await;
2240
2241 let zeta = cx.new(|cx| {
2242 let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2243
2244 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2245 for worktree in worktrees {
2246 let worktree_id = worktree.read(cx).id();
2247 zeta.license_detection_watchers
2248 .entry(worktree_id)
2249 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2250 }
2251
2252 zeta
2253 });
2254
2255 (zeta, captured_request, completion_response)
2256 }
2257
2258 fn to_completion_edits(
2259 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2260 buffer: &Entity<Buffer>,
2261 cx: &App,
2262 ) -> Vec<(Range<Anchor>, String)> {
2263 let buffer = buffer.read(cx);
2264 iterator
2265 .into_iter()
2266 .map(|(range, text)| {
2267 (
2268 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2269 text,
2270 )
2271 })
2272 .collect()
2273 }
2274
2275 fn from_completion_edits(
2276 editor_edits: &[(Range<Anchor>, String)],
2277 buffer: &Entity<Buffer>,
2278 cx: &App,
2279 ) -> Vec<(Range<usize>, String)> {
2280 let buffer = buffer.read(cx);
2281 editor_edits
2282 .iter()
2283 .map(|(range, text)| {
2284 (
2285 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2286 text.clone(),
2287 )
2288 })
2289 .collect()
2290 }
2291
2292 #[ctor::ctor]
2293 fn init_logger() {
2294 zlog::init_test();
2295 }
2296}