1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use multi_buffer::MultiBufferPoint;
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::str::FromStr;
43use std::{
44 cmp,
45 fmt::Write,
46 future::Future,
47 mem,
48 ops::Range,
49 path::Path,
50 rc::Rc,
51 sync::Arc,
52 time::{Duration, Instant},
53};
54use telemetry_events::EditPredictionRating;
55use thiserror::Error;
56use util::{ResultExt, maybe};
57use uuid::Uuid;
58use workspace::Workspace;
59use workspace::notifications::{ErrorMessagePrompt, NotificationId};
60use worktree::Worktree;
61
62const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
63const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
64const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
65const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
66const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
67const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
68
69const MAX_CONTEXT_TOKENS: usize = 150;
70const MAX_REWRITE_TOKENS: usize = 350;
71const MAX_EVENT_TOKENS: usize = 500;
72
73/// Maximum number of events to track.
74const MAX_EVENT_COUNT: usize = 16;
75
76/// Maximum number of recent files to track.
77const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
78
79/// Minimum number of milliseconds between recent file entries.
80const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
81
82/// Maximum file path length to include in recent files list.
83const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
84
85/// Maximum number of JSON bytes for diagnostics in additional context.
86const MAX_DIAGNOSTICS_BYTES: usize = 4096;
87
88/// Maximum number of edit predictions to store for feedback.
89const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
90
91/// Interval between polls tracking time editing files.
92const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
93
94/// Interval between polls of whether data collection is enabled, when it is disabled.
95const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
96
97actions!(
98 edit_prediction,
99 [
100 /// Clears the edit prediction history.
101 ClearHistory
102 ]
103);
104
105#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
106pub struct EditPredictionId(Uuid);
107
108impl From<EditPredictionId> for gpui::ElementId {
109 fn from(value: EditPredictionId) -> Self {
110 gpui::ElementId::Uuid(value.0)
111 }
112}
113
114impl std::fmt::Display for EditPredictionId {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(f, "{}", self.0)
117 }
118}
119
120struct ZedPredictUpsell;
121
122impl Dismissable for ZedPredictUpsell {
123 const KEY: &'static str = "dismissed-edit-predict-upsell";
124
125 fn dismissed() -> bool {
126 // To make this backwards compatible with older versions of Zed, we
127 // check if the user has seen the previous Edit Prediction Onboarding
128 // before, by checking the data collection choice which was written to
129 // the database once the user clicked on "Accept and Enable"
130 if KEY_VALUE_STORE
131 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
132 .log_err()
133 .is_some_and(|s| s.is_some())
134 {
135 return true;
136 }
137
138 KEY_VALUE_STORE
139 .read_kvp(Self::KEY)
140 .log_err()
141 .is_some_and(|s| s.is_some())
142 }
143}
144
145pub fn should_show_upsell_modal() -> bool {
146 !ZedPredictUpsell::dismissed()
147}
148
149#[derive(Clone)]
150struct ZetaGlobal(Entity<Zeta>);
151
152impl Global for ZetaGlobal {}
153
154#[derive(Clone)]
155pub struct EditPrediction {
156 id: EditPredictionId,
157 path: Arc<Path>,
158 excerpt_range: Range<usize>,
159 cursor_offset: usize,
160 edits: Arc<[(Range<Anchor>, String)]>,
161 snapshot: BufferSnapshot,
162 edit_preview: EditPreview,
163 input_events: Arc<str>,
164 input_excerpt: Arc<str>,
165 output_excerpt: Arc<str>,
166 buffer_snapshotted_at: Instant,
167 response_received_at: Instant,
168}
169
170impl EditPrediction {
171 fn latency(&self) -> Duration {
172 self.response_received_at
173 .duration_since(self.buffer_snapshotted_at)
174 }
175
176 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
177 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
178 }
179}
180
181fn interpolate(
182 old_snapshot: &BufferSnapshot,
183 new_snapshot: &BufferSnapshot,
184 current_edits: Arc<[(Range<Anchor>, String)]>,
185) -> Option<Vec<(Range<Anchor>, String)>> {
186 let mut edits = Vec::new();
187
188 let mut model_edits = current_edits.iter().peekable();
189 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
190 while let Some((model_old_range, _)) = model_edits.peek() {
191 let model_old_range = model_old_range.to_offset(old_snapshot);
192 if model_old_range.end < user_edit.old.start {
193 let (model_old_range, model_new_text) = model_edits.next().unwrap();
194 edits.push((model_old_range.clone(), model_new_text.clone()));
195 } else {
196 break;
197 }
198 }
199
200 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
201 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
202 if user_edit.old == model_old_offset_range {
203 let user_new_text = new_snapshot
204 .text_for_range(user_edit.new.clone())
205 .collect::<String>();
206
207 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
208 if !model_suffix.is_empty() {
209 let anchor = old_snapshot.anchor_after(user_edit.old.end);
210 edits.push((anchor..anchor, model_suffix.to_string()));
211 }
212
213 model_edits.next();
214 continue;
215 }
216 }
217 }
218
219 return None;
220 }
221
222 edits.extend(model_edits.cloned());
223
224 if edits.is_empty() { None } else { Some(edits) }
225}
226
227impl std::fmt::Debug for EditPrediction {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 f.debug_struct("EditPrediction")
230 .field("id", &self.id)
231 .field("path", &self.path)
232 .field("edits", &self.edits)
233 .finish_non_exhaustive()
234 }
235}
236
237pub struct Zeta {
238 workspace: WeakEntity<Workspace>,
239 client: Arc<Client>,
240 events: VecDeque<Event>,
241 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
242 shown_completions: VecDeque<EditPrediction>,
243 rated_completions: HashSet<EditPredictionId>,
244 data_collection_choice: Entity<DataCollectionChoice>,
245 llm_token: LlmApiToken,
246 _llm_token_subscription: Subscription,
247 /// Whether an update to a newer version of Zed is required to continue using Zeta.
248 update_required: bool,
249 user_store: Entity<UserStore>,
250 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
251 recent_editors: VecDeque<RecentEditor>,
252 last_activity_state: Option<ActivityState>,
253 _activity_poll_task: Option<Task<Result<()>>>,
254}
255
256struct RecentEditor {
257 editor: WeakEntity<Editor>,
258 last_active_at: Instant,
259 activation_count: u32,
260 cumulative_time_editing: Duration,
261 cumulative_time_navigating: Duration,
262}
263
264#[derive(Debug)]
265struct ActivityState {
266 scroll_position: gpui::Point<f32>,
267 cursor_point: MultiBufferPoint,
268 singleton_version: Option<clock::Global>,
269}
270
271impl Zeta {
272 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
273 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
274 }
275
276 pub fn register(
277 workspace: Option<Entity<Workspace>>,
278 worktree: Option<Entity<Worktree>>,
279 client: Arc<Client>,
280 user_store: Entity<UserStore>,
281 cx: &mut App,
282 ) -> Entity<Self> {
283 let this = Self::global(cx).unwrap_or_else(|| {
284 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
285 cx.set_global(ZetaGlobal(entity.clone()));
286 entity
287 });
288
289 this.update(cx, move |this, cx| {
290 if let Some(worktree) = worktree {
291 let worktree_id = worktree.read(cx).id();
292 this.license_detection_watchers
293 .entry(worktree_id)
294 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
295 }
296 });
297
298 this
299 }
300
301 pub fn clear_history(&mut self) {
302 self.events.clear();
303 }
304
305 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
306 self.user_store.read(cx).edit_prediction_usage()
307 }
308
309 fn new(
310 workspace: Option<Entity<Workspace>>,
311 client: Arc<Client>,
312 user_store: Entity<UserStore>,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
316
317 let data_collection_choice = Self::load_data_collection_choices();
318 let data_collection_choice = cx.new(|_| data_collection_choice);
319
320 let mut activity_poll_task = None;
321
322 if let Some(workspace) = &workspace {
323 let project = workspace.read(cx).project().clone();
324 cx.subscribe(&project, |this, _project, event, cx| match event {
325 project::Event::ActiveEntryChanged(entry_id) => {
326 this.handle_active_project_entry_changed(cx)
327 }
328 _ => {}
329 })
330 .detach();
331
332 // TODO: ideally this would attend to window focus when tracking time, and pause the
333 // loop for efficiency when not focused.
334 activity_poll_task = Some(cx.spawn(async move |this, cx| {
335 let mut instant_before_delay = None;
336 loop {
337 let data_collection_is_enabled = this.read_with(cx, |this, cx| {
338 this.data_collection_choice.read(cx).is_enabled()
339 })?;
340 let interval = if data_collection_is_enabled {
341 ACTIVITY_POLL_INTERVAL
342 } else {
343 instant_before_delay = None;
344 DISABLED_ACTIVITY_POLL_INTERVAL
345 };
346 cx.background_executor().timer(interval).await;
347 this.update(cx, |this, cx| {
348 let now = Instant::now();
349 this.handle_activity_poll(instant_before_delay, now, cx);
350 instant_before_delay = Some(now);
351 })?
352 }
353 }));
354 }
355
356 Self {
357 workspace: workspace.map_or_else(
358 || WeakEntity::new_invalid(),
359 |workspace| workspace.downgrade(),
360 ),
361 client,
362 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
363 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
364 rated_completions: HashSet::default(),
365 registered_buffers: HashMap::default(),
366 data_collection_choice,
367 llm_token: LlmApiToken::default(),
368 _llm_token_subscription: cx.subscribe(
369 &refresh_llm_token_listener,
370 |this, _listener, _event, cx| {
371 let client = this.client.clone();
372 let llm_token = this.llm_token.clone();
373 cx.spawn(async move |_this, _cx| {
374 llm_token.refresh(&client).await?;
375 anyhow::Ok(())
376 })
377 .detach_and_log_err(cx);
378 },
379 ),
380 update_required: false,
381 license_detection_watchers: HashMap::default(),
382 user_store,
383 recent_editors: VecDeque::new(),
384 last_activity_state: None,
385 _activity_poll_task: activity_poll_task,
386 }
387 }
388
389 fn push_event(&mut self, event: Event) {
390 if let Some(Event::BufferChange {
391 new_snapshot: last_new_snapshot,
392 timestamp: last_timestamp,
393 ..
394 }) = self.events.back_mut()
395 {
396 // Coalesce edits for the same buffer when they happen one after the other.
397 let Event::BufferChange {
398 old_snapshot,
399 new_snapshot,
400 timestamp,
401 } = &event;
402
403 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
404 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
405 && old_snapshot.version == last_new_snapshot.version
406 {
407 *last_new_snapshot = new_snapshot.clone();
408 *last_timestamp = *timestamp;
409 return;
410 }
411 }
412
413 if self.events.len() >= MAX_EVENT_COUNT {
414 // These are halved instead of popping to improve prompt caching.
415 self.events.drain(..MAX_EVENT_COUNT / 2);
416 }
417
418 self.events.push_back(event);
419 }
420
421 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
422 let buffer_id = buffer.entity_id();
423 let weak_buffer = buffer.downgrade();
424
425 if let std::collections::hash_map::Entry::Vacant(entry) =
426 self.registered_buffers.entry(buffer_id)
427 {
428 let snapshot = buffer.read(cx).snapshot();
429
430 entry.insert(RegisteredBuffer {
431 snapshot,
432 _subscriptions: [
433 cx.subscribe(buffer, move |this, buffer, event, cx| {
434 this.handle_buffer_event(buffer, event, cx);
435 }),
436 cx.observe_release(buffer, move |this, _buffer, _cx| {
437 this.registered_buffers.remove(&weak_buffer.entity_id());
438 }),
439 ],
440 });
441 };
442 }
443
444 fn handle_buffer_event(
445 &mut self,
446 buffer: Entity<Buffer>,
447 event: &language::BufferEvent,
448 cx: &mut Context<Self>,
449 ) {
450 if let language::BufferEvent::Edited = event {
451 self.report_changes_for_buffer(&buffer, cx);
452 }
453 }
454
455 fn request_completion_impl<F, R>(
456 &mut self,
457 workspace: Option<Entity<Workspace>>,
458 project: Option<Entity<Project>>,
459 buffer: &Entity<Buffer>,
460 cursor: language::Anchor,
461 can_collect_data: CanCollectData,
462 cx: &mut Context<Self>,
463 perform_predict_edits: F,
464 ) -> Task<Result<Option<EditPrediction>>>
465 where
466 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
467 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
468 + Send
469 + 'static,
470 {
471 let buffer = buffer.clone();
472 let buffer_snapshotted_at = Instant::now();
473 let snapshot = self.report_changes_for_buffer(&buffer, cx);
474 let zeta = cx.entity();
475 let events = self.events.clone();
476 let client = self.client.clone();
477 let llm_token = self.llm_token.clone();
478 let app_version = AppVersion::global(cx);
479
480 let full_path: Arc<Path> = snapshot
481 .file()
482 .map(|f| Arc::from(f.full_path(cx).as_path()))
483 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
484 let full_path_str = full_path.to_string_lossy().to_string();
485 let cursor_point = cursor.to_point(&snapshot);
486 let cursor_offset = cursor_point.to_offset(&snapshot);
487 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
488 let gather_task = cx.background_spawn(gather_context(
489 full_path_str,
490 snapshot.clone(),
491 cursor_point,
492 make_events_prompt,
493 can_collect_data,
494 ));
495
496 cx.spawn(async move |this, cx| {
497 let GatherContextOutput {
498 body,
499 editable_range,
500 } = gather_task.await?;
501 let done_gathering_context_at = Instant::now();
502
503 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
504 && let Some(file) = snapshot.file()
505 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
506 {
507 // This is async to reduce latency of the edit predictions request. The downside is
508 // that it will see a slightly later state than was used when gathering context.
509 let snapshot = snapshot.clone();
510 let this = this.clone();
511 Some(cx.spawn(async move |cx| {
512 if let Ok(Some(task)) = this.update(cx, |this, cx| {
513 this.gather_additional_context(
514 cursor_point,
515 cursor_offset,
516 snapshot,
517 &buffer_snapshotted_at,
518 project_path,
519 project.as_ref(),
520 cx,
521 )
522 }) {
523 Some(task.await)
524 } else {
525 None
526 }
527 }))
528 } else {
529 None
530 };
531
532 log::debug!(
533 "Events:\n{}\nExcerpt:\n{:?}",
534 body.input_events,
535 body.input_excerpt
536 );
537
538 let input_events = body.input_events.clone();
539 let input_excerpt = body.input_excerpt.clone();
540
541 let response = perform_predict_edits(PerformPredictEditsParams {
542 client,
543 llm_token,
544 app_version,
545 body,
546 })
547 .await;
548 let (response, usage) = match response {
549 Ok(response) => response,
550 Err(err) => {
551 if err.is::<ZedUpdateRequiredError>() {
552 cx.update(|cx| {
553 zeta.update(cx, |zeta, _cx| {
554 zeta.update_required = true;
555 });
556
557 if let Some(workspace) = workspace {
558 workspace.update(cx, |workspace, cx| {
559 workspace.show_notification(
560 NotificationId::unique::<ZedUpdateRequiredError>(),
561 cx,
562 |cx| {
563 cx.new(|cx| {
564 ErrorMessagePrompt::new(err.to_string(), cx)
565 .with_link_button(
566 "Update Zed",
567 "https://zed.dev/releases",
568 )
569 })
570 },
571 );
572 });
573 }
574 })
575 .ok();
576 }
577
578 return Err(err);
579 }
580 };
581
582 let received_response_at = Instant::now();
583 log::debug!("completion response: {}", &response.output_excerpt);
584
585 if let Some(usage) = usage {
586 this.update(cx, |this, cx| {
587 this.user_store.update(cx, |user_store, cx| {
588 user_store.update_edit_prediction_usage(usage, cx);
589 });
590 })
591 .ok();
592 }
593
594 let request_id = response.request_id.clone();
595 let edit_prediction = Self::process_completion_response(
596 response,
597 buffer,
598 &snapshot,
599 editable_range,
600 cursor_offset,
601 full_path,
602 input_events,
603 input_excerpt,
604 buffer_snapshotted_at,
605 cx,
606 )
607 .await;
608
609 let finished_at = Instant::now();
610
611 // record latency for ~1% of requests
612 if rand::random::<u8>() <= 2 {
613 telemetry::event!(
614 "Edit Prediction Request",
615 context_latency = done_gathering_context_at
616 .duration_since(buffer_snapshotted_at)
617 .as_millis(),
618 request_latency = received_response_at
619 .duration_since(done_gathering_context_at)
620 .as_millis(),
621 process_latency = finished_at.duration_since(received_response_at).as_millis()
622 );
623 }
624
625 if let Some(additional_context_task) = additional_context_task {
626 cx.background_spawn(async move {
627 if let Some(additional_context) = additional_context_task.await {
628 telemetry::event!(
629 "Edit Prediction Additional Context",
630 request_id = request_id,
631 additional_context = additional_context
632 );
633 }
634 })
635 .detach();
636 }
637
638 edit_prediction
639 })
640 }
641
642 // Generates several example completions of various states to fill the Zeta completion modal
643 #[cfg(any(test, feature = "test-support"))]
644 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
645 use language::Point;
646
647 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
648 And maybe a short line
649
650 Then a few lines
651
652 and then another
653 "#};
654
655 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
656 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
657
658 let completion_tasks = vec![
659 self.fake_completion(
660 None,
661 &buffer,
662 position,
663 PredictEditsResponse {
664 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
665 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
666a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
667[here's an edit]
668And maybe a short line
669Then a few lines
670and then another
671{EDITABLE_REGION_END_MARKER}
672 "),
673 },
674 cx,
675 ),
676 self.fake_completion(
677 None,
678 &buffer,
679 position,
680 PredictEditsResponse {
681 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
682 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
683a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
684And maybe a short line
685[and another edit]
686Then a few lines
687and then another
688{EDITABLE_REGION_END_MARKER}
689 "#),
690 },
691 cx,
692 ),
693 self.fake_completion(
694 None,
695 &buffer,
696 position,
697 PredictEditsResponse {
698 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
699 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
700a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
701And maybe a short line
702
703Then a few lines
704
705and then another
706{EDITABLE_REGION_END_MARKER}
707 "#),
708 },
709 cx,
710 ),
711 self.fake_completion(
712 None,
713 &buffer,
714 position,
715 PredictEditsResponse {
716 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
717 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
718a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
719And maybe a short line
720
721Then a few lines
722
723and then another
724{EDITABLE_REGION_END_MARKER}
725 "#),
726 },
727 cx,
728 ),
729 self.fake_completion(
730 None,
731 &buffer,
732 position,
733 PredictEditsResponse {
734 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
735 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
736a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
737And maybe a short line
738Then a few lines
739[a third completion]
740and then another
741{EDITABLE_REGION_END_MARKER}
742 "#),
743 },
744 cx,
745 ),
746 self.fake_completion(
747 None,
748 &buffer,
749 position,
750 PredictEditsResponse {
751 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
752 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
753a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
754And maybe a short line
755and then another
756[fourth completion example]
757{EDITABLE_REGION_END_MARKER}
758 "#),
759 },
760 cx,
761 ),
762 self.fake_completion(
763 None,
764 &buffer,
765 position,
766 PredictEditsResponse {
767 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
768 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
769a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
770And maybe a short line
771Then a few lines
772and then another
773[fifth and final completion]
774{EDITABLE_REGION_END_MARKER}
775 "#),
776 },
777 cx,
778 ),
779 ];
780
781 cx.spawn(async move |zeta, cx| {
782 for task in completion_tasks {
783 task.await.unwrap();
784 }
785
786 zeta.update(cx, |zeta, _cx| {
787 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
788 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
789 })
790 .ok();
791 })
792 }
793
794 #[cfg(any(test, feature = "test-support"))]
795 pub fn fake_completion(
796 &mut self,
797 project: Option<Entity<Project>>,
798 buffer: &Entity<Buffer>,
799 position: language::Anchor,
800 response: PredictEditsResponse,
801 cx: &mut Context<Self>,
802 ) -> Task<Result<Option<EditPrediction>>> {
803 use std::future::ready;
804
805 self.request_completion_impl(
806 None,
807 project,
808 buffer,
809 position,
810 CanCollectData(false),
811 cx,
812 |_params| ready(Ok((response, None))),
813 )
814 }
815
816 pub fn request_completion(
817 &mut self,
818 project: Option<Entity<Project>>,
819 buffer: &Entity<Buffer>,
820 position: language::Anchor,
821 can_collect_data: CanCollectData,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<Option<EditPrediction>>> {
824 self.request_completion_impl(
825 self.workspace.upgrade(),
826 project,
827 buffer,
828 position,
829 can_collect_data,
830 cx,
831 Self::perform_predict_edits,
832 )
833 }
834
835 pub fn perform_predict_edits(
836 params: PerformPredictEditsParams,
837 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
838 async move {
839 let PerformPredictEditsParams {
840 client,
841 llm_token,
842 app_version,
843 body,
844 ..
845 } = params;
846
847 let http_client = client.http_client();
848 let mut token = llm_token.acquire(&client).await?;
849 let mut did_retry = false;
850
851 loop {
852 let request_builder = http_client::Request::builder().method(Method::POST);
853 let request_builder =
854 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
855 request_builder.uri(predict_edits_url)
856 } else {
857 request_builder.uri(
858 http_client
859 .build_zed_llm_url("/predict_edits/v2", &[])?
860 .as_ref(),
861 )
862 };
863 let request = request_builder
864 .header("Content-Type", "application/json")
865 .header("Authorization", format!("Bearer {}", token))
866 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
867 .body(serde_json::to_string(&body)?.into())?;
868
869 let mut response = http_client.send(request).await?;
870
871 if let Some(minimum_required_version) = response
872 .headers()
873 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
874 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
875 {
876 anyhow::ensure!(
877 app_version >= minimum_required_version,
878 ZedUpdateRequiredError {
879 minimum_version: minimum_required_version
880 }
881 );
882 }
883
884 if response.status().is_success() {
885 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
886
887 let mut body = String::new();
888 response.body_mut().read_to_string(&mut body).await?;
889 return Ok((serde_json::from_str(&body)?, usage));
890 } else if !did_retry
891 && response
892 .headers()
893 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
894 .is_some()
895 {
896 did_retry = true;
897 token = llm_token.refresh(&client).await?;
898 } else {
899 let mut body = String::new();
900 response.body_mut().read_to_string(&mut body).await?;
901 anyhow::bail!(
902 "error predicting edits.\nStatus: {:?}\nBody: {}",
903 response.status(),
904 body
905 );
906 }
907 }
908 }
909 }
910
911 fn accept_edit_prediction(
912 &mut self,
913 request_id: EditPredictionId,
914 cx: &mut Context<Self>,
915 ) -> Task<Result<()>> {
916 let client = self.client.clone();
917 let llm_token = self.llm_token.clone();
918 let app_version = AppVersion::global(cx);
919 cx.spawn(async move |this, cx| {
920 let http_client = client.http_client();
921 let mut response = llm_token_retry(&llm_token, &client, |token| {
922 let request_builder = http_client::Request::builder().method(Method::POST);
923 let request_builder =
924 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
925 request_builder.uri(accept_prediction_url)
926 } else {
927 request_builder.uri(
928 http_client
929 .build_zed_llm_url("/predict_edits/accept", &[])?
930 .as_ref(),
931 )
932 };
933 Ok(request_builder
934 .header("Content-Type", "application/json")
935 .header("Authorization", format!("Bearer {}", token))
936 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
937 .body(
938 serde_json::to_string(&AcceptEditPredictionBody {
939 request_id: request_id.0,
940 })?
941 .into(),
942 )?)
943 })
944 .await?;
945
946 if let Some(minimum_required_version) = response
947 .headers()
948 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
949 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
950 && app_version < minimum_required_version
951 {
952 return Err(anyhow!(ZedUpdateRequiredError {
953 minimum_version: minimum_required_version
954 }));
955 }
956
957 if response.status().is_success() {
958 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
959 this.update(cx, |this, cx| {
960 this.user_store.update(cx, |user_store, cx| {
961 user_store.update_edit_prediction_usage(usage, cx);
962 });
963 })?;
964 }
965
966 Ok(())
967 } else {
968 let mut body = String::new();
969 response.body_mut().read_to_string(&mut body).await?;
970 Err(anyhow!(
971 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
972 response.status(),
973 body
974 ))
975 }
976 })
977 }
978
979 fn process_completion_response(
980 prediction_response: PredictEditsResponse,
981 buffer: Entity<Buffer>,
982 snapshot: &BufferSnapshot,
983 editable_range: Range<usize>,
984 cursor_offset: usize,
985 path: Arc<Path>,
986 input_events: String,
987 input_excerpt: String,
988 buffer_snapshotted_at: Instant,
989 cx: &AsyncApp,
990 ) -> Task<Result<Option<EditPrediction>>> {
991 let snapshot = snapshot.clone();
992 let request_id = prediction_response.request_id;
993 let output_excerpt = prediction_response.output_excerpt;
994 cx.spawn(async move |cx| {
995 let output_excerpt: Arc<str> = output_excerpt.into();
996
997 let edits: Arc<[(Range<Anchor>, String)]> = cx
998 .background_spawn({
999 let output_excerpt = output_excerpt.clone();
1000 let editable_range = editable_range.clone();
1001 let snapshot = snapshot.clone();
1002 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
1003 })
1004 .await?
1005 .into();
1006
1007 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
1008 let edits = edits.clone();
1009 |buffer, cx| {
1010 let new_snapshot = buffer.snapshot();
1011 let edits: Arc<[(Range<Anchor>, String)]> =
1012 interpolate(&snapshot, &new_snapshot, edits)?.into();
1013 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
1014 }
1015 })?
1016 else {
1017 return anyhow::Ok(None);
1018 };
1019
1020 let edit_preview = edit_preview.await;
1021
1022 Ok(Some(EditPrediction {
1023 id: EditPredictionId(request_id),
1024 path,
1025 excerpt_range: editable_range,
1026 cursor_offset,
1027 edits,
1028 edit_preview,
1029 snapshot,
1030 input_events: input_events.into(),
1031 input_excerpt: input_excerpt.into(),
1032 output_excerpt,
1033 buffer_snapshotted_at,
1034 response_received_at: Instant::now(),
1035 }))
1036 })
1037 }
1038
1039 fn parse_edits(
1040 output_excerpt: Arc<str>,
1041 editable_range: Range<usize>,
1042 snapshot: &BufferSnapshot,
1043 ) -> Result<Vec<(Range<Anchor>, String)>> {
1044 let content = output_excerpt.replace(CURSOR_MARKER, "");
1045
1046 let start_markers = content
1047 .match_indices(EDITABLE_REGION_START_MARKER)
1048 .collect::<Vec<_>>();
1049 anyhow::ensure!(
1050 start_markers.len() == 1,
1051 "expected exactly one start marker, found {}",
1052 start_markers.len()
1053 );
1054
1055 let end_markers = content
1056 .match_indices(EDITABLE_REGION_END_MARKER)
1057 .collect::<Vec<_>>();
1058 anyhow::ensure!(
1059 end_markers.len() == 1,
1060 "expected exactly one end marker, found {}",
1061 end_markers.len()
1062 );
1063
1064 let sof_markers = content
1065 .match_indices(START_OF_FILE_MARKER)
1066 .collect::<Vec<_>>();
1067 anyhow::ensure!(
1068 sof_markers.len() <= 1,
1069 "expected at most one start-of-file marker, found {}",
1070 sof_markers.len()
1071 );
1072
1073 let codefence_start = start_markers[0].0;
1074 let content = &content[codefence_start..];
1075
1076 let newline_ix = content.find('\n').context("could not find newline")?;
1077 let content = &content[newline_ix + 1..];
1078
1079 let codefence_end = content
1080 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1081 .context("could not find end marker")?;
1082 let new_text = &content[..codefence_end];
1083
1084 let old_text = snapshot
1085 .text_for_range(editable_range.clone())
1086 .collect::<String>();
1087
1088 Ok(Self::compute_edits(
1089 old_text,
1090 new_text,
1091 editable_range.start,
1092 snapshot,
1093 ))
1094 }
1095
1096 pub fn compute_edits(
1097 old_text: String,
1098 new_text: &str,
1099 offset: usize,
1100 snapshot: &BufferSnapshot,
1101 ) -> Vec<(Range<Anchor>, String)> {
1102 text_diff(&old_text, new_text)
1103 .into_iter()
1104 .map(|(mut old_range, new_text)| {
1105 old_range.start += offset;
1106 old_range.end += offset;
1107
1108 let prefix_len = common_prefix(
1109 snapshot.chars_for_range(old_range.clone()),
1110 new_text.chars(),
1111 );
1112 old_range.start += prefix_len;
1113
1114 let suffix_len = common_prefix(
1115 snapshot.reversed_chars_for_range(old_range.clone()),
1116 new_text[prefix_len..].chars().rev(),
1117 );
1118 old_range.end = old_range.end.saturating_sub(suffix_len);
1119
1120 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1121 let range = if old_range.is_empty() {
1122 let anchor = snapshot.anchor_after(old_range.start);
1123 anchor..anchor
1124 } else {
1125 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1126 };
1127 (range, new_text)
1128 })
1129 .collect()
1130 }
1131
1132 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1133 self.rated_completions.contains(&completion_id)
1134 }
1135
1136 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1137 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1138 let completion = self.shown_completions.pop_back().unwrap();
1139 self.rated_completions.remove(&completion.id);
1140 }
1141 self.shown_completions.push_front(completion.clone());
1142 cx.notify();
1143 }
1144
1145 pub fn rate_completion(
1146 &mut self,
1147 completion: &EditPrediction,
1148 rating: EditPredictionRating,
1149 feedback: String,
1150 cx: &mut Context<Self>,
1151 ) {
1152 self.rated_completions.insert(completion.id);
1153 telemetry::event!(
1154 "Edit Prediction Rated",
1155 rating,
1156 input_events = completion.input_events,
1157 input_excerpt = completion.input_excerpt,
1158 output_excerpt = completion.output_excerpt,
1159 feedback
1160 );
1161 self.client.telemetry().flush_events().detach();
1162 cx.notify();
1163 }
1164
1165 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1166 self.shown_completions.iter()
1167 }
1168
1169 pub fn shown_completions_len(&self) -> usize {
1170 self.shown_completions.len()
1171 }
1172
1173 fn report_changes_for_buffer(
1174 &mut self,
1175 buffer: &Entity<Buffer>,
1176 cx: &mut Context<Self>,
1177 ) -> BufferSnapshot {
1178 self.register_buffer(buffer, cx);
1179
1180 let registered_buffer = self
1181 .registered_buffers
1182 .get_mut(&buffer.entity_id())
1183 .unwrap();
1184 let new_snapshot = buffer.read(cx).snapshot();
1185
1186 if new_snapshot.version != registered_buffer.snapshot.version {
1187 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1188 self.push_event(Event::BufferChange {
1189 old_snapshot,
1190 new_snapshot: new_snapshot.clone(),
1191 timestamp: Instant::now(),
1192 });
1193 }
1194
1195 new_snapshot
1196 }
1197
1198 fn load_data_collection_choices() -> DataCollectionChoice {
1199 let choice = KEY_VALUE_STORE
1200 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1201 .log_err()
1202 .flatten();
1203
1204 match choice.as_deref() {
1205 Some("true") => DataCollectionChoice::Enabled,
1206 Some("false") => DataCollectionChoice::Disabled,
1207 Some(_) => {
1208 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1209 DataCollectionChoice::NotAnswered
1210 }
1211 None => DataCollectionChoice::NotAnswered,
1212 }
1213 }
1214
1215 fn gather_additional_context(
1216 &mut self,
1217 cursor_point: language::Point,
1218 cursor_offset: usize,
1219 snapshot: BufferSnapshot,
1220 buffer_snapshotted_at: &Instant,
1221 project_path: ProjectPath,
1222 project: Option<&Entity<Project>>,
1223 cx: &mut Context<Self>,
1224 ) -> Option<Task<PredictEditsAdditionalContext>> {
1225 let project = project?.read(cx);
1226 let entry = project.entry_for_path(&project_path, cx)?;
1227 if !worktree_entry_is_eligible_for_collection(&entry) {
1228 return None;
1229 }
1230
1231 let git_store = project.git_store().read(cx);
1232 let (repository, repo_path) =
1233 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1234 let repo_path_string = repo_path.to_str()?.to_string();
1235
1236 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1237 snapshot
1238 .diagnostics
1239 .iter()
1240 .filter_map(|(language_server_id, diagnostics)| {
1241 let language_server =
1242 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1243 Some((
1244 *language_server_id,
1245 language_server.name(),
1246 diagnostics.clone(),
1247 ))
1248 })
1249 .collect()
1250 } else {
1251 Vec::new()
1252 };
1253
1254 repository.update(cx, |repository, cx| {
1255 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1256 let remote_origin_url = repository.remote_origin_url.clone();
1257 let remote_upstream_url = repository.remote_upstream_url.clone();
1258 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1259
1260 // group, resolve, and select diagnostics on a background thread
1261 Some(cx.background_spawn(async move {
1262 let mut diagnostic_groups_with_name = Vec::new();
1263 for (language_server_id, language_server_name, diagnostics) in
1264 diagnostics.into_iter()
1265 {
1266 let mut groups = Vec::new();
1267 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1268 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1269 (
1270 language_server_name.clone(),
1271 group.resolve::<usize>(&snapshot),
1272 )
1273 }));
1274 }
1275
1276 // sort by proximity to cursor
1277 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1278 let range = &group.entries[group.primary_ix].range;
1279 if range.start >= cursor_offset {
1280 range.start - cursor_offset
1281 } else if cursor_offset >= range.end {
1282 cursor_offset - range.end
1283 } else {
1284 (cursor_offset - range.start).min(range.end - cursor_offset)
1285 }
1286 });
1287
1288 let mut diagnostic_groups = Vec::new();
1289 let mut diagnostic_groups_truncated = false;
1290 let mut diagnostics_byte_count = 0;
1291 for (name, group) in diagnostic_groups_with_name {
1292 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1293 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1294 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1295 diagnostic_groups_truncated = true;
1296 break;
1297 }
1298 diagnostic_groups.push((name.to_string(), raw_value));
1299 }
1300
1301 PredictEditsAdditionalContext {
1302 input_path: repo_path_string,
1303 cursor_point: to_cloud_llm_client_point(cursor_point),
1304 cursor_offset: cursor_offset,
1305 git_info: PredictEditsGitInfo {
1306 head_sha: Some(head_sha),
1307 remote_origin_url,
1308 remote_upstream_url,
1309 },
1310 diagnostic_groups,
1311 diagnostic_groups_truncated,
1312 recent_files,
1313 }
1314 }))
1315 })
1316 }
1317
1318 fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1319 if !self.data_collection_choice.read(cx).is_enabled() {
1320 self.recent_editors.clear();
1321 self.last_activity_state = None;
1322 return;
1323 }
1324 if let Some(active_editor) = self
1325 .workspace
1326 .read_with(cx, |workspace, cx| {
1327 workspace
1328 .active_item(cx)
1329 .and_then(|item| item.act_as::<Editor>(cx))
1330 })
1331 .ok()
1332 .flatten()
1333 {
1334 let now = Instant::now();
1335 let editor = active_editor.downgrade();
1336 let existing_recent_editor = if let Some(existing_ix) = self
1337 .recent_editors
1338 .iter()
1339 .rposition(|recent| &recent.editor == &editor)
1340 {
1341 if existing_ix + 1 != self.recent_editors.len() {
1342 self.last_activity_state = None;
1343 }
1344 self.recent_editors.remove(existing_ix)
1345 } else {
1346 None
1347 };
1348 let new_recent = RecentEditor {
1349 editor: active_editor.downgrade(),
1350 last_active_at: now,
1351 activation_count: existing_recent_editor
1352 .as_ref()
1353 .map_or(0, |recent| recent.activation_count + 1),
1354 cumulative_time_navigating: existing_recent_editor
1355 .as_ref()
1356 .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1357 cumulative_time_editing: existing_recent_editor
1358 .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1359 };
1360 // filter out rapid changes in active item, particularly since this can happen rapidly when
1361 // a workspace is loaded.
1362 if let Some(previous_recent) = self.recent_editors.back_mut()
1363 && previous_recent.activation_count == 1
1364 && now.duration_since(previous_recent.last_active_at)
1365 < MIN_TIME_BETWEEN_RECENT_FILES
1366 {
1367 *previous_recent = new_recent;
1368 return;
1369 }
1370 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1371 self.recent_editors.pop_front();
1372 }
1373 self.recent_editors.push_back(new_recent);
1374 }
1375 }
1376
1377 fn handle_activity_poll(
1378 &mut self,
1379 instant_before_delay: Option<Instant>,
1380 now: Instant,
1381 cx: &mut Context<Self>,
1382 ) {
1383 if !self.data_collection_choice.read(cx).is_enabled() {
1384 self.last_activity_state = None;
1385 return;
1386 }
1387 if let Some(recent_editor) = self.recent_editors.back()
1388 && let Some(editor) = recent_editor.editor.upgrade()
1389 {
1390 let (scroll_position, cursor_point, singleton_version) =
1391 editor.update(cx, |editor, cx| {
1392 let scroll_position = editor.scroll_position(cx);
1393 let cursor_point = editor.selections.newest(cx).head();
1394 let singleton_version = editor
1395 .buffer()
1396 .read(cx)
1397 .as_singleton()
1398 .map(|singleton_buffer| singleton_buffer.read(cx).version());
1399 (scroll_position, cursor_point, singleton_version)
1400 });
1401
1402 let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1403 last_activity_state.scroll_position != scroll_position
1404 || last_activity_state.cursor_point != cursor_point
1405 } else {
1406 false
1407 };
1408
1409 let edited = if let Some(singleton_version) = &singleton_version
1410 && let Some(last_activity_state) = &self.last_activity_state
1411 && let Some(last_singleton_version) = &last_activity_state.singleton_version
1412 {
1413 singleton_version.changed_since(last_singleton_version)
1414 } else {
1415 false
1416 };
1417
1418 self.last_activity_state = Some(ActivityState {
1419 scroll_position,
1420 cursor_point,
1421 singleton_version,
1422 });
1423
1424 let prior_recent_editor = if self.recent_editors.len() > 1 {
1425 Some(&self.recent_editors[self.recent_editors.len() - 2])
1426 } else {
1427 None
1428 };
1429 let additional_time: Option<Duration> =
1430 instant_before_delay.map(|instant_before_delay| {
1431 now.duration_since(prior_recent_editor.map_or(
1432 instant_before_delay,
1433 |prior_recent_editor| {
1434 prior_recent_editor.last_active_at.max(instant_before_delay)
1435 },
1436 ))
1437 });
1438
1439 if let Some(additional_time) = additional_time {
1440 let recent_editor = self.recent_editors.back_mut().unwrap();
1441 if navigated {
1442 recent_editor.cumulative_time_navigating += additional_time;
1443 }
1444 if edited {
1445 recent_editor.cumulative_time_editing += additional_time;
1446 }
1447 }
1448 }
1449 }
1450
1451 fn recent_files(
1452 &mut self,
1453 now: &Instant,
1454 repository: &Repository,
1455 cx: &mut App,
1456 ) -> Vec<PredictEditsRecentFile> {
1457 let Ok(project) = self
1458 .workspace
1459 .read_with(cx, |workspace, _cx| workspace.project().clone())
1460 else {
1461 return Vec::new();
1462 };
1463 let mut results = Vec::with_capacity(self.recent_editors.len());
1464 for ix in (0..self.recent_editors.len()).rev() {
1465 let recent_editor = &self.recent_editors[ix];
1466 let keep_entry = recent_editor
1467 .editor
1468 .update(cx, |editor, cx| {
1469 maybe!({
1470 let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1471 let multibuffer = editor.buffer().read(cx);
1472 let (buffer, cursor_point, _) =
1473 multibuffer.point_to_buffer_point(cursor, cx)?;
1474 let file = buffer.read(cx).file()?;
1475 if !file_is_eligible_for_collection(file.as_ref()) {
1476 return None;
1477 }
1478 let project_path = ProjectPath {
1479 worktree_id: file.worktree_id(cx),
1480 path: file.path().clone(),
1481 };
1482 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1483 if !worktree_entry_is_eligible_for_collection(entry) {
1484 return None;
1485 }
1486 let Some(repo_path) =
1487 repository.project_path_to_repo_path(&project_path, cx)
1488 else {
1489 // entry not removed since later queries may involve other repositories
1490 return Some(());
1491 };
1492 // paths may not be valid UTF-8
1493 let repo_path_str = repo_path.to_str()?;
1494 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1495 return None;
1496 }
1497 let active_to_now_ms = now
1498 .duration_since(recent_editor.last_active_at)
1499 .as_millis()
1500 .try_into()
1501 .ok()?;
1502 let cumulative_time_editing_ms = recent_editor
1503 .cumulative_time_editing
1504 .as_millis()
1505 .try_into()
1506 .ok()?;
1507 let cumulative_time_navigating_ms = recent_editor
1508 .cumulative_time_navigating
1509 .as_millis()
1510 .try_into()
1511 .ok()?;
1512 results.push(PredictEditsRecentFile {
1513 path: repo_path_str.to_string(),
1514 cursor_point: to_cloud_llm_client_point(cursor_point),
1515 active_to_now_ms,
1516 activation_count: recent_editor.activation_count,
1517 cumulative_time_editing_ms,
1518 cumulative_time_navigating_ms,
1519 is_multibuffer: !multibuffer.is_singleton(),
1520 });
1521 Some(())
1522 })
1523 })
1524 .ok()
1525 .flatten();
1526 if keep_entry.is_none() {
1527 self.recent_editors.remove(ix);
1528 }
1529 }
1530 results
1531 }
1532}
1533
1534fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1535 cloud_llm_client::Point {
1536 row: point.row,
1537 column: point.column,
1538 }
1539}
1540
1541fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1542 file.is_local() && !file.is_private()
1543}
1544
1545fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1546 entry.is_file()
1547 && entry.is_created()
1548 && !entry.is_ignored
1549 && !entry.is_private
1550 && !entry.is_external
1551 && !entry.is_fifo
1552}
1553
1554pub struct PerformPredictEditsParams {
1555 pub client: Arc<Client>,
1556 pub llm_token: LlmApiToken,
1557 pub app_version: SemanticVersion,
1558 pub body: PredictEditsBody,
1559}
1560
1561#[derive(Error, Debug)]
1562#[error(
1563 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1564)]
1565pub struct ZedUpdateRequiredError {
1566 minimum_version: SemanticVersion,
1567}
1568
1569fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1570 a.zip(b)
1571 .take_while(|(a, b)| a == b)
1572 .map(|(a, _)| a.len_utf8())
1573 .sum()
1574}
1575
1576pub struct GatherContextOutput {
1577 pub body: PredictEditsBody,
1578 pub editable_range: Range<usize>,
1579}
1580
1581pub async fn gather_context(
1582 full_path_str: String,
1583 snapshot: BufferSnapshot,
1584 cursor_point: language::Point,
1585 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1586 can_collect_data: CanCollectData,
1587) -> Result<GatherContextOutput> {
1588 let input_excerpt = excerpt_for_cursor_position(
1589 cursor_point,
1590 &full_path_str,
1591 &snapshot,
1592 MAX_REWRITE_TOKENS,
1593 MAX_CONTEXT_TOKENS,
1594 );
1595 let input_events = make_events_prompt();
1596 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1597
1598 let body = PredictEditsBody {
1599 input_events,
1600 input_excerpt: input_excerpt.prompt,
1601 can_collect_data: can_collect_data.0,
1602 diagnostic_groups: None,
1603 git_info: None,
1604 };
1605
1606 Ok(GatherContextOutput {
1607 body,
1608 editable_range,
1609 })
1610}
1611
1612fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1613 let mut result = String::new();
1614 for event in events.iter().rev() {
1615 let event_string = event.to_prompt();
1616 let event_tokens = tokens_for_bytes(event_string.len());
1617 if event_tokens > remaining_tokens {
1618 break;
1619 }
1620
1621 if !result.is_empty() {
1622 result.insert_str(0, "\n\n");
1623 }
1624 result.insert_str(0, &event_string);
1625 remaining_tokens -= event_tokens;
1626 }
1627 result
1628}
1629
1630struct RegisteredBuffer {
1631 snapshot: BufferSnapshot,
1632 _subscriptions: [gpui::Subscription; 2],
1633}
1634
1635#[derive(Clone)]
1636pub enum Event {
1637 BufferChange {
1638 old_snapshot: BufferSnapshot,
1639 new_snapshot: BufferSnapshot,
1640 timestamp: Instant,
1641 },
1642}
1643
1644impl Event {
1645 fn to_prompt(&self) -> String {
1646 match self {
1647 Event::BufferChange {
1648 old_snapshot,
1649 new_snapshot,
1650 ..
1651 } => {
1652 let mut prompt = String::new();
1653
1654 let old_path = old_snapshot
1655 .file()
1656 .map(|f| f.path().as_ref())
1657 .unwrap_or(Path::new("untitled"));
1658 let new_path = new_snapshot
1659 .file()
1660 .map(|f| f.path().as_ref())
1661 .unwrap_or(Path::new("untitled"));
1662 if old_path != new_path {
1663 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1664 }
1665
1666 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1667 if !diff.is_empty() {
1668 write!(
1669 prompt,
1670 "User edited {:?}:\n```diff\n{}\n```",
1671 new_path, diff
1672 )
1673 .unwrap();
1674 }
1675
1676 prompt
1677 }
1678 }
1679 }
1680}
1681
1682#[derive(Debug, Clone)]
1683struct CurrentEditPrediction {
1684 buffer_id: EntityId,
1685 completion: EditPrediction,
1686}
1687
1688impl CurrentEditPrediction {
1689 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1690 if self.buffer_id != old_completion.buffer_id {
1691 return true;
1692 }
1693
1694 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1695 return true;
1696 };
1697 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1698 return false;
1699 };
1700
1701 if old_edits.len() == 1 && new_edits.len() == 1 {
1702 let (old_range, old_text) = &old_edits[0];
1703 let (new_range, new_text) = &new_edits[0];
1704 new_range == old_range && new_text.starts_with(old_text)
1705 } else {
1706 true
1707 }
1708 }
1709}
1710
1711struct PendingCompletion {
1712 id: usize,
1713 _task: Task<()>,
1714}
1715
1716#[derive(Debug, Clone, Copy)]
1717pub enum DataCollectionChoice {
1718 NotAnswered,
1719 Enabled,
1720 Disabled,
1721}
1722
1723impl DataCollectionChoice {
1724 pub fn is_enabled(self) -> bool {
1725 match self {
1726 Self::Enabled => true,
1727 Self::NotAnswered | Self::Disabled => false,
1728 }
1729 }
1730
1731 pub fn is_answered(self) -> bool {
1732 match self {
1733 Self::Enabled | Self::Disabled => true,
1734 Self::NotAnswered => false,
1735 }
1736 }
1737
1738 pub fn toggle(&self) -> DataCollectionChoice {
1739 match self {
1740 Self::Enabled => Self::Disabled,
1741 Self::Disabled => Self::Enabled,
1742 Self::NotAnswered => Self::Enabled,
1743 }
1744 }
1745}
1746
1747impl From<bool> for DataCollectionChoice {
1748 fn from(value: bool) -> Self {
1749 match value {
1750 true => DataCollectionChoice::Enabled,
1751 false => DataCollectionChoice::Disabled,
1752 }
1753 }
1754}
1755
1756pub struct ProviderDataCollection {
1757 /// When set to None, data collection is not possible in the provider buffer
1758 choice: Option<Entity<DataCollectionChoice>>,
1759 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1760}
1761
1762#[derive(Debug, Clone, Copy)]
1763pub struct CanCollectData(pub bool);
1764
1765impl ProviderDataCollection {
1766 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1767 let choice_and_watcher = buffer.and_then(|buffer| {
1768 let file = buffer.read(cx).file()?;
1769
1770 if !file_is_eligible_for_collection(file.as_ref()) {
1771 return None;
1772 }
1773
1774 let zeta = zeta.read(cx);
1775 let choice = zeta.data_collection_choice.clone();
1776
1777 let license_detection_watcher = zeta
1778 .license_detection_watchers
1779 .get(&file.worktree_id(cx))
1780 .cloned()?;
1781
1782 Some((choice, license_detection_watcher))
1783 });
1784
1785 if let Some((choice, watcher)) = choice_and_watcher {
1786 ProviderDataCollection {
1787 choice: Some(choice),
1788 license_detection_watcher: Some(watcher),
1789 }
1790 } else {
1791 ProviderDataCollection {
1792 choice: None,
1793 license_detection_watcher: None,
1794 }
1795 }
1796 }
1797
1798 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1799 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1800 }
1801
1802 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1803 self.choice
1804 .as_ref()
1805 .is_some_and(|choice| choice.read(cx).is_enabled())
1806 }
1807
1808 fn is_project_open_source(&self) -> bool {
1809 self.license_detection_watcher
1810 .as_ref()
1811 .is_some_and(|watcher| watcher.is_project_open_source())
1812 }
1813
1814 pub fn toggle(&mut self, cx: &mut App) {
1815 if let Some(choice) = self.choice.as_mut() {
1816 let new_choice = choice.update(cx, |choice, _cx| {
1817 let new_choice = choice.toggle();
1818 *choice = new_choice;
1819 new_choice
1820 });
1821
1822 db::write_and_log(cx, move || {
1823 KEY_VALUE_STORE.write_kvp(
1824 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1825 new_choice.is_enabled().to_string(),
1826 )
1827 });
1828 }
1829 }
1830}
1831
1832async fn llm_token_retry(
1833 llm_token: &LlmApiToken,
1834 client: &Arc<Client>,
1835 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1836) -> Result<Response<AsyncBody>> {
1837 let mut did_retry = false;
1838 let http_client = client.http_client();
1839 let mut token = llm_token.acquire(client).await?;
1840 loop {
1841 let request = build_request(token.clone())?;
1842 let response = http_client.send(request).await?;
1843
1844 if !did_retry
1845 && !response.status().is_success()
1846 && response
1847 .headers()
1848 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1849 .is_some()
1850 {
1851 did_retry = true;
1852 token = llm_token.refresh(client).await?;
1853 continue;
1854 }
1855
1856 return Ok(response);
1857 }
1858}
1859
1860pub struct ZetaEditPredictionProvider {
1861 zeta: Entity<Zeta>,
1862 pending_completions: ArrayVec<PendingCompletion, 2>,
1863 next_pending_completion_id: usize,
1864 current_completion: Option<CurrentEditPrediction>,
1865 /// None if this is entirely disabled for this provider
1866 provider_data_collection: ProviderDataCollection,
1867 last_request_timestamp: Instant,
1868}
1869
1870impl ZetaEditPredictionProvider {
1871 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1872
1873 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1874 Self {
1875 zeta,
1876 pending_completions: ArrayVec::new(),
1877 next_pending_completion_id: 0,
1878 current_completion: None,
1879 provider_data_collection,
1880 last_request_timestamp: Instant::now(),
1881 }
1882 }
1883}
1884
1885impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1886 fn name() -> &'static str {
1887 "zed-predict"
1888 }
1889
1890 fn display_name() -> &'static str {
1891 "Zed's Edit Predictions"
1892 }
1893
1894 fn show_completions_in_menu() -> bool {
1895 true
1896 }
1897
1898 fn show_tab_accept_marker() -> bool {
1899 true
1900 }
1901
1902 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1903 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1904
1905 if self.provider_data_collection.is_data_collection_enabled(cx) {
1906 DataCollectionState::Enabled {
1907 is_project_open_source,
1908 }
1909 } else {
1910 DataCollectionState::Disabled {
1911 is_project_open_source,
1912 }
1913 }
1914 }
1915
1916 fn toggle_data_collection(&mut self, cx: &mut App) {
1917 self.provider_data_collection.toggle(cx);
1918 }
1919
1920 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1921 self.zeta.read(cx).usage(cx)
1922 }
1923
1924 fn is_enabled(
1925 &self,
1926 _buffer: &Entity<Buffer>,
1927 _cursor_position: language::Anchor,
1928 _cx: &App,
1929 ) -> bool {
1930 true
1931 }
1932 fn is_refreshing(&self) -> bool {
1933 !self.pending_completions.is_empty()
1934 }
1935
1936 fn refresh(
1937 &mut self,
1938 project: Option<Entity<Project>>,
1939 buffer: Entity<Buffer>,
1940 position: language::Anchor,
1941 _debounce: bool,
1942 cx: &mut Context<Self>,
1943 ) {
1944 if self.zeta.read(cx).update_required {
1945 return;
1946 }
1947
1948 if self
1949 .zeta
1950 .read(cx)
1951 .user_store
1952 .read_with(cx, |user_store, _cx| {
1953 user_store.account_too_young() || user_store.has_overdue_invoices()
1954 })
1955 {
1956 return;
1957 }
1958
1959 if let Some(current_completion) = self.current_completion.as_ref() {
1960 let snapshot = buffer.read(cx).snapshot();
1961 if current_completion
1962 .completion
1963 .interpolate(&snapshot)
1964 .is_some()
1965 {
1966 return;
1967 }
1968 }
1969
1970 let pending_completion_id = self.next_pending_completion_id;
1971 self.next_pending_completion_id += 1;
1972 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1973 let last_request_timestamp = self.last_request_timestamp;
1974
1975 let task = cx.spawn(async move |this, cx| {
1976 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1977 .checked_duration_since(Instant::now())
1978 {
1979 cx.background_executor().timer(timeout).await;
1980 }
1981
1982 let completion_request = this.update(cx, |this, cx| {
1983 this.last_request_timestamp = Instant::now();
1984 this.zeta.update(cx, |zeta, cx| {
1985 zeta.request_completion(project, &buffer, position, can_collect_data, cx)
1986 })
1987 });
1988
1989 let completion = match completion_request {
1990 Ok(completion_request) => {
1991 let completion_request = completion_request.await;
1992 completion_request.map(|c| {
1993 c.map(|completion| CurrentEditPrediction {
1994 buffer_id: buffer.entity_id(),
1995 completion,
1996 })
1997 })
1998 }
1999 Err(error) => Err(error),
2000 };
2001 let Some(new_completion) = completion
2002 .context("edit prediction failed")
2003 .log_err()
2004 .flatten()
2005 else {
2006 this.update(cx, |this, cx| {
2007 if this.pending_completions[0].id == pending_completion_id {
2008 this.pending_completions.remove(0);
2009 } else {
2010 this.pending_completions.clear();
2011 }
2012
2013 cx.notify();
2014 })
2015 .ok();
2016 return;
2017 };
2018
2019 this.update(cx, |this, cx| {
2020 if this.pending_completions[0].id == pending_completion_id {
2021 this.pending_completions.remove(0);
2022 } else {
2023 this.pending_completions.clear();
2024 }
2025
2026 if let Some(old_completion) = this.current_completion.as_ref() {
2027 let snapshot = buffer.read(cx).snapshot();
2028 if new_completion.should_replace_completion(old_completion, &snapshot) {
2029 this.zeta.update(cx, |zeta, cx| {
2030 zeta.completion_shown(&new_completion.completion, cx);
2031 });
2032 this.current_completion = Some(new_completion);
2033 }
2034 } else {
2035 this.zeta.update(cx, |zeta, cx| {
2036 zeta.completion_shown(&new_completion.completion, cx);
2037 });
2038 this.current_completion = Some(new_completion);
2039 }
2040
2041 cx.notify();
2042 })
2043 .ok();
2044 });
2045
2046 // We always maintain at most two pending completions. When we already
2047 // have two, we replace the newest one.
2048 if self.pending_completions.len() <= 1 {
2049 self.pending_completions.push(PendingCompletion {
2050 id: pending_completion_id,
2051 _task: task,
2052 });
2053 } else if self.pending_completions.len() == 2 {
2054 self.pending_completions.pop();
2055 self.pending_completions.push(PendingCompletion {
2056 id: pending_completion_id,
2057 _task: task,
2058 });
2059 }
2060 }
2061
2062 fn cycle(
2063 &mut self,
2064 _buffer: Entity<Buffer>,
2065 _cursor_position: language::Anchor,
2066 _direction: edit_prediction::Direction,
2067 _cx: &mut Context<Self>,
2068 ) {
2069 // Right now we don't support cycling.
2070 }
2071
2072 fn accept(&mut self, cx: &mut Context<Self>) {
2073 let completion_id = self
2074 .current_completion
2075 .as_ref()
2076 .map(|completion| completion.completion.id);
2077 if let Some(completion_id) = completion_id {
2078 self.zeta
2079 .update(cx, |zeta, cx| {
2080 zeta.accept_edit_prediction(completion_id, cx)
2081 })
2082 .detach();
2083 }
2084 self.pending_completions.clear();
2085 }
2086
2087 fn discard(&mut self, _cx: &mut Context<Self>) {
2088 self.pending_completions.clear();
2089 self.current_completion.take();
2090 }
2091
2092 fn suggest(
2093 &mut self,
2094 buffer: &Entity<Buffer>,
2095 cursor_position: language::Anchor,
2096 cx: &mut Context<Self>,
2097 ) -> Option<edit_prediction::EditPrediction> {
2098 let CurrentEditPrediction {
2099 buffer_id,
2100 completion,
2101 ..
2102 } = self.current_completion.as_mut()?;
2103
2104 // Invalidate previous completion if it was generated for a different buffer.
2105 if *buffer_id != buffer.entity_id() {
2106 self.current_completion.take();
2107 return None;
2108 }
2109
2110 let buffer = buffer.read(cx);
2111 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2112 self.current_completion.take();
2113 return None;
2114 };
2115
2116 let cursor_row = cursor_position.to_point(buffer).row;
2117 let (closest_edit_ix, (closest_edit_range, _)) =
2118 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2119 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2120 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2121 cmp::min(distance_from_start, distance_from_end)
2122 })?;
2123
2124 let mut edit_start_ix = closest_edit_ix;
2125 for (range, _) in edits[..edit_start_ix].iter().rev() {
2126 let distance_from_closest_edit =
2127 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2128 if distance_from_closest_edit <= 1 {
2129 edit_start_ix -= 1;
2130 } else {
2131 break;
2132 }
2133 }
2134
2135 let mut edit_end_ix = closest_edit_ix + 1;
2136 for (range, _) in &edits[edit_end_ix..] {
2137 let distance_from_closest_edit =
2138 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2139 if distance_from_closest_edit <= 1 {
2140 edit_end_ix += 1;
2141 } else {
2142 break;
2143 }
2144 }
2145
2146 Some(edit_prediction::EditPrediction {
2147 id: Some(completion.id.to_string().into()),
2148 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2149 edit_preview: Some(completion.edit_preview.clone()),
2150 })
2151 }
2152}
2153
2154fn tokens_for_bytes(bytes: usize) -> usize {
2155 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2156 /// intentionally low to err on the side of underestimating limits.
2157 const BYTES_PER_TOKEN_GUESS: usize = 3;
2158 bytes / BYTES_PER_TOKEN_GUESS
2159}
2160
2161#[cfg(test)]
2162mod tests {
2163 use client::UserStore;
2164 use client::test::FakeServer;
2165 use clock::FakeSystemClock;
2166 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2167 use gpui::TestAppContext;
2168 use http_client::FakeHttpClient;
2169 use indoc::indoc;
2170 use language::Point;
2171 use settings::SettingsStore;
2172
2173 use super::*;
2174
2175 #[gpui::test]
2176 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2177 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2178 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2179 to_completion_edits(
2180 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2181 &buffer,
2182 cx,
2183 )
2184 .into()
2185 });
2186
2187 let edit_preview = cx
2188 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2189 .await;
2190
2191 let completion = EditPrediction {
2192 edits,
2193 edit_preview,
2194 path: Path::new("").into(),
2195 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2196 id: EditPredictionId(Uuid::new_v4()),
2197 excerpt_range: 0..0,
2198 cursor_offset: 0,
2199 input_events: "".into(),
2200 input_excerpt: "".into(),
2201 output_excerpt: "".into(),
2202 buffer_snapshotted_at: Instant::now(),
2203 response_received_at: Instant::now(),
2204 };
2205
2206 cx.update(|cx| {
2207 assert_eq!(
2208 from_completion_edits(
2209 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2210 &buffer,
2211 cx
2212 ),
2213 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2214 );
2215
2216 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2217 assert_eq!(
2218 from_completion_edits(
2219 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2220 &buffer,
2221 cx
2222 ),
2223 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2224 );
2225
2226 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2227 assert_eq!(
2228 from_completion_edits(
2229 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2230 &buffer,
2231 cx
2232 ),
2233 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2234 );
2235
2236 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2237 assert_eq!(
2238 from_completion_edits(
2239 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2240 &buffer,
2241 cx
2242 ),
2243 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2244 );
2245
2246 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2247 assert_eq!(
2248 from_completion_edits(
2249 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2250 &buffer,
2251 cx
2252 ),
2253 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2254 );
2255
2256 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2257 assert_eq!(
2258 from_completion_edits(
2259 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2260 &buffer,
2261 cx
2262 ),
2263 vec![(9..11, "".to_string())]
2264 );
2265
2266 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2267 assert_eq!(
2268 from_completion_edits(
2269 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2270 &buffer,
2271 cx
2272 ),
2273 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2274 );
2275
2276 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2277 assert_eq!(
2278 from_completion_edits(
2279 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2280 &buffer,
2281 cx
2282 ),
2283 vec![(4..4, "M".to_string())]
2284 );
2285
2286 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2287 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2288 })
2289 }
2290
2291 #[gpui::test]
2292 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2293 cx.update(|cx| {
2294 let settings_store = SettingsStore::test(cx);
2295 cx.set_global(settings_store);
2296 client::init_settings(cx);
2297 });
2298
2299 let edits = edits_for_prediction(
2300 indoc! {"
2301 fn main() {
2302 let word_1 = \"lorem\";
2303 let range = word.len()..word.len();
2304 }
2305 "},
2306 indoc! {"
2307 <|editable_region_start|>
2308 fn main() {
2309 let word_1 = \"lorem\";
2310 let range = word_1.len()..word_1.len();
2311 }
2312
2313 <|editable_region_end|>
2314 "},
2315 cx,
2316 )
2317 .await;
2318 assert_eq!(
2319 edits,
2320 [
2321 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2322 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2323 ]
2324 );
2325
2326 let edits = edits_for_prediction(
2327 indoc! {"
2328 fn main() {
2329 let story = \"the quick\"
2330 }
2331 "},
2332 indoc! {"
2333 <|editable_region_start|>
2334 fn main() {
2335 let story = \"the quick brown fox jumps over the lazy dog\";
2336 }
2337
2338 <|editable_region_end|>
2339 "},
2340 cx,
2341 )
2342 .await;
2343 assert_eq!(
2344 edits,
2345 [
2346 (
2347 Point::new(1, 26)..Point::new(1, 26),
2348 " brown fox jumps over the lazy dog".to_string()
2349 ),
2350 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2351 ]
2352 );
2353 }
2354
2355 #[gpui::test]
2356 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2357 cx.update(|cx| {
2358 let settings_store = SettingsStore::test(cx);
2359 cx.set_global(settings_store);
2360 client::init_settings(cx);
2361 });
2362
2363 let buffer_content = "lorem\n";
2364 let completion_response = indoc! {"
2365 ```animals.js
2366 <|start_of_file|>
2367 <|editable_region_start|>
2368 lorem
2369 ipsum
2370 <|editable_region_end|>
2371 ```"};
2372
2373 let http_client = FakeHttpClient::create(move |req| async move {
2374 match (req.method(), req.uri().path()) {
2375 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2376 .status(200)
2377 .body(
2378 serde_json::to_string(&CreateLlmTokenResponse {
2379 token: LlmToken("the-llm-token".to_string()),
2380 })
2381 .unwrap()
2382 .into(),
2383 )
2384 .unwrap()),
2385 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2386 .status(200)
2387 .body(
2388 serde_json::to_string(&PredictEditsResponse {
2389 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2390 .unwrap(),
2391 output_excerpt: completion_response.to_string(),
2392 })
2393 .unwrap()
2394 .into(),
2395 )
2396 .unwrap()),
2397 _ => Ok(http_client::Response::builder()
2398 .status(404)
2399 .body("Not Found".into())
2400 .unwrap()),
2401 }
2402 });
2403
2404 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2405 cx.update(|cx| {
2406 RefreshLlmTokenListener::register(client.clone(), cx);
2407 });
2408 // Construct the fake server to authenticate.
2409 let _server = FakeServer::for_client(42, &client, cx).await;
2410 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2411 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2412
2413 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2414 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2415 let completion_task = zeta.update(cx, |zeta, cx| {
2416 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2417 });
2418
2419 let completion = completion_task.await.unwrap().unwrap();
2420 buffer.update(cx, |buffer, cx| {
2421 buffer.edit(completion.edits.iter().cloned(), None, cx)
2422 });
2423 assert_eq!(
2424 buffer.read_with(cx, |buffer, _| buffer.text()),
2425 "lorem\nipsum"
2426 );
2427 }
2428
2429 async fn edits_for_prediction(
2430 buffer_content: &str,
2431 completion_response: &str,
2432 cx: &mut TestAppContext,
2433 ) -> Vec<(Range<Point>, String)> {
2434 let completion_response = completion_response.to_string();
2435 let http_client = FakeHttpClient::create(move |req| {
2436 let completion = completion_response.clone();
2437 async move {
2438 match (req.method(), req.uri().path()) {
2439 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2440 .status(200)
2441 .body(
2442 serde_json::to_string(&CreateLlmTokenResponse {
2443 token: LlmToken("the-llm-token".to_string()),
2444 })
2445 .unwrap()
2446 .into(),
2447 )
2448 .unwrap()),
2449 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2450 .status(200)
2451 .body(
2452 serde_json::to_string(&PredictEditsResponse {
2453 request_id: Uuid::new_v4(),
2454 output_excerpt: completion,
2455 })
2456 .unwrap()
2457 .into(),
2458 )
2459 .unwrap()),
2460 _ => Ok(http_client::Response::builder()
2461 .status(404)
2462 .body("Not Found".into())
2463 .unwrap()),
2464 }
2465 }
2466 });
2467
2468 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2469 cx.update(|cx| {
2470 RefreshLlmTokenListener::register(client.clone(), cx);
2471 });
2472 // Construct the fake server to authenticate.
2473 let _server = FakeServer::for_client(42, &client, cx).await;
2474 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2475 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2476
2477 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2478 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2479 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2480 let completion_task = zeta.update(cx, |zeta, cx| {
2481 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2482 });
2483
2484 let completion = completion_task.await.unwrap().unwrap();
2485 completion
2486 .edits
2487 .iter()
2488 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2489 .collect::<Vec<_>>()
2490 }
2491
2492 fn to_completion_edits(
2493 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2494 buffer: &Entity<Buffer>,
2495 cx: &App,
2496 ) -> Vec<(Range<Anchor>, String)> {
2497 let buffer = buffer.read(cx);
2498 iterator
2499 .into_iter()
2500 .map(|(range, text)| {
2501 (
2502 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2503 text,
2504 )
2505 })
2506 .collect()
2507 }
2508
2509 fn from_completion_edits(
2510 editor_edits: &[(Range<Anchor>, String)],
2511 buffer: &Entity<Buffer>,
2512 cx: &App,
2513 ) -> Vec<(Range<usize>, String)> {
2514 let buffer = buffer.read(cx);
2515 editor_edits
2516 .iter()
2517 .map(|(range, text)| {
2518 (
2519 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2520 text.clone(),
2521 )
2522 })
2523 .collect()
2524 }
2525
2526 #[ctor::ctor]
2527 fn init_logger() {
2528 zlog::init_test();
2529 }
2530}