1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 SharedString, Subscription, Task, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use multi_buffer::MultiBufferPoint;
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::str::FromStr;
43use std::{
44 cmp,
45 fmt::Write,
46 future::Future,
47 mem,
48 ops::Range,
49 path::Path,
50 rc::Rc,
51 sync::Arc,
52 time::{Duration, Instant},
53};
54use telemetry_events::EditPredictionRating;
55use thiserror::Error;
56use util::{ResultExt, maybe};
57use uuid::Uuid;
58use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
59use worktree::Worktree;
60
61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Minimum number of milliseconds between recent file entries.
79const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
80
81/// Maximum file path length to include in recent files list.
82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
83
84/// Maximum number of JSON bytes for diagnostics in additional context.
85const MAX_DIAGNOSTICS_BYTES: usize = 4096;
86
87/// Maximum number of edit predictions to store for feedback.
88const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
89
90/// Interval between polls tracking time editing files.
91const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
92
93/// Interval between polls of whether data collection is enabled, when it is disabled.
94const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
95
96actions!(
97 edit_prediction,
98 [
99 /// Clears the edit prediction history.
100 ClearHistory
101 ]
102);
103
104#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
105pub struct EditPredictionId(Uuid);
106
107impl From<EditPredictionId> for gpui::ElementId {
108 fn from(value: EditPredictionId) -> Self {
109 gpui::ElementId::Uuid(value.0)
110 }
111}
112
113impl std::fmt::Display for EditPredictionId {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 write!(f, "{}", self.0)
116 }
117}
118
119struct ZedPredictUpsell;
120
121impl Dismissable for ZedPredictUpsell {
122 const KEY: &'static str = "dismissed-edit-predict-upsell";
123
124 fn dismissed() -> bool {
125 // To make this backwards compatible with older versions of Zed, we
126 // check if the user has seen the previous Edit Prediction Onboarding
127 // before, by checking the data collection choice which was written to
128 // the database once the user clicked on "Accept and Enable"
129 if KEY_VALUE_STORE
130 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
131 .log_err()
132 .is_some_and(|s| s.is_some())
133 {
134 return true;
135 }
136
137 KEY_VALUE_STORE
138 .read_kvp(Self::KEY)
139 .log_err()
140 .is_some_and(|s| s.is_some())
141 }
142}
143
144pub fn should_show_upsell_modal() -> bool {
145 !ZedPredictUpsell::dismissed()
146}
147
148#[derive(Clone)]
149struct ZetaGlobal(Entity<Zeta>);
150
151impl Global for ZetaGlobal {}
152
153#[derive(Clone)]
154pub struct EditPrediction {
155 id: EditPredictionId,
156 path: Arc<Path>,
157 excerpt_range: Range<usize>,
158 cursor_offset: usize,
159 edits: Arc<[(Range<Anchor>, String)]>,
160 snapshot: BufferSnapshot,
161 edit_preview: EditPreview,
162 input_events: Arc<str>,
163 input_excerpt: Arc<str>,
164 output_excerpt: Arc<str>,
165 buffer_snapshotted_at: Instant,
166 response_received_at: Instant,
167}
168
169impl EditPrediction {
170 fn latency(&self) -> Duration {
171 self.response_received_at
172 .duration_since(self.buffer_snapshotted_at)
173 }
174
175 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
176 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
177 }
178}
179
180fn interpolate(
181 old_snapshot: &BufferSnapshot,
182 new_snapshot: &BufferSnapshot,
183 current_edits: Arc<[(Range<Anchor>, String)]>,
184) -> Option<Vec<(Range<Anchor>, String)>> {
185 let mut edits = Vec::new();
186
187 let mut model_edits = current_edits.iter().peekable();
188 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
189 while let Some((model_old_range, _)) = model_edits.peek() {
190 let model_old_range = model_old_range.to_offset(old_snapshot);
191 if model_old_range.end < user_edit.old.start {
192 let (model_old_range, model_new_text) = model_edits.next().unwrap();
193 edits.push((model_old_range.clone(), model_new_text.clone()));
194 } else {
195 break;
196 }
197 }
198
199 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
200 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
201 if user_edit.old == model_old_offset_range {
202 let user_new_text = new_snapshot
203 .text_for_range(user_edit.new.clone())
204 .collect::<String>();
205
206 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
207 if !model_suffix.is_empty() {
208 let anchor = old_snapshot.anchor_after(user_edit.old.end);
209 edits.push((anchor..anchor, model_suffix.to_string()));
210 }
211
212 model_edits.next();
213 continue;
214 }
215 }
216 }
217
218 return None;
219 }
220
221 edits.extend(model_edits.cloned());
222
223 if edits.is_empty() { None } else { Some(edits) }
224}
225
226impl std::fmt::Debug for EditPrediction {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("EditPrediction")
229 .field("id", &self.id)
230 .field("path", &self.path)
231 .field("edits", &self.edits)
232 .finish_non_exhaustive()
233 }
234}
235
236pub struct Zeta {
237<<<<<<< HEAD
238 workspace: WeakEntity<Workspace>,
239=======
240>>>>>>> main
241 client: Arc<Client>,
242 events: VecDeque<Event>,
243 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
244 shown_completions: VecDeque<EditPrediction>,
245 rated_completions: HashSet<EditPredictionId>,
246 data_collection_choice: Entity<DataCollectionChoice>,
247 llm_token: LlmApiToken,
248 _llm_token_subscription: Subscription,
249 /// Whether an update to a newer version of Zed is required to continue using Zeta.
250 update_required: bool,
251 user_store: Entity<UserStore>,
252 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
253 recent_editors: VecDeque<RecentEditor>,
254 last_activity_state: Option<ActivityState>,
255 _activity_poll_task: Option<Task<Result<()>>>,
256}
257
258struct RecentEditor {
259 editor: WeakEntity<Editor>,
260 last_active_at: Instant,
261 activation_count: u32,
262 cumulative_time_editing: Duration,
263 cumulative_time_navigating: Duration,
264}
265
266#[derive(Debug)]
267struct ActivityState {
268 scroll_position: gpui::Point<f32>,
269 cursor_point: MultiBufferPoint,
270 singleton_version: Option<clock::Global>,
271}
272
273impl Zeta {
274 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
275 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
276 }
277
278 pub fn register(
279<<<<<<< HEAD
280 workspace: Option<Entity<Workspace>>,
281=======
282>>>>>>> main
283 worktree: Option<Entity<Worktree>>,
284 client: Arc<Client>,
285 user_store: Entity<UserStore>,
286 cx: &mut App,
287 ) -> Entity<Self> {
288 let this = Self::global(cx).unwrap_or_else(|| {
289 let entity = cx.new(|cx| Self::new(client, user_store, cx));
290 cx.set_global(ZetaGlobal(entity.clone()));
291 entity
292 });
293
294 this.update(cx, move |this, cx| {
295 if let Some(worktree) = worktree {
296 let worktree_id = worktree.read(cx).id();
297 this.license_detection_watchers
298 .entry(worktree_id)
299 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
300 }
301 });
302
303 this
304 }
305
306 pub fn clear_history(&mut self) {
307 self.events.clear();
308 }
309
310 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
311 self.user_store.read(cx).edit_prediction_usage()
312 }
313
314<<<<<<< HEAD
315 fn new(
316 workspace: Option<Entity<Workspace>>,
317 client: Arc<Client>,
318 user_store: Entity<UserStore>,
319 cx: &mut Context<Self>,
320 ) -> Self {
321=======
322 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
323>>>>>>> main
324 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
325
326 let data_collection_choice = Self::load_data_collection_choices();
327 let data_collection_choice = cx.new(|_| data_collection_choice);
328
329 let mut activity_poll_task = None;
330
331 if let Some(workspace) = &workspace {
332 let project = workspace.read(cx).project().clone();
333 cx.subscribe(&project, |this, _project, event, cx| match event {
334 project::Event::ActiveEntryChanged(entry_id) => {
335 this.handle_active_project_entry_changed(cx)
336 }
337 _ => {}
338 })
339 .detach();
340
341 // TODO: ideally this would attend to window focus when tracking time, and pause the
342 // loop for efficiency when not focused.
343 activity_poll_task = Some(cx.spawn(async move |this, cx| {
344 let mut instant_before_delay = None;
345 loop {
346 let data_collection_is_enabled = this.read_with(cx, |this, cx| {
347 this.data_collection_choice.read(cx).is_enabled()
348 })?;
349 let interval = if data_collection_is_enabled {
350 ACTIVITY_POLL_INTERVAL
351 } else {
352 instant_before_delay = None;
353 DISABLED_ACTIVITY_POLL_INTERVAL
354 };
355 cx.background_executor().timer(interval).await;
356 this.update(cx, |this, cx| {
357 let now = Instant::now();
358 this.handle_activity_poll(instant_before_delay, now, cx);
359 instant_before_delay = Some(now);
360 })?
361 }
362 }));
363 }
364
365 Self {
366<<<<<<< HEAD
367 workspace: workspace.map_or_else(
368 || WeakEntity::new_invalid(),
369 |workspace| workspace.downgrade(),
370 ),
371=======
372>>>>>>> main
373 client,
374 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
375 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
376 rated_completions: HashSet::default(),
377 registered_buffers: HashMap::default(),
378 data_collection_choice,
379 llm_token: LlmApiToken::default(),
380 _llm_token_subscription: cx.subscribe(
381 &refresh_llm_token_listener,
382 |this, _listener, _event, cx| {
383 let client = this.client.clone();
384 let llm_token = this.llm_token.clone();
385 cx.spawn(async move |_this, _cx| {
386 llm_token.refresh(&client).await?;
387 anyhow::Ok(())
388 })
389 .detach_and_log_err(cx);
390 },
391 ),
392 update_required: false,
393 license_detection_watchers: HashMap::default(),
394 user_store,
395 recent_editors: VecDeque::new(),
396 last_activity_state: None,
397 _activity_poll_task: activity_poll_task,
398 }
399 }
400
401 fn push_event(&mut self, event: Event) {
402 if let Some(Event::BufferChange {
403 new_snapshot: last_new_snapshot,
404 timestamp: last_timestamp,
405 ..
406 }) = self.events.back_mut()
407 {
408 // Coalesce edits for the same buffer when they happen one after the other.
409 let Event::BufferChange {
410 old_snapshot,
411 new_snapshot,
412 timestamp,
413 } = &event;
414
415 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
416 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
417 && old_snapshot.version == last_new_snapshot.version
418 {
419 *last_new_snapshot = new_snapshot.clone();
420 *last_timestamp = *timestamp;
421 return;
422 }
423 }
424
425 if self.events.len() >= MAX_EVENT_COUNT {
426 // These are halved instead of popping to improve prompt caching.
427 self.events.drain(..MAX_EVENT_COUNT / 2);
428 }
429
430 self.events.push_back(event);
431 }
432
433 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
434 let buffer_id = buffer.entity_id();
435 let weak_buffer = buffer.downgrade();
436
437 if let std::collections::hash_map::Entry::Vacant(entry) =
438 self.registered_buffers.entry(buffer_id)
439 {
440 let snapshot = buffer.read(cx).snapshot();
441
442 entry.insert(RegisteredBuffer {
443 snapshot,
444 _subscriptions: [
445 cx.subscribe(buffer, move |this, buffer, event, cx| {
446 this.handle_buffer_event(buffer, event, cx);
447 }),
448 cx.observe_release(buffer, move |this, _buffer, _cx| {
449 this.registered_buffers.remove(&weak_buffer.entity_id());
450 }),
451 ],
452 });
453 };
454 }
455
456 fn handle_buffer_event(
457 &mut self,
458 buffer: Entity<Buffer>,
459 event: &language::BufferEvent,
460 cx: &mut Context<Self>,
461 ) {
462 if let language::BufferEvent::Edited = event {
463 self.report_changes_for_buffer(&buffer, cx);
464 }
465 }
466
467 fn request_completion_impl<F, R>(
468 &mut self,
469<<<<<<< HEAD
470 workspace: Option<Entity<Workspace>>,
471 project: Option<Entity<Project>>,
472=======
473 project: Option<&Entity<Project>>,
474>>>>>>> main
475 buffer: &Entity<Buffer>,
476 cursor: language::Anchor,
477 can_collect_data: CanCollectData,
478 cx: &mut Context<Self>,
479 perform_predict_edits: F,
480 ) -> Task<Result<Option<EditPrediction>>>
481 where
482 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
483 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
484 + Send
485 + 'static,
486 {
487 let buffer = buffer.clone();
488 let buffer_snapshotted_at = Instant::now();
489 let snapshot = self.report_changes_for_buffer(&buffer, cx);
490 let zeta = cx.entity();
491 let events = self.events.clone();
492 let client = self.client.clone();
493 let llm_token = self.llm_token.clone();
494 let app_version = AppVersion::global(cx);
495
496 let full_path: Arc<Path> = snapshot
497 .file()
498 .map(|f| Arc::from(f.full_path(cx).as_path()))
499 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
500 let full_path_str = full_path.to_string_lossy().to_string();
501 let cursor_point = cursor.to_point(&snapshot);
502 let cursor_offset = cursor_point.to_offset(&snapshot);
503 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
504 let gather_task = cx.background_spawn(gather_context(
505 full_path_str,
506 snapshot.clone(),
507 cursor_point,
508 make_events_prompt,
509 can_collect_data,
510 ));
511
512 cx.spawn(async move |this, cx| {
513 let GatherContextOutput {
514 body,
515 editable_range,
516 } = gather_task.await?;
517 let done_gathering_context_at = Instant::now();
518
519 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
520 && let Some(file) = snapshot.file()
521 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
522 {
523 // This is async to reduce latency of the edit predictions request. The downside is
524 // that it will see a slightly later state than was used when gathering context.
525 let snapshot = snapshot.clone();
526 let this = this.clone();
527 Some(cx.spawn(async move |cx| {
528 if let Ok(Some(task)) = this.update(cx, |this, cx| {
529 this.gather_additional_context(
530 cursor_point,
531 cursor_offset,
532 snapshot,
533 &buffer_snapshotted_at,
534 project_path,
535 project.as_ref(),
536 cx,
537 )
538 }) {
539 Some(task.await)
540 } else {
541 None
542 }
543 }))
544 } else {
545 None
546 };
547
548 log::debug!(
549 "Events:\n{}\nExcerpt:\n{:?}",
550 body.input_events,
551 body.input_excerpt
552 );
553
554 let input_events = body.input_events.clone();
555 let input_excerpt = body.input_excerpt.clone();
556
557 let response = perform_predict_edits(PerformPredictEditsParams {
558 client,
559 llm_token,
560 app_version,
561 body,
562 })
563 .await;
564 let (response, usage) = match response {
565 Ok(response) => response,
566 Err(err) => {
567 if err.is::<ZedUpdateRequiredError>() {
568 cx.update(|cx| {
569 zeta.update(cx, |zeta, _cx| {
570 zeta.update_required = true;
571 });
572
573 let error_message: SharedString = err.to_string().into();
574 show_app_notification(
575 NotificationId::unique::<ZedUpdateRequiredError>(),
576 cx,
577 move |cx| {
578 cx.new(|cx| {
579 ErrorMessagePrompt::new(error_message.clone(), cx)
580 .with_link_button(
581 "Update Zed",
582 "https://zed.dev/releases",
583 )
584 })
585 },
586 );
587 })
588 .ok();
589 }
590
591 return Err(err);
592 }
593 };
594
595 let received_response_at = Instant::now();
596 log::debug!("completion response: {}", &response.output_excerpt);
597
598 if let Some(usage) = usage {
599 this.update(cx, |this, cx| {
600 this.user_store.update(cx, |user_store, cx| {
601 user_store.update_edit_prediction_usage(usage, cx);
602 });
603 })
604 .ok();
605 }
606
607 let request_id = response.request_id.clone();
608 let edit_prediction = Self::process_completion_response(
609 response,
610 buffer,
611 &snapshot,
612 editable_range,
613 cursor_offset,
614 full_path,
615 input_events,
616 input_excerpt,
617 buffer_snapshotted_at,
618 cx,
619 )
620 .await;
621
622 let finished_at = Instant::now();
623
624 // record latency for ~1% of requests
625 if rand::random::<u8>() <= 2 {
626 telemetry::event!(
627 "Edit Prediction Request",
628 context_latency = done_gathering_context_at
629 .duration_since(buffer_snapshotted_at)
630 .as_millis(),
631 request_latency = received_response_at
632 .duration_since(done_gathering_context_at)
633 .as_millis(),
634 process_latency = finished_at.duration_since(received_response_at).as_millis()
635 );
636 }
637
638 if let Some(additional_context_task) = additional_context_task {
639 cx.background_spawn(async move {
640 if let Some(additional_context) = additional_context_task.await {
641 telemetry::event!(
642 "Edit Prediction Additional Context",
643 request_id = request_id,
644 additional_context = additional_context
645 );
646 }
647 })
648 .detach();
649 }
650
651 edit_prediction
652 })
653 }
654
655 // Generates several example completions of various states to fill the Zeta completion modal
656 #[cfg(any(test, feature = "test-support"))]
657 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
658 use language::Point;
659
660 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
661 And maybe a short line
662
663 Then a few lines
664
665 and then another
666 "#};
667
668 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
669 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
670
671 let completion_tasks = vec![
672 self.fake_completion(
673 None,
674 &buffer,
675 position,
676 PredictEditsResponse {
677 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
678 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
679a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
680[here's an edit]
681And maybe a short line
682Then a few lines
683and then another
684{EDITABLE_REGION_END_MARKER}
685 "),
686 },
687 cx,
688 ),
689 self.fake_completion(
690 None,
691 &buffer,
692 position,
693 PredictEditsResponse {
694 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
695 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
696a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
697And maybe a short line
698[and another edit]
699Then a few lines
700and then another
701{EDITABLE_REGION_END_MARKER}
702 "#),
703 },
704 cx,
705 ),
706 self.fake_completion(
707 None,
708 &buffer,
709 position,
710 PredictEditsResponse {
711 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
712 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
713a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
714And maybe a short line
715
716Then a few lines
717
718and then another
719{EDITABLE_REGION_END_MARKER}
720 "#),
721 },
722 cx,
723 ),
724 self.fake_completion(
725 None,
726 &buffer,
727 position,
728 PredictEditsResponse {
729 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
730 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
731a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
732And maybe a short line
733
734Then a few lines
735
736and then another
737{EDITABLE_REGION_END_MARKER}
738 "#),
739 },
740 cx,
741 ),
742 self.fake_completion(
743 None,
744 &buffer,
745 position,
746 PredictEditsResponse {
747 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
748 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
749a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
750And maybe a short line
751Then a few lines
752[a third completion]
753and then another
754{EDITABLE_REGION_END_MARKER}
755 "#),
756 },
757 cx,
758 ),
759 self.fake_completion(
760 None,
761 &buffer,
762 position,
763 PredictEditsResponse {
764 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
765 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
766a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
767And maybe a short line
768and then another
769[fourth completion example]
770{EDITABLE_REGION_END_MARKER}
771 "#),
772 },
773 cx,
774 ),
775 self.fake_completion(
776 None,
777 &buffer,
778 position,
779 PredictEditsResponse {
780 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
781 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
782a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
783And maybe a short line
784Then a few lines
785and then another
786[fifth and final completion]
787{EDITABLE_REGION_END_MARKER}
788 "#),
789 },
790 cx,
791 ),
792 ];
793
794 cx.spawn(async move |zeta, cx| {
795 for task in completion_tasks {
796 task.await.unwrap();
797 }
798
799 zeta.update(cx, |zeta, _cx| {
800 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
801 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
802 })
803 .ok();
804 })
805 }
806
807 #[cfg(any(test, feature = "test-support"))]
808 pub fn fake_completion(
809 &mut self,
810 project: Option<Entity<Project>>,
811 buffer: &Entity<Buffer>,
812 position: language::Anchor,
813 response: PredictEditsResponse,
814 cx: &mut Context<Self>,
815 ) -> Task<Result<Option<EditPrediction>>> {
816 use std::future::ready;
817
818<<<<<<< HEAD
819 self.request_completion_impl(
820 None,
821 project,
822 buffer,
823 position,
824 CanCollectData(false),
825 cx,
826 |_params| ready(Ok((response, None))),
827 )
828=======
829 self.request_completion_impl(project, buffer, position, false, cx, |_params| {
830 ready(Ok((response, None)))
831 })
832>>>>>>> main
833 }
834
835 pub fn request_completion(
836 &mut self,
837 project: Option<Entity<Project>>,
838 buffer: &Entity<Buffer>,
839 position: language::Anchor,
840 can_collect_data: CanCollectData,
841 cx: &mut Context<Self>,
842 ) -> Task<Result<Option<EditPrediction>>> {
843 self.request_completion_impl(
844<<<<<<< HEAD
845 self.workspace.upgrade(),
846=======
847>>>>>>> main
848 project,
849 buffer,
850 position,
851 can_collect_data,
852 cx,
853 Self::perform_predict_edits,
854 )
855 }
856
857 pub fn perform_predict_edits(
858 params: PerformPredictEditsParams,
859 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
860 async move {
861 let PerformPredictEditsParams {
862 client,
863 llm_token,
864 app_version,
865 body,
866 ..
867 } = params;
868
869 let http_client = client.http_client();
870 let mut token = llm_token.acquire(&client).await?;
871 let mut did_retry = false;
872
873 loop {
874 let request_builder = http_client::Request::builder().method(Method::POST);
875 let request_builder =
876 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
877 request_builder.uri(predict_edits_url)
878 } else {
879 request_builder.uri(
880 http_client
881 .build_zed_llm_url("/predict_edits/v2", &[])?
882 .as_ref(),
883 )
884 };
885 let request = request_builder
886 .header("Content-Type", "application/json")
887 .header("Authorization", format!("Bearer {}", token))
888 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
889 .body(serde_json::to_string(&body)?.into())?;
890
891 let mut response = http_client.send(request).await?;
892
893 if let Some(minimum_required_version) = response
894 .headers()
895 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
896 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
897 {
898 anyhow::ensure!(
899 app_version >= minimum_required_version,
900 ZedUpdateRequiredError {
901 minimum_version: minimum_required_version
902 }
903 );
904 }
905
906 if response.status().is_success() {
907 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
908
909 let mut body = String::new();
910 response.body_mut().read_to_string(&mut body).await?;
911 return Ok((serde_json::from_str(&body)?, usage));
912 } else if !did_retry
913 && response
914 .headers()
915 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
916 .is_some()
917 {
918 did_retry = true;
919 token = llm_token.refresh(&client).await?;
920 } else {
921 let mut body = String::new();
922 response.body_mut().read_to_string(&mut body).await?;
923 anyhow::bail!(
924 "error predicting edits.\nStatus: {:?}\nBody: {}",
925 response.status(),
926 body
927 );
928 }
929 }
930 }
931 }
932
933 fn accept_edit_prediction(
934 &mut self,
935 request_id: EditPredictionId,
936 cx: &mut Context<Self>,
937 ) -> Task<Result<()>> {
938 let client = self.client.clone();
939 let llm_token = self.llm_token.clone();
940 let app_version = AppVersion::global(cx);
941 cx.spawn(async move |this, cx| {
942 let http_client = client.http_client();
943 let mut response = llm_token_retry(&llm_token, &client, |token| {
944 let request_builder = http_client::Request::builder().method(Method::POST);
945 let request_builder =
946 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
947 request_builder.uri(accept_prediction_url)
948 } else {
949 request_builder.uri(
950 http_client
951 .build_zed_llm_url("/predict_edits/accept", &[])?
952 .as_ref(),
953 )
954 };
955 Ok(request_builder
956 .header("Content-Type", "application/json")
957 .header("Authorization", format!("Bearer {}", token))
958 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
959 .body(
960 serde_json::to_string(&AcceptEditPredictionBody {
961 request_id: request_id.0,
962 })?
963 .into(),
964 )?)
965 })
966 .await?;
967
968 if let Some(minimum_required_version) = response
969 .headers()
970 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
971 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
972 && app_version < minimum_required_version
973 {
974 return Err(anyhow!(ZedUpdateRequiredError {
975 minimum_version: minimum_required_version
976 }));
977 }
978
979 if response.status().is_success() {
980 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
981 this.update(cx, |this, cx| {
982 this.user_store.update(cx, |user_store, cx| {
983 user_store.update_edit_prediction_usage(usage, cx);
984 });
985 })?;
986 }
987
988 Ok(())
989 } else {
990 let mut body = String::new();
991 response.body_mut().read_to_string(&mut body).await?;
992 Err(anyhow!(
993 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
994 response.status(),
995 body
996 ))
997 }
998 })
999 }
1000
1001 fn process_completion_response(
1002 prediction_response: PredictEditsResponse,
1003 buffer: Entity<Buffer>,
1004 snapshot: &BufferSnapshot,
1005 editable_range: Range<usize>,
1006 cursor_offset: usize,
1007 path: Arc<Path>,
1008 input_events: String,
1009 input_excerpt: String,
1010 buffer_snapshotted_at: Instant,
1011 cx: &AsyncApp,
1012 ) -> Task<Result<Option<EditPrediction>>> {
1013 let snapshot = snapshot.clone();
1014 let request_id = prediction_response.request_id;
1015 let output_excerpt = prediction_response.output_excerpt;
1016 cx.spawn(async move |cx| {
1017 let output_excerpt: Arc<str> = output_excerpt.into();
1018
1019 let edits: Arc<[(Range<Anchor>, String)]> = cx
1020 .background_spawn({
1021 let output_excerpt = output_excerpt.clone();
1022 let editable_range = editable_range.clone();
1023 let snapshot = snapshot.clone();
1024 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
1025 })
1026 .await?
1027 .into();
1028
1029 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
1030 let edits = edits.clone();
1031 |buffer, cx| {
1032 let new_snapshot = buffer.snapshot();
1033 let edits: Arc<[(Range<Anchor>, String)]> =
1034 interpolate(&snapshot, &new_snapshot, edits)?.into();
1035 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
1036 }
1037 })?
1038 else {
1039 return anyhow::Ok(None);
1040 };
1041
1042 let edit_preview = edit_preview.await;
1043
1044 Ok(Some(EditPrediction {
1045 id: EditPredictionId(request_id),
1046 path,
1047 excerpt_range: editable_range,
1048 cursor_offset,
1049 edits,
1050 edit_preview,
1051 snapshot,
1052 input_events: input_events.into(),
1053 input_excerpt: input_excerpt.into(),
1054 output_excerpt,
1055 buffer_snapshotted_at,
1056 response_received_at: Instant::now(),
1057 }))
1058 })
1059 }
1060
1061 fn parse_edits(
1062 output_excerpt: Arc<str>,
1063 editable_range: Range<usize>,
1064 snapshot: &BufferSnapshot,
1065 ) -> Result<Vec<(Range<Anchor>, String)>> {
1066 let content = output_excerpt.replace(CURSOR_MARKER, "");
1067
1068 let start_markers = content
1069 .match_indices(EDITABLE_REGION_START_MARKER)
1070 .collect::<Vec<_>>();
1071 anyhow::ensure!(
1072 start_markers.len() == 1,
1073 "expected exactly one start marker, found {}",
1074 start_markers.len()
1075 );
1076
1077 let end_markers = content
1078 .match_indices(EDITABLE_REGION_END_MARKER)
1079 .collect::<Vec<_>>();
1080 anyhow::ensure!(
1081 end_markers.len() == 1,
1082 "expected exactly one end marker, found {}",
1083 end_markers.len()
1084 );
1085
1086 let sof_markers = content
1087 .match_indices(START_OF_FILE_MARKER)
1088 .collect::<Vec<_>>();
1089 anyhow::ensure!(
1090 sof_markers.len() <= 1,
1091 "expected at most one start-of-file marker, found {}",
1092 sof_markers.len()
1093 );
1094
1095 let codefence_start = start_markers[0].0;
1096 let content = &content[codefence_start..];
1097
1098 let newline_ix = content.find('\n').context("could not find newline")?;
1099 let content = &content[newline_ix + 1..];
1100
1101 let codefence_end = content
1102 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1103 .context("could not find end marker")?;
1104 let new_text = &content[..codefence_end];
1105
1106 let old_text = snapshot
1107 .text_for_range(editable_range.clone())
1108 .collect::<String>();
1109
1110 Ok(Self::compute_edits(
1111 old_text,
1112 new_text,
1113 editable_range.start,
1114 snapshot,
1115 ))
1116 }
1117
1118 pub fn compute_edits(
1119 old_text: String,
1120 new_text: &str,
1121 offset: usize,
1122 snapshot: &BufferSnapshot,
1123 ) -> Vec<(Range<Anchor>, String)> {
1124 text_diff(&old_text, new_text)
1125 .into_iter()
1126 .map(|(mut old_range, new_text)| {
1127 old_range.start += offset;
1128 old_range.end += offset;
1129
1130 let prefix_len = common_prefix(
1131 snapshot.chars_for_range(old_range.clone()),
1132 new_text.chars(),
1133 );
1134 old_range.start += prefix_len;
1135
1136 let suffix_len = common_prefix(
1137 snapshot.reversed_chars_for_range(old_range.clone()),
1138 new_text[prefix_len..].chars().rev(),
1139 );
1140 old_range.end = old_range.end.saturating_sub(suffix_len);
1141
1142 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1143 let range = if old_range.is_empty() {
1144 let anchor = snapshot.anchor_after(old_range.start);
1145 anchor..anchor
1146 } else {
1147 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1148 };
1149 (range, new_text)
1150 })
1151 .collect()
1152 }
1153
1154 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1155 self.rated_completions.contains(&completion_id)
1156 }
1157
1158 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1159 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1160 let completion = self.shown_completions.pop_back().unwrap();
1161 self.rated_completions.remove(&completion.id);
1162 }
1163 self.shown_completions.push_front(completion.clone());
1164 cx.notify();
1165 }
1166
1167 pub fn rate_completion(
1168 &mut self,
1169 completion: &EditPrediction,
1170 rating: EditPredictionRating,
1171 feedback: String,
1172 cx: &mut Context<Self>,
1173 ) {
1174 self.rated_completions.insert(completion.id);
1175 telemetry::event!(
1176 "Edit Prediction Rated",
1177 rating,
1178 input_events = completion.input_events,
1179 input_excerpt = completion.input_excerpt,
1180 output_excerpt = completion.output_excerpt,
1181 feedback
1182 );
1183 self.client.telemetry().flush_events().detach();
1184 cx.notify();
1185 }
1186
1187 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1188 self.shown_completions.iter()
1189 }
1190
1191 pub fn shown_completions_len(&self) -> usize {
1192 self.shown_completions.len()
1193 }
1194
1195 fn report_changes_for_buffer(
1196 &mut self,
1197 buffer: &Entity<Buffer>,
1198 cx: &mut Context<Self>,
1199 ) -> BufferSnapshot {
1200 self.register_buffer(buffer, cx);
1201
1202 let registered_buffer = self
1203 .registered_buffers
1204 .get_mut(&buffer.entity_id())
1205 .unwrap();
1206 let new_snapshot = buffer.read(cx).snapshot();
1207
1208 if new_snapshot.version != registered_buffer.snapshot.version {
1209 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1210 self.push_event(Event::BufferChange {
1211 old_snapshot,
1212 new_snapshot: new_snapshot.clone(),
1213 timestamp: Instant::now(),
1214 });
1215 }
1216
1217 new_snapshot
1218 }
1219
1220 fn load_data_collection_choices() -> DataCollectionChoice {
1221 let choice = KEY_VALUE_STORE
1222 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1223 .log_err()
1224 .flatten();
1225
1226 match choice.as_deref() {
1227 Some("true") => DataCollectionChoice::Enabled,
1228 Some("false") => DataCollectionChoice::Disabled,
1229 Some(_) => {
1230 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1231 DataCollectionChoice::NotAnswered
1232 }
1233 None => DataCollectionChoice::NotAnswered,
1234 }
1235 }
1236
1237 fn gather_additional_context(
1238 &mut self,
1239 cursor_point: language::Point,
1240 cursor_offset: usize,
1241 snapshot: BufferSnapshot,
1242 buffer_snapshotted_at: &Instant,
1243 project_path: ProjectPath,
1244 project: Option<&Entity<Project>>,
1245 cx: &mut Context<Self>,
1246 ) -> Option<Task<PredictEditsAdditionalContext>> {
1247 let project = project?.read(cx);
1248 let entry = project.entry_for_path(&project_path, cx)?;
1249 if !worktree_entry_is_eligible_for_collection(&entry) {
1250 return None;
1251 }
1252
1253 let git_store = project.git_store().read(cx);
1254 let (repository, repo_path) =
1255 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1256 let repo_path_string = repo_path.to_str()?.to_string();
1257
1258 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1259 snapshot
1260 .diagnostics
1261 .iter()
1262 .filter_map(|(language_server_id, diagnostics)| {
1263 let language_server =
1264 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1265 Some((
1266 *language_server_id,
1267 language_server.name(),
1268 diagnostics.clone(),
1269 ))
1270 })
1271 .collect()
1272 } else {
1273 Vec::new()
1274 };
1275
1276 repository.update(cx, |repository, cx| {
1277 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1278 let remote_origin_url = repository.remote_origin_url.clone();
1279 let remote_upstream_url = repository.remote_upstream_url.clone();
1280 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1281
1282 // group, resolve, and select diagnostics on a background thread
1283 Some(cx.background_spawn(async move {
1284 let mut diagnostic_groups_with_name = Vec::new();
1285 for (language_server_id, language_server_name, diagnostics) in
1286 diagnostics.into_iter()
1287 {
1288 let mut groups = Vec::new();
1289 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1290 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1291 (
1292 language_server_name.clone(),
1293 group.resolve::<usize>(&snapshot),
1294 )
1295 }));
1296 }
1297
1298 // sort by proximity to cursor
1299 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1300 let range = &group.entries[group.primary_ix].range;
1301 if range.start >= cursor_offset {
1302 range.start - cursor_offset
1303 } else if cursor_offset >= range.end {
1304 cursor_offset - range.end
1305 } else {
1306 (cursor_offset - range.start).min(range.end - cursor_offset)
1307 }
1308 });
1309
1310 let mut diagnostic_groups = Vec::new();
1311 let mut diagnostic_groups_truncated = false;
1312 let mut diagnostics_byte_count = 0;
1313 for (name, group) in diagnostic_groups_with_name {
1314 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1315 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1316 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1317 diagnostic_groups_truncated = true;
1318 break;
1319 }
1320 diagnostic_groups.push((name.to_string(), raw_value));
1321 }
1322
1323 PredictEditsAdditionalContext {
1324 input_path: repo_path_string,
1325 cursor_point: to_cloud_llm_client_point(cursor_point),
1326 cursor_offset: cursor_offset,
1327 git_info: PredictEditsGitInfo {
1328 head_sha: Some(head_sha),
1329 remote_origin_url,
1330 remote_upstream_url,
1331 },
1332 diagnostic_groups,
1333 diagnostic_groups_truncated,
1334 recent_files,
1335 }
1336 }))
1337 })
1338 }
1339
1340 fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1341 if !self.data_collection_choice.read(cx).is_enabled() {
1342 self.recent_editors.clear();
1343 self.last_activity_state = None;
1344 return;
1345 }
1346 if let Some(active_editor) = self
1347 .workspace
1348 .read_with(cx, |workspace, cx| {
1349 workspace
1350 .active_item(cx)
1351 .and_then(|item| item.act_as::<Editor>(cx))
1352 })
1353 .ok()
1354 .flatten()
1355 {
1356 let now = Instant::now();
1357 let editor = active_editor.downgrade();
1358 let existing_recent_editor = if let Some(existing_ix) = self
1359 .recent_editors
1360 .iter()
1361 .rposition(|recent| &recent.editor == &editor)
1362 {
1363 if existing_ix + 1 != self.recent_editors.len() {
1364 self.last_activity_state = None;
1365 }
1366 self.recent_editors.remove(existing_ix)
1367 } else {
1368 None
1369 };
1370 let new_recent = RecentEditor {
1371 editor: active_editor.downgrade(),
1372 last_active_at: now,
1373 activation_count: existing_recent_editor
1374 .as_ref()
1375 .map_or(0, |recent| recent.activation_count + 1),
1376 cumulative_time_navigating: existing_recent_editor
1377 .as_ref()
1378 .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1379 cumulative_time_editing: existing_recent_editor
1380 .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1381 };
1382 // filter out rapid changes in active item, particularly since this can happen rapidly when
1383 // a workspace is loaded.
1384 if let Some(previous_recent) = self.recent_editors.back_mut()
1385 && previous_recent.activation_count == 1
1386 && now.duration_since(previous_recent.last_active_at)
1387 < MIN_TIME_BETWEEN_RECENT_FILES
1388 {
1389 *previous_recent = new_recent;
1390 return;
1391 }
1392 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1393 self.recent_editors.pop_front();
1394 }
1395 self.recent_editors.push_back(new_recent);
1396 }
1397 }
1398
1399 fn handle_activity_poll(
1400 &mut self,
1401 instant_before_delay: Option<Instant>,
1402 now: Instant,
1403 cx: &mut Context<Self>,
1404 ) {
1405 if !self.data_collection_choice.read(cx).is_enabled() {
1406 self.last_activity_state = None;
1407 return;
1408 }
1409 if let Some(recent_editor) = self.recent_editors.back()
1410 && let Some(editor) = recent_editor.editor.upgrade()
1411 {
1412 let (scroll_position, cursor_point, singleton_version) =
1413 editor.update(cx, |editor, cx| {
1414 let scroll_position = editor.scroll_position(cx);
1415 let cursor_point = editor.selections.newest(cx).head();
1416 let singleton_version = editor
1417 .buffer()
1418 .read(cx)
1419 .as_singleton()
1420 .map(|singleton_buffer| singleton_buffer.read(cx).version());
1421 (scroll_position, cursor_point, singleton_version)
1422 });
1423
1424 let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1425 last_activity_state.scroll_position != scroll_position
1426 || last_activity_state.cursor_point != cursor_point
1427 } else {
1428 false
1429 };
1430
1431 let edited = if let Some(singleton_version) = &singleton_version
1432 && let Some(last_activity_state) = &self.last_activity_state
1433 && let Some(last_singleton_version) = &last_activity_state.singleton_version
1434 {
1435 singleton_version.changed_since(last_singleton_version)
1436 } else {
1437 false
1438 };
1439
1440 self.last_activity_state = Some(ActivityState {
1441 scroll_position,
1442 cursor_point,
1443 singleton_version,
1444 });
1445
1446 let prior_recent_editor = if self.recent_editors.len() > 1 {
1447 Some(&self.recent_editors[self.recent_editors.len() - 2])
1448 } else {
1449 None
1450 };
1451 let additional_time: Option<Duration> =
1452 instant_before_delay.map(|instant_before_delay| {
1453 now.duration_since(prior_recent_editor.map_or(
1454 instant_before_delay,
1455 |prior_recent_editor| {
1456 prior_recent_editor.last_active_at.max(instant_before_delay)
1457 },
1458 ))
1459 });
1460
1461 if let Some(additional_time) = additional_time {
1462 let recent_editor = self.recent_editors.back_mut().unwrap();
1463 if navigated {
1464 recent_editor.cumulative_time_navigating += additional_time;
1465 }
1466 if edited {
1467 recent_editor.cumulative_time_editing += additional_time;
1468 }
1469 }
1470 }
1471 }
1472
1473 fn recent_files(
1474 &mut self,
1475 now: &Instant,
1476 repository: &Repository,
1477 cx: &mut App,
1478 ) -> Vec<PredictEditsRecentFile> {
1479 let Ok(project) = self
1480 .workspace
1481 .read_with(cx, |workspace, _cx| workspace.project().clone())
1482 else {
1483 return Vec::new();
1484 };
1485 let mut results = Vec::with_capacity(self.recent_editors.len());
1486 for ix in (0..self.recent_editors.len()).rev() {
1487 let recent_editor = &self.recent_editors[ix];
1488 let keep_entry = recent_editor
1489 .editor
1490 .update(cx, |editor, cx| {
1491 maybe!({
1492 let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1493 let multibuffer = editor.buffer().read(cx);
1494 let (buffer, cursor_point, _) =
1495 multibuffer.point_to_buffer_point(cursor, cx)?;
1496 let file = buffer.read(cx).file()?;
1497 if !file_is_eligible_for_collection(file.as_ref()) {
1498 return None;
1499 }
1500 let project_path = ProjectPath {
1501 worktree_id: file.worktree_id(cx),
1502 path: file.path().clone(),
1503 };
1504 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1505 if !worktree_entry_is_eligible_for_collection(entry) {
1506 return None;
1507 }
1508 let Some(repo_path) =
1509 repository.project_path_to_repo_path(&project_path, cx)
1510 else {
1511 // entry not removed since later queries may involve other repositories
1512 return Some(());
1513 };
1514 // paths may not be valid UTF-8
1515 let repo_path_str = repo_path.to_str()?;
1516 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1517 return None;
1518 }
1519 let active_to_now_ms = now
1520 .duration_since(recent_editor.last_active_at)
1521 .as_millis()
1522 .try_into()
1523 .ok()?;
1524 let cumulative_time_editing_ms = recent_editor
1525 .cumulative_time_editing
1526 .as_millis()
1527 .try_into()
1528 .ok()?;
1529 let cumulative_time_navigating_ms = recent_editor
1530 .cumulative_time_navigating
1531 .as_millis()
1532 .try_into()
1533 .ok()?;
1534 results.push(PredictEditsRecentFile {
1535 path: repo_path_str.to_string(),
1536 cursor_point: to_cloud_llm_client_point(cursor_point),
1537 active_to_now_ms,
1538 activation_count: recent_editor.activation_count,
1539 cumulative_time_editing_ms,
1540 cumulative_time_navigating_ms,
1541 is_multibuffer: !multibuffer.is_singleton(),
1542 });
1543 Some(())
1544 })
1545 })
1546 .ok()
1547 .flatten();
1548 if keep_entry.is_none() {
1549 self.recent_editors.remove(ix);
1550 }
1551 }
1552 results
1553 }
1554}
1555
1556fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1557 cloud_llm_client::Point {
1558 row: point.row,
1559 column: point.column,
1560 }
1561}
1562
1563fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1564 file.is_local() && !file.is_private()
1565}
1566
1567fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1568 entry.is_file()
1569 && entry.is_created()
1570 && !entry.is_ignored
1571 && !entry.is_private
1572 && !entry.is_external
1573 && !entry.is_fifo
1574}
1575
1576pub struct PerformPredictEditsParams {
1577 pub client: Arc<Client>,
1578 pub llm_token: LlmApiToken,
1579 pub app_version: SemanticVersion,
1580 pub body: PredictEditsBody,
1581}
1582
1583#[derive(Error, Debug)]
1584#[error(
1585 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1586)]
1587pub struct ZedUpdateRequiredError {
1588 minimum_version: SemanticVersion,
1589}
1590
1591fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1592 a.zip(b)
1593 .take_while(|(a, b)| a == b)
1594 .map(|(a, _)| a.len_utf8())
1595 .sum()
1596}
1597
1598pub struct GatherContextOutput {
1599 pub body: PredictEditsBody,
1600 pub editable_range: Range<usize>,
1601}
1602
1603pub async fn gather_context(
1604 full_path_str: String,
1605 snapshot: BufferSnapshot,
1606 cursor_point: language::Point,
1607 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1608 can_collect_data: CanCollectData,
1609) -> Result<GatherContextOutput> {
1610 let input_excerpt = excerpt_for_cursor_position(
1611 cursor_point,
1612 &full_path_str,
1613 &snapshot,
1614 MAX_REWRITE_TOKENS,
1615 MAX_CONTEXT_TOKENS,
1616 );
1617 let input_events = make_events_prompt();
1618 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1619
1620 let body = PredictEditsBody {
1621 input_events,
1622 input_excerpt: input_excerpt.prompt,
1623 can_collect_data: can_collect_data.0,
1624 diagnostic_groups: None,
1625 git_info: None,
1626 };
1627
1628 Ok(GatherContextOutput {
1629 body,
1630 editable_range,
1631 })
1632}
1633
1634fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1635 let mut result = String::new();
1636 for event in events.iter().rev() {
1637 let event_string = event.to_prompt();
1638 let event_tokens = tokens_for_bytes(event_string.len());
1639 if event_tokens > remaining_tokens {
1640 break;
1641 }
1642
1643 if !result.is_empty() {
1644 result.insert_str(0, "\n\n");
1645 }
1646 result.insert_str(0, &event_string);
1647 remaining_tokens -= event_tokens;
1648 }
1649 result
1650}
1651
1652struct RegisteredBuffer {
1653 snapshot: BufferSnapshot,
1654 _subscriptions: [gpui::Subscription; 2],
1655}
1656
1657#[derive(Clone)]
1658pub enum Event {
1659 BufferChange {
1660 old_snapshot: BufferSnapshot,
1661 new_snapshot: BufferSnapshot,
1662 timestamp: Instant,
1663 },
1664}
1665
1666impl Event {
1667 fn to_prompt(&self) -> String {
1668 match self {
1669 Event::BufferChange {
1670 old_snapshot,
1671 new_snapshot,
1672 ..
1673 } => {
1674 let mut prompt = String::new();
1675
1676 let old_path = old_snapshot
1677 .file()
1678 .map(|f| f.path().as_ref())
1679 .unwrap_or(Path::new("untitled"));
1680 let new_path = new_snapshot
1681 .file()
1682 .map(|f| f.path().as_ref())
1683 .unwrap_or(Path::new("untitled"));
1684 if old_path != new_path {
1685 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1686 }
1687
1688 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1689 if !diff.is_empty() {
1690 write!(
1691 prompt,
1692 "User edited {:?}:\n```diff\n{}\n```",
1693 new_path, diff
1694 )
1695 .unwrap();
1696 }
1697
1698 prompt
1699 }
1700 }
1701 }
1702}
1703
1704#[derive(Debug, Clone)]
1705struct CurrentEditPrediction {
1706 buffer_id: EntityId,
1707 completion: EditPrediction,
1708}
1709
1710impl CurrentEditPrediction {
1711 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1712 if self.buffer_id != old_completion.buffer_id {
1713 return true;
1714 }
1715
1716 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1717 return true;
1718 };
1719 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1720 return false;
1721 };
1722
1723 if old_edits.len() == 1 && new_edits.len() == 1 {
1724 let (old_range, old_text) = &old_edits[0];
1725 let (new_range, new_text) = &new_edits[0];
1726 new_range == old_range && new_text.starts_with(old_text)
1727 } else {
1728 true
1729 }
1730 }
1731}
1732
1733struct PendingCompletion {
1734 id: usize,
1735 _task: Task<()>,
1736}
1737
1738#[derive(Debug, Clone, Copy)]
1739pub enum DataCollectionChoice {
1740 NotAnswered,
1741 Enabled,
1742 Disabled,
1743}
1744
1745impl DataCollectionChoice {
1746 pub fn is_enabled(self) -> bool {
1747 match self {
1748 Self::Enabled => true,
1749 Self::NotAnswered | Self::Disabled => false,
1750 }
1751 }
1752
1753 pub fn is_answered(self) -> bool {
1754 match self {
1755 Self::Enabled | Self::Disabled => true,
1756 Self::NotAnswered => false,
1757 }
1758 }
1759
1760 pub fn toggle(&self) -> DataCollectionChoice {
1761 match self {
1762 Self::Enabled => Self::Disabled,
1763 Self::Disabled => Self::Enabled,
1764 Self::NotAnswered => Self::Enabled,
1765 }
1766 }
1767}
1768
1769impl From<bool> for DataCollectionChoice {
1770 fn from(value: bool) -> Self {
1771 match value {
1772 true => DataCollectionChoice::Enabled,
1773 false => DataCollectionChoice::Disabled,
1774 }
1775 }
1776}
1777
1778pub struct ProviderDataCollection {
1779 /// When set to None, data collection is not possible in the provider buffer
1780 choice: Option<Entity<DataCollectionChoice>>,
1781 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1782}
1783
1784#[derive(Debug, Clone, Copy)]
1785pub struct CanCollectData(pub bool);
1786
1787impl ProviderDataCollection {
1788 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1789 let choice_and_watcher = buffer.and_then(|buffer| {
1790 let file = buffer.read(cx).file()?;
1791
1792 if !file_is_eligible_for_collection(file.as_ref()) {
1793 return None;
1794 }
1795
1796 let zeta = zeta.read(cx);
1797 let choice = zeta.data_collection_choice.clone();
1798
1799 let license_detection_watcher = zeta
1800 .license_detection_watchers
1801 .get(&file.worktree_id(cx))
1802 .cloned()?;
1803
1804 Some((choice, license_detection_watcher))
1805 });
1806
1807 if let Some((choice, watcher)) = choice_and_watcher {
1808 ProviderDataCollection {
1809 choice: Some(choice),
1810 license_detection_watcher: Some(watcher),
1811 }
1812 } else {
1813 ProviderDataCollection {
1814 choice: None,
1815 license_detection_watcher: None,
1816 }
1817 }
1818 }
1819
1820 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1821 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1822 }
1823
1824 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1825 self.choice
1826 .as_ref()
1827 .is_some_and(|choice| choice.read(cx).is_enabled())
1828 }
1829
1830 fn is_project_open_source(&self) -> bool {
1831 self.license_detection_watcher
1832 .as_ref()
1833 .is_some_and(|watcher| watcher.is_project_open_source())
1834 }
1835
1836 pub fn toggle(&mut self, cx: &mut App) {
1837 if let Some(choice) = self.choice.as_mut() {
1838 let new_choice = choice.update(cx, |choice, _cx| {
1839 let new_choice = choice.toggle();
1840 *choice = new_choice;
1841 new_choice
1842 });
1843
1844 db::write_and_log(cx, move || {
1845 KEY_VALUE_STORE.write_kvp(
1846 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1847 new_choice.is_enabled().to_string(),
1848 )
1849 });
1850 }
1851 }
1852}
1853
1854async fn llm_token_retry(
1855 llm_token: &LlmApiToken,
1856 client: &Arc<Client>,
1857 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1858) -> Result<Response<AsyncBody>> {
1859 let mut did_retry = false;
1860 let http_client = client.http_client();
1861 let mut token = llm_token.acquire(client).await?;
1862 loop {
1863 let request = build_request(token.clone())?;
1864 let response = http_client.send(request).await?;
1865
1866 if !did_retry
1867 && !response.status().is_success()
1868 && response
1869 .headers()
1870 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1871 .is_some()
1872 {
1873 did_retry = true;
1874 token = llm_token.refresh(client).await?;
1875 continue;
1876 }
1877
1878 return Ok(response);
1879 }
1880}
1881
1882pub struct ZetaEditPredictionProvider {
1883 zeta: Entity<Zeta>,
1884 pending_completions: ArrayVec<PendingCompletion, 2>,
1885 next_pending_completion_id: usize,
1886 current_completion: Option<CurrentEditPrediction>,
1887 /// None if this is entirely disabled for this provider
1888 provider_data_collection: ProviderDataCollection,
1889 last_request_timestamp: Instant,
1890}
1891
1892impl ZetaEditPredictionProvider {
1893 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1894
1895 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1896 Self {
1897 zeta,
1898 pending_completions: ArrayVec::new(),
1899 next_pending_completion_id: 0,
1900 current_completion: None,
1901 provider_data_collection,
1902 last_request_timestamp: Instant::now(),
1903 }
1904 }
1905}
1906
1907impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1908 fn name() -> &'static str {
1909 "zed-predict"
1910 }
1911
1912 fn display_name() -> &'static str {
1913 "Zed's Edit Predictions"
1914 }
1915
1916 fn show_completions_in_menu() -> bool {
1917 true
1918 }
1919
1920 fn show_tab_accept_marker() -> bool {
1921 true
1922 }
1923
1924 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1925 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1926
1927 if self.provider_data_collection.is_data_collection_enabled(cx) {
1928 DataCollectionState::Enabled {
1929 is_project_open_source,
1930 }
1931 } else {
1932 DataCollectionState::Disabled {
1933 is_project_open_source,
1934 }
1935 }
1936 }
1937
1938 fn toggle_data_collection(&mut self, cx: &mut App) {
1939 self.provider_data_collection.toggle(cx);
1940 }
1941
1942 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1943 self.zeta.read(cx).usage(cx)
1944 }
1945
1946 fn is_enabled(
1947 &self,
1948 _buffer: &Entity<Buffer>,
1949 _cursor_position: language::Anchor,
1950 _cx: &App,
1951 ) -> bool {
1952 true
1953 }
1954 fn is_refreshing(&self) -> bool {
1955 !self.pending_completions.is_empty()
1956 }
1957
1958 fn refresh(
1959 &mut self,
1960 project: Option<Entity<Project>>,
1961 buffer: Entity<Buffer>,
1962 position: language::Anchor,
1963 _debounce: bool,
1964 cx: &mut Context<Self>,
1965 ) {
1966 if self.zeta.read(cx).update_required {
1967 return;
1968 }
1969
1970 if self
1971 .zeta
1972 .read(cx)
1973 .user_store
1974 .read_with(cx, |user_store, _cx| {
1975 user_store.account_too_young() || user_store.has_overdue_invoices()
1976 })
1977 {
1978 return;
1979 }
1980
1981 if let Some(current_completion) = self.current_completion.as_ref() {
1982 let snapshot = buffer.read(cx).snapshot();
1983 if current_completion
1984 .completion
1985 .interpolate(&snapshot)
1986 .is_some()
1987 {
1988 return;
1989 }
1990 }
1991
1992 let pending_completion_id = self.next_pending_completion_id;
1993 self.next_pending_completion_id += 1;
1994 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1995 let last_request_timestamp = self.last_request_timestamp;
1996
1997 let task = cx.spawn(async move |this, cx| {
1998 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1999 .checked_duration_since(Instant::now())
2000 {
2001 cx.background_executor().timer(timeout).await;
2002 }
2003
2004 let completion_request = this.update(cx, |this, cx| {
2005 this.last_request_timestamp = Instant::now();
2006 this.zeta.update(cx, |zeta, cx| {
2007 zeta.request_completion(project, &buffer, position, can_collect_data, cx)
2008 })
2009 });
2010
2011 let completion = match completion_request {
2012 Ok(completion_request) => {
2013 let completion_request = completion_request.await;
2014 completion_request.map(|c| {
2015 c.map(|completion| CurrentEditPrediction {
2016 buffer_id: buffer.entity_id(),
2017 completion,
2018 })
2019 })
2020 }
2021 Err(error) => Err(error),
2022 };
2023 let Some(new_completion) = completion
2024 .context("edit prediction failed")
2025 .log_err()
2026 .flatten()
2027 else {
2028 this.update(cx, |this, cx| {
2029 if this.pending_completions[0].id == pending_completion_id {
2030 this.pending_completions.remove(0);
2031 } else {
2032 this.pending_completions.clear();
2033 }
2034
2035 cx.notify();
2036 })
2037 .ok();
2038 return;
2039 };
2040
2041 this.update(cx, |this, cx| {
2042 if this.pending_completions[0].id == pending_completion_id {
2043 this.pending_completions.remove(0);
2044 } else {
2045 this.pending_completions.clear();
2046 }
2047
2048 if let Some(old_completion) = this.current_completion.as_ref() {
2049 let snapshot = buffer.read(cx).snapshot();
2050 if new_completion.should_replace_completion(old_completion, &snapshot) {
2051 this.zeta.update(cx, |zeta, cx| {
2052 zeta.completion_shown(&new_completion.completion, cx);
2053 });
2054 this.current_completion = Some(new_completion);
2055 }
2056 } else {
2057 this.zeta.update(cx, |zeta, cx| {
2058 zeta.completion_shown(&new_completion.completion, cx);
2059 });
2060 this.current_completion = Some(new_completion);
2061 }
2062
2063 cx.notify();
2064 })
2065 .ok();
2066 });
2067
2068 // We always maintain at most two pending completions. When we already
2069 // have two, we replace the newest one.
2070 if self.pending_completions.len() <= 1 {
2071 self.pending_completions.push(PendingCompletion {
2072 id: pending_completion_id,
2073 _task: task,
2074 });
2075 } else if self.pending_completions.len() == 2 {
2076 self.pending_completions.pop();
2077 self.pending_completions.push(PendingCompletion {
2078 id: pending_completion_id,
2079 _task: task,
2080 });
2081 }
2082 }
2083
2084 fn cycle(
2085 &mut self,
2086 _buffer: Entity<Buffer>,
2087 _cursor_position: language::Anchor,
2088 _direction: edit_prediction::Direction,
2089 _cx: &mut Context<Self>,
2090 ) {
2091 // Right now we don't support cycling.
2092 }
2093
2094 fn accept(&mut self, cx: &mut Context<Self>) {
2095 let completion_id = self
2096 .current_completion
2097 .as_ref()
2098 .map(|completion| completion.completion.id);
2099 if let Some(completion_id) = completion_id {
2100 self.zeta
2101 .update(cx, |zeta, cx| {
2102 zeta.accept_edit_prediction(completion_id, cx)
2103 })
2104 .detach();
2105 }
2106 self.pending_completions.clear();
2107 }
2108
2109 fn discard(&mut self, _cx: &mut Context<Self>) {
2110 self.pending_completions.clear();
2111 self.current_completion.take();
2112 }
2113
2114 fn suggest(
2115 &mut self,
2116 buffer: &Entity<Buffer>,
2117 cursor_position: language::Anchor,
2118 cx: &mut Context<Self>,
2119 ) -> Option<edit_prediction::EditPrediction> {
2120 let CurrentEditPrediction {
2121 buffer_id,
2122 completion,
2123 ..
2124 } = self.current_completion.as_mut()?;
2125
2126 // Invalidate previous completion if it was generated for a different buffer.
2127 if *buffer_id != buffer.entity_id() {
2128 self.current_completion.take();
2129 return None;
2130 }
2131
2132 let buffer = buffer.read(cx);
2133 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2134 self.current_completion.take();
2135 return None;
2136 };
2137
2138 let cursor_row = cursor_position.to_point(buffer).row;
2139 let (closest_edit_ix, (closest_edit_range, _)) =
2140 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2141 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2142 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2143 cmp::min(distance_from_start, distance_from_end)
2144 })?;
2145
2146 let mut edit_start_ix = closest_edit_ix;
2147 for (range, _) in edits[..edit_start_ix].iter().rev() {
2148 let distance_from_closest_edit =
2149 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2150 if distance_from_closest_edit <= 1 {
2151 edit_start_ix -= 1;
2152 } else {
2153 break;
2154 }
2155 }
2156
2157 let mut edit_end_ix = closest_edit_ix + 1;
2158 for (range, _) in &edits[edit_end_ix..] {
2159 let distance_from_closest_edit =
2160 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2161 if distance_from_closest_edit <= 1 {
2162 edit_end_ix += 1;
2163 } else {
2164 break;
2165 }
2166 }
2167
2168 Some(edit_prediction::EditPrediction {
2169 id: Some(completion.id.to_string().into()),
2170 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2171 edit_preview: Some(completion.edit_preview.clone()),
2172 })
2173 }
2174}
2175
2176fn tokens_for_bytes(bytes: usize) -> usize {
2177 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2178 /// intentionally low to err on the side of underestimating limits.
2179 const BYTES_PER_TOKEN_GUESS: usize = 3;
2180 bytes / BYTES_PER_TOKEN_GUESS
2181}
2182
2183#[cfg(test)]
2184mod tests {
2185 use client::UserStore;
2186 use client::test::FakeServer;
2187 use clock::FakeSystemClock;
2188 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2189 use gpui::TestAppContext;
2190 use http_client::FakeHttpClient;
2191 use indoc::indoc;
2192 use language::Point;
2193 use settings::SettingsStore;
2194
2195 use super::*;
2196
2197 #[gpui::test]
2198 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2199 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2200 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2201 to_completion_edits(
2202 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2203 &buffer,
2204 cx,
2205 )
2206 .into()
2207 });
2208
2209 let edit_preview = cx
2210 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2211 .await;
2212
2213 let completion = EditPrediction {
2214 edits,
2215 edit_preview,
2216 path: Path::new("").into(),
2217 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2218 id: EditPredictionId(Uuid::new_v4()),
2219 excerpt_range: 0..0,
2220 cursor_offset: 0,
2221 input_events: "".into(),
2222 input_excerpt: "".into(),
2223 output_excerpt: "".into(),
2224 buffer_snapshotted_at: Instant::now(),
2225 response_received_at: Instant::now(),
2226 };
2227
2228 cx.update(|cx| {
2229 assert_eq!(
2230 from_completion_edits(
2231 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2232 &buffer,
2233 cx
2234 ),
2235 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2236 );
2237
2238 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2239 assert_eq!(
2240 from_completion_edits(
2241 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2242 &buffer,
2243 cx
2244 ),
2245 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2246 );
2247
2248 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2249 assert_eq!(
2250 from_completion_edits(
2251 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2252 &buffer,
2253 cx
2254 ),
2255 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2256 );
2257
2258 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2259 assert_eq!(
2260 from_completion_edits(
2261 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2262 &buffer,
2263 cx
2264 ),
2265 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2266 );
2267
2268 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2269 assert_eq!(
2270 from_completion_edits(
2271 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2272 &buffer,
2273 cx
2274 ),
2275 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2276 );
2277
2278 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2279 assert_eq!(
2280 from_completion_edits(
2281 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2282 &buffer,
2283 cx
2284 ),
2285 vec![(9..11, "".to_string())]
2286 );
2287
2288 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2289 assert_eq!(
2290 from_completion_edits(
2291 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2292 &buffer,
2293 cx
2294 ),
2295 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2296 );
2297
2298 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2299 assert_eq!(
2300 from_completion_edits(
2301 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2302 &buffer,
2303 cx
2304 ),
2305 vec![(4..4, "M".to_string())]
2306 );
2307
2308 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2309 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2310 })
2311 }
2312
2313 #[gpui::test]
2314 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2315 cx.update(|cx| {
2316 let settings_store = SettingsStore::test(cx);
2317 cx.set_global(settings_store);
2318 client::init_settings(cx);
2319 });
2320
2321 let edits = edits_for_prediction(
2322 indoc! {"
2323 fn main() {
2324 let word_1 = \"lorem\";
2325 let range = word.len()..word.len();
2326 }
2327 "},
2328 indoc! {"
2329 <|editable_region_start|>
2330 fn main() {
2331 let word_1 = \"lorem\";
2332 let range = word_1.len()..word_1.len();
2333 }
2334
2335 <|editable_region_end|>
2336 "},
2337 cx,
2338 )
2339 .await;
2340 assert_eq!(
2341 edits,
2342 [
2343 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2344 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2345 ]
2346 );
2347
2348 let edits = edits_for_prediction(
2349 indoc! {"
2350 fn main() {
2351 let story = \"the quick\"
2352 }
2353 "},
2354 indoc! {"
2355 <|editable_region_start|>
2356 fn main() {
2357 let story = \"the quick brown fox jumps over the lazy dog\";
2358 }
2359
2360 <|editable_region_end|>
2361 "},
2362 cx,
2363 )
2364 .await;
2365 assert_eq!(
2366 edits,
2367 [
2368 (
2369 Point::new(1, 26)..Point::new(1, 26),
2370 " brown fox jumps over the lazy dog".to_string()
2371 ),
2372 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2373 ]
2374 );
2375 }
2376
2377 #[gpui::test]
2378 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2379 cx.update(|cx| {
2380 let settings_store = SettingsStore::test(cx);
2381 cx.set_global(settings_store);
2382 client::init_settings(cx);
2383 });
2384
2385 let buffer_content = "lorem\n";
2386 let completion_response = indoc! {"
2387 ```animals.js
2388 <|start_of_file|>
2389 <|editable_region_start|>
2390 lorem
2391 ipsum
2392 <|editable_region_end|>
2393 ```"};
2394
2395 let http_client = FakeHttpClient::create(move |req| async move {
2396 match (req.method(), req.uri().path()) {
2397 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2398 .status(200)
2399 .body(
2400 serde_json::to_string(&CreateLlmTokenResponse {
2401 token: LlmToken("the-llm-token".to_string()),
2402 })
2403 .unwrap()
2404 .into(),
2405 )
2406 .unwrap()),
2407 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2408 .status(200)
2409 .body(
2410 serde_json::to_string(&PredictEditsResponse {
2411 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2412 .unwrap(),
2413 output_excerpt: completion_response.to_string(),
2414 })
2415 .unwrap()
2416 .into(),
2417 )
2418 .unwrap()),
2419 _ => Ok(http_client::Response::builder()
2420 .status(404)
2421 .body("Not Found".into())
2422 .unwrap()),
2423 }
2424 });
2425
2426 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2427 cx.update(|cx| {
2428 RefreshLlmTokenListener::register(client.clone(), cx);
2429 });
2430 // Construct the fake server to authenticate.
2431 let _server = FakeServer::for_client(42, &client, cx).await;
2432 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2433 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2434
2435 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2436 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2437 let completion_task = zeta.update(cx, |zeta, cx| {
2438 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2439 });
2440
2441 let completion = completion_task.await.unwrap().unwrap();
2442 buffer.update(cx, |buffer, cx| {
2443 buffer.edit(completion.edits.iter().cloned(), None, cx)
2444 });
2445 assert_eq!(
2446 buffer.read_with(cx, |buffer, _| buffer.text()),
2447 "lorem\nipsum"
2448 );
2449 }
2450
2451 async fn edits_for_prediction(
2452 buffer_content: &str,
2453 completion_response: &str,
2454 cx: &mut TestAppContext,
2455 ) -> Vec<(Range<Point>, String)> {
2456 let completion_response = completion_response.to_string();
2457 let http_client = FakeHttpClient::create(move |req| {
2458 let completion = completion_response.clone();
2459 async move {
2460 match (req.method(), req.uri().path()) {
2461 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2462 .status(200)
2463 .body(
2464 serde_json::to_string(&CreateLlmTokenResponse {
2465 token: LlmToken("the-llm-token".to_string()),
2466 })
2467 .unwrap()
2468 .into(),
2469 )
2470 .unwrap()),
2471 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2472 .status(200)
2473 .body(
2474 serde_json::to_string(&PredictEditsResponse {
2475 request_id: Uuid::new_v4(),
2476 output_excerpt: completion,
2477 })
2478 .unwrap()
2479 .into(),
2480 )
2481 .unwrap()),
2482 _ => Ok(http_client::Response::builder()
2483 .status(404)
2484 .body("Not Found".into())
2485 .unwrap()),
2486 }
2487 }
2488 });
2489
2490 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2491 cx.update(|cx| {
2492 RefreshLlmTokenListener::register(client.clone(), cx);
2493 });
2494 // Construct the fake server to authenticate.
2495 let _server = FakeServer::for_client(42, &client, cx).await;
2496 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2497 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2498
2499 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2500 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2501 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2502 let completion_task = zeta.update(cx, |zeta, cx| {
2503 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2504 });
2505
2506 let completion = completion_task.await.unwrap().unwrap();
2507 completion
2508 .edits
2509 .iter()
2510 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2511 .collect::<Vec<_>>()
2512 }
2513
2514 fn to_completion_edits(
2515 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2516 buffer: &Entity<Buffer>,
2517 cx: &App,
2518 ) -> Vec<(Range<Anchor>, String)> {
2519 let buffer = buffer.read(cx);
2520 iterator
2521 .into_iter()
2522 .map(|(range, text)| {
2523 (
2524 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2525 text,
2526 )
2527 })
2528 .collect()
2529 }
2530
2531 fn from_completion_edits(
2532 editor_edits: &[(Range<Anchor>, String)],
2533 buffer: &Entity<Buffer>,
2534 cx: &App,
2535 ) -> Vec<(Range<usize>, String)> {
2536 let buffer = buffer.read(cx);
2537 editor_edits
2538 .iter()
2539 .map(|(range, text)| {
2540 (
2541 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2542 text.clone(),
2543 )
2544 })
2545 .collect()
2546 }
2547
2548 #[ctor::ctor]
2549 fn init_logger() {
2550 zlog::init_test();
2551 }
2552}