1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 SharedString, Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use multi_buffer::MultiBufferPoint;
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::str::FromStr;
43use std::{
44 cmp,
45 fmt::Write,
46 future::Future,
47 mem,
48 ops::Range,
49 path::Path,
50 rc::Rc,
51 sync::Arc,
52 time::{Duration, Instant},
53};
54use telemetry_events::EditPredictionRating;
55use thiserror::Error;
56use util::{ResultExt, maybe};
57use uuid::Uuid;
58use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
59use worktree::Worktree;
60
61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Minimum number of milliseconds between recent file entries.
79const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
80
81/// Maximum file path length to include in recent files list.
82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
83
84/// Maximum number of JSON bytes for diagnostics in additional context.
85const MAX_DIAGNOSTICS_BYTES: usize = 4096;
86
87/// Maximum number of edit predictions to store for feedback.
88const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
89
90/// Interval between polls tracking time editing files.
91const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
92
93/// Interval between polls of whether data collection is enabled, when it is disabled.
94const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
95
96actions!(
97 edit_prediction,
98 [
99 /// Clears the edit prediction history.
100 ClearHistory
101 ]
102);
103
104#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
105pub struct EditPredictionId(Uuid);
106
107impl From<EditPredictionId> for gpui::ElementId {
108 fn from(value: EditPredictionId) -> Self {
109 gpui::ElementId::Uuid(value.0)
110 }
111}
112
113impl std::fmt::Display for EditPredictionId {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 write!(f, "{}", self.0)
116 }
117}
118
119struct ZedPredictUpsell;
120
121impl Dismissable for ZedPredictUpsell {
122 const KEY: &'static str = "dismissed-edit-predict-upsell";
123
124 fn dismissed() -> bool {
125 // To make this backwards compatible with older versions of Zed, we
126 // check if the user has seen the previous Edit Prediction Onboarding
127 // before, by checking the data collection choice which was written to
128 // the database once the user clicked on "Accept and Enable"
129 if KEY_VALUE_STORE
130 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
131 .log_err()
132 .is_some_and(|s| s.is_some())
133 {
134 return true;
135 }
136
137 KEY_VALUE_STORE
138 .read_kvp(Self::KEY)
139 .log_err()
140 .is_some_and(|s| s.is_some())
141 }
142}
143
144pub fn should_show_upsell_modal() -> bool {
145 !ZedPredictUpsell::dismissed()
146}
147
148#[derive(Clone)]
149struct ZetaGlobal(Entity<Zeta>);
150
151impl Global for ZetaGlobal {}
152
153#[derive(Clone)]
154pub struct EditPrediction {
155 id: EditPredictionId,
156 path: Arc<Path>,
157 excerpt_range: Range<usize>,
158 cursor_offset: usize,
159 edits: Arc<[(Range<Anchor>, String)]>,
160 snapshot: BufferSnapshot,
161 edit_preview: EditPreview,
162 input_events: Arc<str>,
163 input_excerpt: Arc<str>,
164 output_excerpt: Arc<str>,
165 buffer_snapshotted_at: Instant,
166 response_received_at: Instant,
167}
168
169impl EditPrediction {
170 fn latency(&self) -> Duration {
171 self.response_received_at
172 .duration_since(self.buffer_snapshotted_at)
173 }
174
175 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
176 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
177 }
178}
179
180fn interpolate(
181 old_snapshot: &BufferSnapshot,
182 new_snapshot: &BufferSnapshot,
183 current_edits: Arc<[(Range<Anchor>, String)]>,
184) -> Option<Vec<(Range<Anchor>, String)>> {
185 let mut edits = Vec::new();
186
187 let mut model_edits = current_edits.iter().peekable();
188 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
189 while let Some((model_old_range, _)) = model_edits.peek() {
190 let model_old_range = model_old_range.to_offset(old_snapshot);
191 if model_old_range.end < user_edit.old.start {
192 let (model_old_range, model_new_text) = model_edits.next().unwrap();
193 edits.push((model_old_range.clone(), model_new_text.clone()));
194 } else {
195 break;
196 }
197 }
198
199 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
200 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
201 if user_edit.old == model_old_offset_range {
202 let user_new_text = new_snapshot
203 .text_for_range(user_edit.new.clone())
204 .collect::<String>();
205
206 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
207 if !model_suffix.is_empty() {
208 let anchor = old_snapshot.anchor_after(user_edit.old.end);
209 edits.push((anchor..anchor, model_suffix.to_string()));
210 }
211
212 model_edits.next();
213 continue;
214 }
215 }
216 }
217
218 return None;
219 }
220
221 edits.extend(model_edits.cloned());
222
223 if edits.is_empty() { None } else { Some(edits) }
224}
225
226impl std::fmt::Debug for EditPrediction {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("EditPrediction")
229 .field("id", &self.id)
230 .field("path", &self.path)
231 .field("edits", &self.edits)
232 .finish_non_exhaustive()
233 }
234}
235
236pub struct Zeta {
237 client: Arc<Client>,
238 events: VecDeque<Event>,
239 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
240 shown_completions: VecDeque<EditPrediction>,
241 rated_completions: HashSet<EditPredictionId>,
242 data_collection_choice: Entity<DataCollectionChoice>,
243 llm_token: LlmApiToken,
244 _llm_token_subscription: Subscription,
245 /// Whether an update to a newer version of Zed is required to continue using Zeta.
246 update_required: bool,
247 user_store: Entity<UserStore>,
248 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
249 recent_editors: VecDeque<RecentEditor>,
250 last_activity_state: Option<ActivityState>,
251 _activity_poll_task: Option<Task<Result<()>>>,
252}
253
254struct RecentEditor {
255 editor: WeakEntity<Editor>,
256 last_active_at: Instant,
257 activation_count: u32,
258 cumulative_time_editing: Duration,
259 cumulative_time_navigating: Duration,
260}
261
262#[derive(Debug)]
263struct ActivityState {
264 scroll_position: gpui::Point<f32>,
265 cursor_point: MultiBufferPoint,
266 singleton_version: Option<clock::Global>,
267}
268
269impl Zeta {
270 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
271 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
272 }
273
274 pub fn register(
275 worktree: Option<Entity<Worktree>>,
276 client: Arc<Client>,
277 user_store: Entity<UserStore>,
278 cx: &mut App,
279 ) -> Entity<Self> {
280 let this = Self::global(cx).unwrap_or_else(|| {
281 let entity = cx.new(|cx| Self::new(client, user_store, cx));
282 cx.set_global(ZetaGlobal(entity.clone()));
283 entity
284 });
285
286 this.update(cx, move |this, cx| {
287 if let Some(worktree) = worktree {
288 let worktree_id = worktree.read(cx).id();
289 this.license_detection_watchers
290 .entry(worktree_id)
291 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
292 }
293 });
294
295 this
296 }
297
298 pub fn clear_history(&mut self) {
299 self.events.clear();
300 }
301
302 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
303 self.user_store.read(cx).edit_prediction_usage()
304 }
305
306 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
307 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
308
309 let data_collection_choice = Self::load_data_collection_choices();
310 let data_collection_choice = cx.new(|_| data_collection_choice);
311
312 let mut activity_poll_task = None;
313
314 if let Some(workspace) = &workspace {
315 let project = workspace.read(cx).project().clone();
316 cx.subscribe(&project, |this, _project, event, cx| match event {
317 project::Event::ActiveEntryChanged(entry_id) => {
318 this.handle_active_project_entry_changed(cx)
319 }
320 _ => {}
321 })
322 .detach();
323
324 // TODO: ideally this would attend to window focus when tracking time, and pause the
325 // loop for efficiency when not focused.
326 activity_poll_task = Some(cx.spawn(async move |this, cx| {
327 let mut instant_before_delay = None;
328 loop {
329 let data_collection_is_enabled = this.read_with(cx, |this, cx| {
330 this.data_collection_choice.read(cx).is_enabled()
331 })?;
332 let interval = if data_collection_is_enabled {
333 ACTIVITY_POLL_INTERVAL
334 } else {
335 instant_before_delay = None;
336 DISABLED_ACTIVITY_POLL_INTERVAL
337 };
338 cx.background_executor().timer(interval).await;
339 this.update(cx, |this, cx| {
340 let now = Instant::now();
341 this.handle_activity_poll(instant_before_delay, now, cx);
342 instant_before_delay = Some(now);
343 })?
344 }
345 }));
346 }
347
348 Self {
349 client,
350 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
351 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
352 rated_completions: HashSet::default(),
353 registered_buffers: HashMap::default(),
354 data_collection_choice,
355 llm_token: LlmApiToken::default(),
356 _llm_token_subscription: cx.subscribe(
357 &refresh_llm_token_listener,
358 |this, _listener, _event, cx| {
359 let client = this.client.clone();
360 let llm_token = this.llm_token.clone();
361 cx.spawn(async move |_this, _cx| {
362 llm_token.refresh(&client).await?;
363 anyhow::Ok(())
364 })
365 .detach_and_log_err(cx);
366 },
367 ),
368 update_required: false,
369 license_detection_watchers: HashMap::default(),
370 user_store,
371 recent_editors: VecDeque::new(),
372 last_activity_state: None,
373 _activity_poll_task: activity_poll_task,
374 }
375 }
376
377 fn push_event(&mut self, event: Event) {
378 if let Some(Event::BufferChange {
379 new_snapshot: last_new_snapshot,
380 timestamp: last_timestamp,
381 ..
382 }) = self.events.back_mut()
383 {
384 // Coalesce edits for the same buffer when they happen one after the other.
385 let Event::BufferChange {
386 old_snapshot,
387 new_snapshot,
388 timestamp,
389 } = &event;
390
391 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
392 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
393 && old_snapshot.version == last_new_snapshot.version
394 {
395 *last_new_snapshot = new_snapshot.clone();
396 *last_timestamp = *timestamp;
397 return;
398 }
399 }
400
401 if self.events.len() >= MAX_EVENT_COUNT {
402 // These are halved instead of popping to improve prompt caching.
403 self.events.drain(..MAX_EVENT_COUNT / 2);
404 }
405
406 self.events.push_back(event);
407 }
408
409 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
410 let buffer_id = buffer.entity_id();
411 let weak_buffer = buffer.downgrade();
412
413 if let std::collections::hash_map::Entry::Vacant(entry) =
414 self.registered_buffers.entry(buffer_id)
415 {
416 let snapshot = buffer.read(cx).snapshot();
417
418 entry.insert(RegisteredBuffer {
419 snapshot,
420 _subscriptions: [
421 cx.subscribe(buffer, move |this, buffer, event, cx| {
422 this.handle_buffer_event(buffer, event, cx);
423 }),
424 cx.observe_release(buffer, move |this, _buffer, _cx| {
425 this.registered_buffers.remove(&weak_buffer.entity_id());
426 }),
427 ],
428 });
429 };
430 }
431
432 fn handle_buffer_event(
433 &mut self,
434 buffer: Entity<Buffer>,
435 event: &language::BufferEvent,
436 cx: &mut Context<Self>,
437 ) {
438 if let language::BufferEvent::Edited = event {
439 self.report_changes_for_buffer(&buffer, cx);
440 }
441 }
442
443 fn request_completion_impl<F, R>(
444 &mut self,
445 project: Option<Entity<Project>>,
446 buffer: &Entity<Buffer>,
447 cursor: language::Anchor,
448 can_collect_data: CanCollectData,
449 cx: &mut Context<Self>,
450 perform_predict_edits: F,
451 ) -> Task<Result<Option<EditPrediction>>>
452 where
453 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
454 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
455 + Send
456 + 'static,
457 {
458 let buffer = buffer.clone();
459 let buffer_snapshotted_at = Instant::now();
460 let snapshot = self.report_changes_for_buffer(&buffer, cx);
461 let zeta = cx.entity();
462 let events = self.events.clone();
463 let client = self.client.clone();
464 let llm_token = self.llm_token.clone();
465 let app_version = AppVersion::global(cx);
466
467 let full_path: Arc<Path> = snapshot
468 .file()
469 .map(|f| Arc::from(f.full_path(cx).as_path()))
470 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
471 let full_path_str = full_path.to_string_lossy().to_string();
472 let cursor_point = cursor.to_point(&snapshot);
473 let cursor_offset = cursor_point.to_offset(&snapshot);
474 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
475 let gather_task = cx.background_spawn(gather_context(
476 full_path_str,
477 snapshot.clone(),
478 cursor_point,
479 make_events_prompt,
480 can_collect_data,
481 ));
482
483 cx.spawn(async move |this, cx| {
484 let GatherContextOutput {
485 body,
486 editable_range,
487 } = gather_task.await?;
488 let done_gathering_context_at = Instant::now();
489
490 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
491 && let Some(file) = snapshot.file()
492 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
493 {
494 // This is async to reduce latency of the edit predictions request. The downside is
495 // that it will see a slightly later state than was used when gathering context.
496 let snapshot = snapshot.clone();
497 let this = this.clone();
498 Some(cx.spawn(async move |cx| {
499 if let Ok(Some(task)) = this.update(cx, |this, cx| {
500 this.gather_additional_context(
501 cursor_point,
502 cursor_offset,
503 snapshot,
504 &buffer_snapshotted_at,
505 project_path,
506 project.as_ref(),
507 cx,
508 )
509 }) {
510 Some(task.await)
511 } else {
512 None
513 }
514 }))
515 } else {
516 None
517 };
518
519 log::debug!(
520 "Events:\n{}\nExcerpt:\n{:?}",
521 body.input_events,
522 body.input_excerpt
523 );
524
525 let input_events = body.input_events.clone();
526 let input_excerpt = body.input_excerpt.clone();
527
528 let response = perform_predict_edits(PerformPredictEditsParams {
529 client,
530 llm_token,
531 app_version,
532 body,
533 })
534 .await;
535 let (response, usage) = match response {
536 Ok(response) => response,
537 Err(err) => {
538 if err.is::<ZedUpdateRequiredError>() {
539 cx.update(|cx| {
540 zeta.update(cx, |zeta, _cx| {
541 zeta.update_required = true;
542 });
543
544 let error_message: SharedString = err.to_string().into();
545 show_app_notification(
546 NotificationId::unique::<ZedUpdateRequiredError>(),
547 cx,
548 move |cx| {
549 cx.new(|cx| {
550 ErrorMessagePrompt::new(error_message.clone(), cx)
551 .with_link_button(
552 "Update Zed",
553 "https://zed.dev/releases",
554 )
555 })
556 },
557 );
558 })
559 .ok();
560 }
561
562 return Err(err);
563 }
564 };
565
566 let received_response_at = Instant::now();
567 log::debug!("completion response: {}", &response.output_excerpt);
568
569 if let Some(usage) = usage {
570 this.update(cx, |this, cx| {
571 this.user_store.update(cx, |user_store, cx| {
572 user_store.update_edit_prediction_usage(usage, cx);
573 });
574 })
575 .ok();
576 }
577
578 let request_id = response.request_id.clone();
579 let edit_prediction = Self::process_completion_response(
580 response,
581 buffer,
582 &snapshot,
583 editable_range,
584 cursor_offset,
585 full_path,
586 input_events,
587 input_excerpt,
588 buffer_snapshotted_at,
589 cx,
590 )
591 .await;
592
593 let finished_at = Instant::now();
594
595 // record latency for ~1% of requests
596 if rand::random::<u8>() <= 2 {
597 telemetry::event!(
598 "Edit Prediction Request",
599 context_latency = done_gathering_context_at
600 .duration_since(buffer_snapshotted_at)
601 .as_millis(),
602 request_latency = received_response_at
603 .duration_since(done_gathering_context_at)
604 .as_millis(),
605 process_latency = finished_at.duration_since(received_response_at).as_millis()
606 );
607 }
608
609 if let Some(additional_context_task) = additional_context_task {
610 cx.background_spawn(async move {
611 if let Some(additional_context) = additional_context_task.await {
612 telemetry::event!(
613 "Edit Prediction Additional Context",
614 request_id = request_id,
615 additional_context = additional_context
616 );
617 }
618 })
619 .detach();
620 }
621
622 edit_prediction
623 })
624 }
625
626 // Generates several example completions of various states to fill the Zeta completion modal
627 #[cfg(any(test, feature = "test-support"))]
628 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
629 use language::Point;
630
631 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
632 And maybe a short line
633
634 Then a few lines
635
636 and then another
637 "#};
638
639 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
640 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
641
642 let completion_tasks = vec![
643 self.fake_completion(
644 None,
645 &buffer,
646 position,
647 PredictEditsResponse {
648 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
649 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
650a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
651[here's an edit]
652And maybe a short line
653Then a few lines
654and then another
655{EDITABLE_REGION_END_MARKER}
656 "),
657 },
658 cx,
659 ),
660 self.fake_completion(
661 None,
662 &buffer,
663 position,
664 PredictEditsResponse {
665 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
666 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
667a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
668And maybe a short line
669[and another edit]
670Then a few lines
671and then another
672{EDITABLE_REGION_END_MARKER}
673 "#),
674 },
675 cx,
676 ),
677 self.fake_completion(
678 None,
679 &buffer,
680 position,
681 PredictEditsResponse {
682 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
683 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
684a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
685And maybe a short line
686
687Then a few lines
688
689and then another
690{EDITABLE_REGION_END_MARKER}
691 "#),
692 },
693 cx,
694 ),
695 self.fake_completion(
696 None,
697 &buffer,
698 position,
699 PredictEditsResponse {
700 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
701 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
702a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
703And maybe a short line
704
705Then a few lines
706
707and then another
708{EDITABLE_REGION_END_MARKER}
709 "#),
710 },
711 cx,
712 ),
713 self.fake_completion(
714 None,
715 &buffer,
716 position,
717 PredictEditsResponse {
718 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
719 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
720a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
721And maybe a short line
722Then a few lines
723[a third completion]
724and then another
725{EDITABLE_REGION_END_MARKER}
726 "#),
727 },
728 cx,
729 ),
730 self.fake_completion(
731 None,
732 &buffer,
733 position,
734 PredictEditsResponse {
735 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
736 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
737a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
738And maybe a short line
739and then another
740[fourth completion example]
741{EDITABLE_REGION_END_MARKER}
742 "#),
743 },
744 cx,
745 ),
746 self.fake_completion(
747 None,
748 &buffer,
749 position,
750 PredictEditsResponse {
751 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
752 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
753a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
754And maybe a short line
755Then a few lines
756and then another
757[fifth and final completion]
758{EDITABLE_REGION_END_MARKER}
759 "#),
760 },
761 cx,
762 ),
763 ];
764
765 cx.spawn(async move |zeta, cx| {
766 for task in completion_tasks {
767 task.await.unwrap();
768 }
769
770 zeta.update(cx, |zeta, _cx| {
771 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
772 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
773 })
774 .ok();
775 })
776 }
777
778 #[cfg(any(test, feature = "test-support"))]
779 pub fn fake_completion(
780 &mut self,
781 project: Option<Entity<Project>>,
782 buffer: &Entity<Buffer>,
783 position: language::Anchor,
784 response: PredictEditsResponse,
785 cx: &mut Context<Self>,
786 ) -> Task<Result<Option<EditPrediction>>> {
787 use std::future::ready;
788
789 self.request_completion_impl(
790 project,
791 buffer,
792 position,
793 CanCollectData(false),
794 cx,
795 |_params| ready(Ok((response, None))),
796 )
797 }
798
799 pub fn request_completion(
800 &mut self,
801 project: Option<Entity<Project>>,
802 buffer: &Entity<Buffer>,
803 position: language::Anchor,
804 can_collect_data: CanCollectData,
805 cx: &mut Context<Self>,
806 ) -> Task<Result<Option<EditPrediction>>> {
807 self.request_completion_impl(
808 project,
809 buffer,
810 position,
811 can_collect_data,
812 cx,
813 Self::perform_predict_edits,
814 )
815 }
816
817 pub fn perform_predict_edits(
818 params: PerformPredictEditsParams,
819 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
820 async move {
821 let PerformPredictEditsParams {
822 client,
823 llm_token,
824 app_version,
825 body,
826 ..
827 } = params;
828
829 let http_client = client.http_client();
830 let mut token = llm_token.acquire(&client).await?;
831 let mut did_retry = false;
832
833 loop {
834 let request_builder = http_client::Request::builder().method(Method::POST);
835 let request_builder =
836 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
837 request_builder.uri(predict_edits_url)
838 } else {
839 request_builder.uri(
840 http_client
841 .build_zed_llm_url("/predict_edits/v2", &[])?
842 .as_ref(),
843 )
844 };
845 let request = request_builder
846 .header("Content-Type", "application/json")
847 .header("Authorization", format!("Bearer {}", token))
848 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
849 .body(serde_json::to_string(&body)?.into())?;
850
851 let mut response = http_client.send(request).await?;
852
853 if let Some(minimum_required_version) = response
854 .headers()
855 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
856 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
857 {
858 anyhow::ensure!(
859 app_version >= minimum_required_version,
860 ZedUpdateRequiredError {
861 minimum_version: minimum_required_version
862 }
863 );
864 }
865
866 if response.status().is_success() {
867 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
868
869 let mut body = String::new();
870 response.body_mut().read_to_string(&mut body).await?;
871 return Ok((serde_json::from_str(&body)?, usage));
872 } else if !did_retry
873 && response
874 .headers()
875 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
876 .is_some()
877 {
878 did_retry = true;
879 token = llm_token.refresh(&client).await?;
880 } else {
881 let mut body = String::new();
882 response.body_mut().read_to_string(&mut body).await?;
883 anyhow::bail!(
884 "error predicting edits.\nStatus: {:?}\nBody: {}",
885 response.status(),
886 body
887 );
888 }
889 }
890 }
891 }
892
893 fn accept_edit_prediction(
894 &mut self,
895 request_id: EditPredictionId,
896 cx: &mut Context<Self>,
897 ) -> Task<Result<()>> {
898 let client = self.client.clone();
899 let llm_token = self.llm_token.clone();
900 let app_version = AppVersion::global(cx);
901 cx.spawn(async move |this, cx| {
902 let http_client = client.http_client();
903 let mut response = llm_token_retry(&llm_token, &client, |token| {
904 let request_builder = http_client::Request::builder().method(Method::POST);
905 let request_builder =
906 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
907 request_builder.uri(accept_prediction_url)
908 } else {
909 request_builder.uri(
910 http_client
911 .build_zed_llm_url("/predict_edits/accept", &[])?
912 .as_ref(),
913 )
914 };
915 Ok(request_builder
916 .header("Content-Type", "application/json")
917 .header("Authorization", format!("Bearer {}", token))
918 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
919 .body(
920 serde_json::to_string(&AcceptEditPredictionBody {
921 request_id: request_id.0,
922 })?
923 .into(),
924 )?)
925 })
926 .await?;
927
928 if let Some(minimum_required_version) = response
929 .headers()
930 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
931 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
932 && app_version < minimum_required_version
933 {
934 return Err(anyhow!(ZedUpdateRequiredError {
935 minimum_version: minimum_required_version
936 }));
937 }
938
939 if response.status().is_success() {
940 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
941 this.update(cx, |this, cx| {
942 this.user_store.update(cx, |user_store, cx| {
943 user_store.update_edit_prediction_usage(usage, cx);
944 });
945 })?;
946 }
947
948 Ok(())
949 } else {
950 let mut body = String::new();
951 response.body_mut().read_to_string(&mut body).await?;
952 Err(anyhow!(
953 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
954 response.status(),
955 body
956 ))
957 }
958 })
959 }
960
961 fn process_completion_response(
962 prediction_response: PredictEditsResponse,
963 buffer: Entity<Buffer>,
964 snapshot: &BufferSnapshot,
965 editable_range: Range<usize>,
966 cursor_offset: usize,
967 path: Arc<Path>,
968 input_events: String,
969 input_excerpt: String,
970 buffer_snapshotted_at: Instant,
971 cx: &AsyncApp,
972 ) -> Task<Result<Option<EditPrediction>>> {
973 let snapshot = snapshot.clone();
974 let request_id = prediction_response.request_id;
975 let output_excerpt = prediction_response.output_excerpt;
976 cx.spawn(async move |cx| {
977 let output_excerpt: Arc<str> = output_excerpt.into();
978
979 let edits: Arc<[(Range<Anchor>, String)]> = cx
980 .background_spawn({
981 let output_excerpt = output_excerpt.clone();
982 let editable_range = editable_range.clone();
983 let snapshot = snapshot.clone();
984 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
985 })
986 .await?
987 .into();
988
989 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
990 let edits = edits.clone();
991 |buffer, cx| {
992 let new_snapshot = buffer.snapshot();
993 let edits: Arc<[(Range<Anchor>, String)]> =
994 interpolate(&snapshot, &new_snapshot, edits)?.into();
995 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
996 }
997 })?
998 else {
999 return anyhow::Ok(None);
1000 };
1001
1002 let edit_preview = edit_preview.await;
1003
1004 Ok(Some(EditPrediction {
1005 id: EditPredictionId(request_id),
1006 path,
1007 excerpt_range: editable_range,
1008 cursor_offset,
1009 edits,
1010 edit_preview,
1011 snapshot,
1012 input_events: input_events.into(),
1013 input_excerpt: input_excerpt.into(),
1014 output_excerpt,
1015 buffer_snapshotted_at,
1016 response_received_at: Instant::now(),
1017 }))
1018 })
1019 }
1020
1021 fn parse_edits(
1022 output_excerpt: Arc<str>,
1023 editable_range: Range<usize>,
1024 snapshot: &BufferSnapshot,
1025 ) -> Result<Vec<(Range<Anchor>, String)>> {
1026 let content = output_excerpt.replace(CURSOR_MARKER, "");
1027
1028 let start_markers = content
1029 .match_indices(EDITABLE_REGION_START_MARKER)
1030 .collect::<Vec<_>>();
1031 anyhow::ensure!(
1032 start_markers.len() == 1,
1033 "expected exactly one start marker, found {}",
1034 start_markers.len()
1035 );
1036
1037 let end_markers = content
1038 .match_indices(EDITABLE_REGION_END_MARKER)
1039 .collect::<Vec<_>>();
1040 anyhow::ensure!(
1041 end_markers.len() == 1,
1042 "expected exactly one end marker, found {}",
1043 end_markers.len()
1044 );
1045
1046 let sof_markers = content
1047 .match_indices(START_OF_FILE_MARKER)
1048 .collect::<Vec<_>>();
1049 anyhow::ensure!(
1050 sof_markers.len() <= 1,
1051 "expected at most one start-of-file marker, found {}",
1052 sof_markers.len()
1053 );
1054
1055 let codefence_start = start_markers[0].0;
1056 let content = &content[codefence_start..];
1057
1058 let newline_ix = content.find('\n').context("could not find newline")?;
1059 let content = &content[newline_ix + 1..];
1060
1061 let codefence_end = content
1062 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1063 .context("could not find end marker")?;
1064 let new_text = &content[..codefence_end];
1065
1066 let old_text = snapshot
1067 .text_for_range(editable_range.clone())
1068 .collect::<String>();
1069
1070 Ok(Self::compute_edits(
1071 old_text,
1072 new_text,
1073 editable_range.start,
1074 snapshot,
1075 ))
1076 }
1077
1078 pub fn compute_edits(
1079 old_text: String,
1080 new_text: &str,
1081 offset: usize,
1082 snapshot: &BufferSnapshot,
1083 ) -> Vec<(Range<Anchor>, String)> {
1084 text_diff(&old_text, new_text)
1085 .into_iter()
1086 .map(|(mut old_range, new_text)| {
1087 old_range.start += offset;
1088 old_range.end += offset;
1089
1090 let prefix_len = common_prefix(
1091 snapshot.chars_for_range(old_range.clone()),
1092 new_text.chars(),
1093 );
1094 old_range.start += prefix_len;
1095
1096 let suffix_len = common_prefix(
1097 snapshot.reversed_chars_for_range(old_range.clone()),
1098 new_text[prefix_len..].chars().rev(),
1099 );
1100 old_range.end = old_range.end.saturating_sub(suffix_len);
1101
1102 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1103 let range = if old_range.is_empty() {
1104 let anchor = snapshot.anchor_after(old_range.start);
1105 anchor..anchor
1106 } else {
1107 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1108 };
1109 (range, new_text)
1110 })
1111 .collect()
1112 }
1113
1114 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1115 self.rated_completions.contains(&completion_id)
1116 }
1117
1118 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1119 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1120 let completion = self.shown_completions.pop_back().unwrap();
1121 self.rated_completions.remove(&completion.id);
1122 }
1123 self.shown_completions.push_front(completion.clone());
1124 cx.notify();
1125 }
1126
1127 pub fn rate_completion(
1128 &mut self,
1129 completion: &EditPrediction,
1130 rating: EditPredictionRating,
1131 feedback: String,
1132 cx: &mut Context<Self>,
1133 ) {
1134 self.rated_completions.insert(completion.id);
1135 telemetry::event!(
1136 "Edit Prediction Rated",
1137 rating,
1138 input_events = completion.input_events,
1139 input_excerpt = completion.input_excerpt,
1140 output_excerpt = completion.output_excerpt,
1141 feedback
1142 );
1143 self.client.telemetry().flush_events().detach();
1144 cx.notify();
1145 }
1146
1147 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1148 self.shown_completions.iter()
1149 }
1150
1151 pub fn shown_completions_len(&self) -> usize {
1152 self.shown_completions.len()
1153 }
1154
1155 fn report_changes_for_buffer(
1156 &mut self,
1157 buffer: &Entity<Buffer>,
1158 cx: &mut Context<Self>,
1159 ) -> BufferSnapshot {
1160 self.register_buffer(buffer, cx);
1161
1162 let registered_buffer = self
1163 .registered_buffers
1164 .get_mut(&buffer.entity_id())
1165 .unwrap();
1166 let new_snapshot = buffer.read(cx).snapshot();
1167
1168 if new_snapshot.version != registered_buffer.snapshot.version {
1169 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1170 self.push_event(Event::BufferChange {
1171 old_snapshot,
1172 new_snapshot: new_snapshot.clone(),
1173 timestamp: Instant::now(),
1174 });
1175 }
1176
1177 new_snapshot
1178 }
1179
1180 fn load_data_collection_choices() -> DataCollectionChoice {
1181 let choice = KEY_VALUE_STORE
1182 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1183 .log_err()
1184 .flatten();
1185
1186 match choice.as_deref() {
1187 Some("true") => DataCollectionChoice::Enabled,
1188 Some("false") => DataCollectionChoice::Disabled,
1189 Some(_) => {
1190 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1191 DataCollectionChoice::NotAnswered
1192 }
1193 None => DataCollectionChoice::NotAnswered,
1194 }
1195 }
1196
1197 fn gather_additional_context(
1198 &mut self,
1199 cursor_point: language::Point,
1200 cursor_offset: usize,
1201 snapshot: BufferSnapshot,
1202 buffer_snapshotted_at: &Instant,
1203 project_path: ProjectPath,
1204 project: Option<&Entity<Project>>,
1205 cx: &mut Context<Self>,
1206 ) -> Option<Task<PredictEditsAdditionalContext>> {
1207 let project = project?.read(cx);
1208 let entry = project.entry_for_path(&project_path, cx)?;
1209 if !worktree_entry_is_eligible_for_collection(&entry) {
1210 return None;
1211 }
1212
1213 let git_store = project.git_store().read(cx);
1214 let (repository, repo_path) =
1215 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1216 let repo_path_string = repo_path.to_str()?.to_string();
1217
1218 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1219 snapshot
1220 .diagnostics
1221 .iter()
1222 .filter_map(|(language_server_id, diagnostics)| {
1223 let language_server =
1224 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1225 Some((
1226 *language_server_id,
1227 language_server.name(),
1228 diagnostics.clone(),
1229 ))
1230 })
1231 .collect()
1232 } else {
1233 Vec::new()
1234 };
1235
1236 repository.update(cx, |repository, cx| {
1237 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1238 let remote_origin_url = repository.remote_origin_url.clone();
1239 let remote_upstream_url = repository.remote_upstream_url.clone();
1240 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1241
1242 // group, resolve, and select diagnostics on a background thread
1243 Some(cx.background_spawn(async move {
1244 let mut diagnostic_groups_with_name = Vec::new();
1245 for (language_server_id, language_server_name, diagnostics) in
1246 diagnostics.into_iter()
1247 {
1248 let mut groups = Vec::new();
1249 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1250 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1251 (
1252 language_server_name.clone(),
1253 group.resolve::<usize>(&snapshot),
1254 )
1255 }));
1256 }
1257
1258 // sort by proximity to cursor
1259 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1260 let range = &group.entries[group.primary_ix].range;
1261 if range.start >= cursor_offset {
1262 range.start - cursor_offset
1263 } else if cursor_offset >= range.end {
1264 cursor_offset - range.end
1265 } else {
1266 (cursor_offset - range.start).min(range.end - cursor_offset)
1267 }
1268 });
1269
1270 let mut diagnostic_groups = Vec::new();
1271 let mut diagnostic_groups_truncated = false;
1272 let mut diagnostics_byte_count = 0;
1273 for (name, group) in diagnostic_groups_with_name {
1274 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1275 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1276 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1277 diagnostic_groups_truncated = true;
1278 break;
1279 }
1280 diagnostic_groups.push((name.to_string(), raw_value));
1281 }
1282
1283 PredictEditsAdditionalContext {
1284 input_path: repo_path_string,
1285 cursor_point: to_cloud_llm_client_point(cursor_point),
1286 cursor_offset: cursor_offset,
1287 git_info: PredictEditsGitInfo {
1288 head_sha: Some(head_sha),
1289 remote_origin_url,
1290 remote_upstream_url,
1291 },
1292 diagnostic_groups,
1293 diagnostic_groups_truncated,
1294 recent_files,
1295 }
1296 }))
1297 })
1298 }
1299
1300 fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1301 if !self.data_collection_choice.read(cx).is_enabled() {
1302 self.recent_editors.clear();
1303 self.last_activity_state = None;
1304 return;
1305 }
1306 if let Some(active_editor) = self
1307 .workspace
1308 .read_with(cx, |workspace, cx| {
1309 workspace
1310 .active_item(cx)
1311 .and_then(|item| item.act_as::<Editor>(cx))
1312 })
1313 .ok()
1314 .flatten()
1315 {
1316 let now = Instant::now();
1317 let editor = active_editor.downgrade();
1318 let existing_recent_editor = if let Some(existing_ix) = self
1319 .recent_editors
1320 .iter()
1321 .rposition(|recent| &recent.editor == &editor)
1322 {
1323 if existing_ix + 1 != self.recent_editors.len() {
1324 self.last_activity_state = None;
1325 }
1326 self.recent_editors.remove(existing_ix)
1327 } else {
1328 None
1329 };
1330 let new_recent = RecentEditor {
1331 editor: active_editor.downgrade(),
1332 last_active_at: now,
1333 activation_count: existing_recent_editor
1334 .as_ref()
1335 .map_or(0, |recent| recent.activation_count + 1),
1336 cumulative_time_navigating: existing_recent_editor
1337 .as_ref()
1338 .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1339 cumulative_time_editing: existing_recent_editor
1340 .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1341 };
1342 // filter out rapid changes in active item, particularly since this can happen rapidly when
1343 // a workspace is loaded.
1344 if let Some(previous_recent) = self.recent_editors.back_mut()
1345 && previous_recent.activation_count == 1
1346 && now.duration_since(previous_recent.last_active_at)
1347 < MIN_TIME_BETWEEN_RECENT_FILES
1348 {
1349 *previous_recent = new_recent;
1350 return;
1351 }
1352 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1353 self.recent_editors.pop_front();
1354 }
1355 self.recent_editors.push_back(new_recent);
1356 }
1357 }
1358
1359 fn handle_activity_poll(
1360 &mut self,
1361 instant_before_delay: Option<Instant>,
1362 now: Instant,
1363 cx: &mut Context<Self>,
1364 ) {
1365 if !self.data_collection_choice.read(cx).is_enabled() {
1366 self.last_activity_state = None;
1367 return;
1368 }
1369 if let Some(recent_editor) = self.recent_editors.back()
1370 && let Some(editor) = recent_editor.editor.upgrade()
1371 {
1372 let (scroll_position, cursor_point, singleton_version) =
1373 editor.update(cx, |editor, cx| {
1374 let scroll_position = editor.scroll_position(cx);
1375 let cursor_point = editor.selections.newest(cx).head();
1376 let singleton_version = editor
1377 .buffer()
1378 .read(cx)
1379 .as_singleton()
1380 .map(|singleton_buffer| singleton_buffer.read(cx).version());
1381 (scroll_position, cursor_point, singleton_version)
1382 });
1383
1384 let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1385 last_activity_state.scroll_position != scroll_position
1386 || last_activity_state.cursor_point != cursor_point
1387 } else {
1388 false
1389 };
1390
1391 let edited = if let Some(singleton_version) = &singleton_version
1392 && let Some(last_activity_state) = &self.last_activity_state
1393 && let Some(last_singleton_version) = &last_activity_state.singleton_version
1394 {
1395 singleton_version.changed_since(last_singleton_version)
1396 } else {
1397 false
1398 };
1399
1400 self.last_activity_state = Some(ActivityState {
1401 scroll_position,
1402 cursor_point,
1403 singleton_version,
1404 });
1405
1406 let prior_recent_editor = if self.recent_editors.len() > 1 {
1407 Some(&self.recent_editors[self.recent_editors.len() - 2])
1408 } else {
1409 None
1410 };
1411 let additional_time: Option<Duration> =
1412 instant_before_delay.map(|instant_before_delay| {
1413 now.duration_since(prior_recent_editor.map_or(
1414 instant_before_delay,
1415 |prior_recent_editor| {
1416 prior_recent_editor.last_active_at.max(instant_before_delay)
1417 },
1418 ))
1419 });
1420
1421 if let Some(additional_time) = additional_time {
1422 let recent_editor = self.recent_editors.back_mut().unwrap();
1423 if navigated {
1424 recent_editor.cumulative_time_navigating += additional_time;
1425 }
1426 if edited {
1427 recent_editor.cumulative_time_editing += additional_time;
1428 }
1429 }
1430 }
1431 }
1432
1433 fn recent_files(
1434 &mut self,
1435 now: &Instant,
1436 repository: &Repository,
1437 cx: &mut App,
1438 ) -> Vec<PredictEditsRecentFile> {
1439 let Ok(project) = self
1440 .workspace
1441 .read_with(cx, |workspace, _cx| workspace.project().clone())
1442 else {
1443 return Vec::new();
1444 };
1445 let mut results = Vec::with_capacity(self.recent_editors.len());
1446 for ix in (0..self.recent_editors.len()).rev() {
1447 let recent_editor = &self.recent_editors[ix];
1448 let keep_entry = recent_editor
1449 .editor
1450 .update(cx, |editor, cx| {
1451 maybe!({
1452 let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1453 let multibuffer = editor.buffer().read(cx);
1454 let (buffer, cursor_point, _) =
1455 multibuffer.point_to_buffer_point(cursor, cx)?;
1456 let file = buffer.read(cx).file()?;
1457 if !file_is_eligible_for_collection(file.as_ref()) {
1458 return None;
1459 }
1460 let project_path = ProjectPath {
1461 worktree_id: file.worktree_id(cx),
1462 path: file.path().clone(),
1463 };
1464 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1465 if !worktree_entry_is_eligible_for_collection(entry) {
1466 return None;
1467 }
1468 let Some(repo_path) =
1469 repository.project_path_to_repo_path(&project_path, cx)
1470 else {
1471 // entry not removed since later queries may involve other repositories
1472 return Some(());
1473 };
1474 // paths may not be valid UTF-8
1475 let repo_path_str = repo_path.to_str()?;
1476 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1477 return None;
1478 }
1479 let active_to_now_ms = now
1480 .duration_since(recent_editor.last_active_at)
1481 .as_millis()
1482 .try_into()
1483 .ok()?;
1484 let cumulative_time_editing_ms = recent_editor
1485 .cumulative_time_editing
1486 .as_millis()
1487 .try_into()
1488 .ok()?;
1489 let cumulative_time_navigating_ms = recent_editor
1490 .cumulative_time_navigating
1491 .as_millis()
1492 .try_into()
1493 .ok()?;
1494 results.push(PredictEditsRecentFile {
1495 path: repo_path_str.to_string(),
1496 cursor_point: to_cloud_llm_client_point(cursor_point),
1497 active_to_now_ms,
1498 activation_count: recent_editor.activation_count,
1499 cumulative_time_editing_ms,
1500 cumulative_time_navigating_ms,
1501 is_multibuffer: !multibuffer.is_singleton(),
1502 });
1503 Some(())
1504 })
1505 })
1506 .ok()
1507 .flatten();
1508 if keep_entry.is_none() {
1509 self.recent_editors.remove(ix);
1510 }
1511 }
1512 results
1513 }
1514}
1515
1516fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1517 cloud_llm_client::Point {
1518 row: point.row,
1519 column: point.column,
1520 }
1521}
1522
1523fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1524 file.is_local() && !file.is_private()
1525}
1526
1527fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1528 entry.is_file()
1529 && entry.is_created()
1530 && !entry.is_ignored
1531 && !entry.is_private
1532 && !entry.is_external
1533 && !entry.is_fifo
1534}
1535
1536pub struct PerformPredictEditsParams {
1537 pub client: Arc<Client>,
1538 pub llm_token: LlmApiToken,
1539 pub app_version: SemanticVersion,
1540 pub body: PredictEditsBody,
1541}
1542
1543#[derive(Error, Debug)]
1544#[error(
1545 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1546)]
1547pub struct ZedUpdateRequiredError {
1548 minimum_version: SemanticVersion,
1549}
1550
1551fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1552 a.zip(b)
1553 .take_while(|(a, b)| a == b)
1554 .map(|(a, _)| a.len_utf8())
1555 .sum()
1556}
1557
1558pub struct GatherContextOutput {
1559 pub body: PredictEditsBody,
1560 pub editable_range: Range<usize>,
1561}
1562
1563pub async fn gather_context(
1564 full_path_str: String,
1565 snapshot: BufferSnapshot,
1566 cursor_point: language::Point,
1567 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1568 can_collect_data: CanCollectData,
1569) -> Result<GatherContextOutput> {
1570 let input_excerpt = excerpt_for_cursor_position(
1571 cursor_point,
1572 &full_path_str,
1573 &snapshot,
1574 MAX_REWRITE_TOKENS,
1575 MAX_CONTEXT_TOKENS,
1576 );
1577 let input_events = make_events_prompt();
1578 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1579
1580 let body = PredictEditsBody {
1581 input_events,
1582 input_excerpt: input_excerpt.prompt,
1583 can_collect_data: can_collect_data.0,
1584 diagnostic_groups: None,
1585 git_info: None,
1586 };
1587
1588 Ok(GatherContextOutput {
1589 body,
1590 editable_range,
1591 })
1592}
1593
1594fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1595 let mut result = String::new();
1596 for event in events.iter().rev() {
1597 let event_string = event.to_prompt();
1598 let event_tokens = tokens_for_bytes(event_string.len());
1599 if event_tokens > remaining_tokens {
1600 break;
1601 }
1602
1603 if !result.is_empty() {
1604 result.insert_str(0, "\n\n");
1605 }
1606 result.insert_str(0, &event_string);
1607 remaining_tokens -= event_tokens;
1608 }
1609 result
1610}
1611
1612struct RegisteredBuffer {
1613 snapshot: BufferSnapshot,
1614 _subscriptions: [gpui::Subscription; 2],
1615}
1616
1617#[derive(Clone)]
1618pub enum Event {
1619 BufferChange {
1620 old_snapshot: BufferSnapshot,
1621 new_snapshot: BufferSnapshot,
1622 timestamp: Instant,
1623 },
1624}
1625
1626impl Event {
1627 fn to_prompt(&self) -> String {
1628 match self {
1629 Event::BufferChange {
1630 old_snapshot,
1631 new_snapshot,
1632 ..
1633 } => {
1634 let mut prompt = String::new();
1635
1636 let old_path = old_snapshot
1637 .file()
1638 .map(|f| f.path().as_ref())
1639 .unwrap_or(Path::new("untitled"));
1640 let new_path = new_snapshot
1641 .file()
1642 .map(|f| f.path().as_ref())
1643 .unwrap_or(Path::new("untitled"));
1644 if old_path != new_path {
1645 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1646 }
1647
1648 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1649 if !diff.is_empty() {
1650 write!(
1651 prompt,
1652 "User edited {:?}:\n```diff\n{}\n```",
1653 new_path, diff
1654 )
1655 .unwrap();
1656 }
1657
1658 prompt
1659 }
1660 }
1661 }
1662}
1663
1664#[derive(Debug, Clone)]
1665struct CurrentEditPrediction {
1666 buffer_id: EntityId,
1667 completion: EditPrediction,
1668}
1669
1670impl CurrentEditPrediction {
1671 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1672 if self.buffer_id != old_completion.buffer_id {
1673 return true;
1674 }
1675
1676 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1677 return true;
1678 };
1679 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1680 return false;
1681 };
1682
1683 if old_edits.len() == 1 && new_edits.len() == 1 {
1684 let (old_range, old_text) = &old_edits[0];
1685 let (new_range, new_text) = &new_edits[0];
1686 new_range == old_range && new_text.starts_with(old_text)
1687 } else {
1688 true
1689 }
1690 }
1691}
1692
1693struct PendingCompletion {
1694 id: usize,
1695 _task: Task<()>,
1696}
1697
1698#[derive(Debug, Clone, Copy)]
1699pub enum DataCollectionChoice {
1700 NotAnswered,
1701 Enabled,
1702 Disabled,
1703}
1704
1705impl DataCollectionChoice {
1706 pub fn is_enabled(self) -> bool {
1707 match self {
1708 Self::Enabled => true,
1709 Self::NotAnswered | Self::Disabled => false,
1710 }
1711 }
1712
1713 pub fn is_answered(self) -> bool {
1714 match self {
1715 Self::Enabled | Self::Disabled => true,
1716 Self::NotAnswered => false,
1717 }
1718 }
1719
1720 pub fn toggle(&self) -> DataCollectionChoice {
1721 match self {
1722 Self::Enabled => Self::Disabled,
1723 Self::Disabled => Self::Enabled,
1724 Self::NotAnswered => Self::Enabled,
1725 }
1726 }
1727}
1728
1729impl From<bool> for DataCollectionChoice {
1730 fn from(value: bool) -> Self {
1731 match value {
1732 true => DataCollectionChoice::Enabled,
1733 false => DataCollectionChoice::Disabled,
1734 }
1735 }
1736}
1737
1738pub struct ProviderDataCollection {
1739 /// When set to None, data collection is not possible in the provider buffer
1740 choice: Option<Entity<DataCollectionChoice>>,
1741 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1742}
1743
1744#[derive(Debug, Clone, Copy)]
1745pub struct CanCollectData(pub bool);
1746
1747impl ProviderDataCollection {
1748 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1749 let choice_and_watcher = buffer.and_then(|buffer| {
1750 let file = buffer.read(cx).file()?;
1751
1752 if !file_is_eligible_for_collection(file.as_ref()) {
1753 return None;
1754 }
1755
1756 let zeta = zeta.read(cx);
1757 let choice = zeta.data_collection_choice.clone();
1758
1759 let license_detection_watcher = zeta
1760 .license_detection_watchers
1761 .get(&file.worktree_id(cx))
1762 .cloned()?;
1763
1764 Some((choice, license_detection_watcher))
1765 });
1766
1767 if let Some((choice, watcher)) = choice_and_watcher {
1768 ProviderDataCollection {
1769 choice: Some(choice),
1770 license_detection_watcher: Some(watcher),
1771 }
1772 } else {
1773 ProviderDataCollection {
1774 choice: None,
1775 license_detection_watcher: None,
1776 }
1777 }
1778 }
1779
1780 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1781 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1782 }
1783
1784 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1785 self.choice
1786 .as_ref()
1787 .is_some_and(|choice| choice.read(cx).is_enabled())
1788 }
1789
1790 fn is_project_open_source(&self) -> bool {
1791 self.license_detection_watcher
1792 .as_ref()
1793 .is_some_and(|watcher| watcher.is_project_open_source())
1794 }
1795
1796 pub fn toggle(&mut self, cx: &mut App) {
1797 if let Some(choice) = self.choice.as_mut() {
1798 let new_choice = choice.update(cx, |choice, _cx| {
1799 let new_choice = choice.toggle();
1800 *choice = new_choice;
1801 new_choice
1802 });
1803
1804 db::write_and_log(cx, move || {
1805 KEY_VALUE_STORE.write_kvp(
1806 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1807 new_choice.is_enabled().to_string(),
1808 )
1809 });
1810 }
1811 }
1812}
1813
1814async fn llm_token_retry(
1815 llm_token: &LlmApiToken,
1816 client: &Arc<Client>,
1817 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1818) -> Result<Response<AsyncBody>> {
1819 let mut did_retry = false;
1820 let http_client = client.http_client();
1821 let mut token = llm_token.acquire(client).await?;
1822 loop {
1823 let request = build_request(token.clone())?;
1824 let response = http_client.send(request).await?;
1825
1826 if !did_retry
1827 && !response.status().is_success()
1828 && response
1829 .headers()
1830 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1831 .is_some()
1832 {
1833 did_retry = true;
1834 token = llm_token.refresh(client).await?;
1835 continue;
1836 }
1837
1838 return Ok(response);
1839 }
1840}
1841
1842pub struct ZetaEditPredictionProvider {
1843 zeta: Entity<Zeta>,
1844 pending_completions: ArrayVec<PendingCompletion, 2>,
1845 next_pending_completion_id: usize,
1846 current_completion: Option<CurrentEditPrediction>,
1847 /// None if this is entirely disabled for this provider
1848 provider_data_collection: ProviderDataCollection,
1849 last_request_timestamp: Instant,
1850}
1851
1852impl ZetaEditPredictionProvider {
1853 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1854
1855 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1856 Self {
1857 zeta,
1858 pending_completions: ArrayVec::new(),
1859 next_pending_completion_id: 0,
1860 current_completion: None,
1861 provider_data_collection,
1862 last_request_timestamp: Instant::now(),
1863 }
1864 }
1865}
1866
1867impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1868 fn name() -> &'static str {
1869 "zed-predict"
1870 }
1871
1872 fn display_name() -> &'static str {
1873 "Zed's Edit Predictions"
1874 }
1875
1876 fn show_completions_in_menu() -> bool {
1877 true
1878 }
1879
1880 fn show_tab_accept_marker() -> bool {
1881 true
1882 }
1883
1884 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1885 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1886
1887 if self.provider_data_collection.is_data_collection_enabled(cx) {
1888 DataCollectionState::Enabled {
1889 is_project_open_source,
1890 }
1891 } else {
1892 DataCollectionState::Disabled {
1893 is_project_open_source,
1894 }
1895 }
1896 }
1897
1898 fn toggle_data_collection(&mut self, cx: &mut App) {
1899 self.provider_data_collection.toggle(cx);
1900 }
1901
1902 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1903 self.zeta.read(cx).usage(cx)
1904 }
1905
1906 fn is_enabled(
1907 &self,
1908 _buffer: &Entity<Buffer>,
1909 _cursor_position: language::Anchor,
1910 _cx: &App,
1911 ) -> bool {
1912 true
1913 }
1914 fn is_refreshing(&self) -> bool {
1915 !self.pending_completions.is_empty()
1916 }
1917
1918 fn refresh(
1919 &mut self,
1920 project: Option<Entity<Project>>,
1921 buffer: Entity<Buffer>,
1922 position: language::Anchor,
1923 _debounce: bool,
1924 cx: &mut Context<Self>,
1925 ) {
1926 if self.zeta.read(cx).update_required {
1927 return;
1928 }
1929
1930 if self
1931 .zeta
1932 .read(cx)
1933 .user_store
1934 .read_with(cx, |user_store, _cx| {
1935 user_store.account_too_young() || user_store.has_overdue_invoices()
1936 })
1937 {
1938 return;
1939 }
1940
1941 if let Some(current_completion) = self.current_completion.as_ref() {
1942 let snapshot = buffer.read(cx).snapshot();
1943 if current_completion
1944 .completion
1945 .interpolate(&snapshot)
1946 .is_some()
1947 {
1948 return;
1949 }
1950 }
1951
1952 let pending_completion_id = self.next_pending_completion_id;
1953 self.next_pending_completion_id += 1;
1954 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1955 let last_request_timestamp = self.last_request_timestamp;
1956
1957 let task = cx.spawn(async move |this, cx| {
1958 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1959 .checked_duration_since(Instant::now())
1960 {
1961 cx.background_executor().timer(timeout).await;
1962 }
1963
1964 let completion_request = this.update(cx, |this, cx| {
1965 this.last_request_timestamp = Instant::now();
1966 this.zeta.update(cx, |zeta, cx| {
1967 zeta.request_completion(project, &buffer, position, can_collect_data, cx)
1968 })
1969 });
1970
1971 let completion = match completion_request {
1972 Ok(completion_request) => {
1973 let completion_request = completion_request.await;
1974 completion_request.map(|c| {
1975 c.map(|completion| CurrentEditPrediction {
1976 buffer_id: buffer.entity_id(),
1977 completion,
1978 })
1979 })
1980 }
1981 Err(error) => Err(error),
1982 };
1983 let Some(new_completion) = completion
1984 .context("edit prediction failed")
1985 .log_err()
1986 .flatten()
1987 else {
1988 this.update(cx, |this, cx| {
1989 if this.pending_completions[0].id == pending_completion_id {
1990 this.pending_completions.remove(0);
1991 } else {
1992 this.pending_completions.clear();
1993 }
1994
1995 cx.notify();
1996 })
1997 .ok();
1998 return;
1999 };
2000
2001 this.update(cx, |this, cx| {
2002 if this.pending_completions[0].id == pending_completion_id {
2003 this.pending_completions.remove(0);
2004 } else {
2005 this.pending_completions.clear();
2006 }
2007
2008 if let Some(old_completion) = this.current_completion.as_ref() {
2009 let snapshot = buffer.read(cx).snapshot();
2010 if new_completion.should_replace_completion(old_completion, &snapshot) {
2011 this.zeta.update(cx, |zeta, cx| {
2012 zeta.completion_shown(&new_completion.completion, cx);
2013 });
2014 this.current_completion = Some(new_completion);
2015 }
2016 } else {
2017 this.zeta.update(cx, |zeta, cx| {
2018 zeta.completion_shown(&new_completion.completion, cx);
2019 });
2020 this.current_completion = Some(new_completion);
2021 }
2022
2023 cx.notify();
2024 })
2025 .ok();
2026 });
2027
2028 // We always maintain at most two pending completions. When we already
2029 // have two, we replace the newest one.
2030 if self.pending_completions.len() <= 1 {
2031 self.pending_completions.push(PendingCompletion {
2032 id: pending_completion_id,
2033 _task: task,
2034 });
2035 } else if self.pending_completions.len() == 2 {
2036 self.pending_completions.pop();
2037 self.pending_completions.push(PendingCompletion {
2038 id: pending_completion_id,
2039 _task: task,
2040 });
2041 }
2042 }
2043
2044 fn cycle(
2045 &mut self,
2046 _buffer: Entity<Buffer>,
2047 _cursor_position: language::Anchor,
2048 _direction: edit_prediction::Direction,
2049 _cx: &mut Context<Self>,
2050 ) {
2051 // Right now we don't support cycling.
2052 }
2053
2054 fn accept(&mut self, cx: &mut Context<Self>) {
2055 let completion_id = self
2056 .current_completion
2057 .as_ref()
2058 .map(|completion| completion.completion.id);
2059 if let Some(completion_id) = completion_id {
2060 self.zeta
2061 .update(cx, |zeta, cx| {
2062 zeta.accept_edit_prediction(completion_id, cx)
2063 })
2064 .detach();
2065 }
2066 self.pending_completions.clear();
2067 }
2068
2069 fn discard(&mut self, _cx: &mut Context<Self>) {
2070 self.pending_completions.clear();
2071 self.current_completion.take();
2072 }
2073
2074 fn suggest(
2075 &mut self,
2076 buffer: &Entity<Buffer>,
2077 cursor_position: language::Anchor,
2078 cx: &mut Context<Self>,
2079 ) -> Option<edit_prediction::EditPrediction> {
2080 let CurrentEditPrediction {
2081 buffer_id,
2082 completion,
2083 ..
2084 } = self.current_completion.as_mut()?;
2085
2086 // Invalidate previous completion if it was generated for a different buffer.
2087 if *buffer_id != buffer.entity_id() {
2088 self.current_completion.take();
2089 return None;
2090 }
2091
2092 let buffer = buffer.read(cx);
2093 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2094 self.current_completion.take();
2095 return None;
2096 };
2097
2098 let cursor_row = cursor_position.to_point(buffer).row;
2099 let (closest_edit_ix, (closest_edit_range, _)) =
2100 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2101 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2102 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2103 cmp::min(distance_from_start, distance_from_end)
2104 })?;
2105
2106 let mut edit_start_ix = closest_edit_ix;
2107 for (range, _) in edits[..edit_start_ix].iter().rev() {
2108 let distance_from_closest_edit =
2109 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2110 if distance_from_closest_edit <= 1 {
2111 edit_start_ix -= 1;
2112 } else {
2113 break;
2114 }
2115 }
2116
2117 let mut edit_end_ix = closest_edit_ix + 1;
2118 for (range, _) in &edits[edit_end_ix..] {
2119 let distance_from_closest_edit =
2120 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2121 if distance_from_closest_edit <= 1 {
2122 edit_end_ix += 1;
2123 } else {
2124 break;
2125 }
2126 }
2127
2128 Some(edit_prediction::EditPrediction {
2129 id: Some(completion.id.to_string().into()),
2130 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2131 edit_preview: Some(completion.edit_preview.clone()),
2132 })
2133 }
2134}
2135
2136fn tokens_for_bytes(bytes: usize) -> usize {
2137 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2138 /// intentionally low to err on the side of underestimating limits.
2139 const BYTES_PER_TOKEN_GUESS: usize = 3;
2140 bytes / BYTES_PER_TOKEN_GUESS
2141}
2142
2143#[cfg(test)]
2144mod tests {
2145 use client::UserStore;
2146 use client::test::FakeServer;
2147 use clock::FakeSystemClock;
2148 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2149 use gpui::TestAppContext;
2150 use http_client::FakeHttpClient;
2151 use indoc::indoc;
2152 use language::Point;
2153 use settings::SettingsStore;
2154
2155 use super::*;
2156
2157 #[gpui::test]
2158 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2159 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2160 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2161 to_completion_edits(
2162 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2163 &buffer,
2164 cx,
2165 )
2166 .into()
2167 });
2168
2169 let edit_preview = cx
2170 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2171 .await;
2172
2173 let completion = EditPrediction {
2174 edits,
2175 edit_preview,
2176 path: Path::new("").into(),
2177 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2178 id: EditPredictionId(Uuid::new_v4()),
2179 excerpt_range: 0..0,
2180 cursor_offset: 0,
2181 input_events: "".into(),
2182 input_excerpt: "".into(),
2183 output_excerpt: "".into(),
2184 buffer_snapshotted_at: Instant::now(),
2185 response_received_at: Instant::now(),
2186 };
2187
2188 cx.update(|cx| {
2189 assert_eq!(
2190 from_completion_edits(
2191 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2192 &buffer,
2193 cx
2194 ),
2195 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2196 );
2197
2198 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2199 assert_eq!(
2200 from_completion_edits(
2201 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2202 &buffer,
2203 cx
2204 ),
2205 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2206 );
2207
2208 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2209 assert_eq!(
2210 from_completion_edits(
2211 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2212 &buffer,
2213 cx
2214 ),
2215 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2216 );
2217
2218 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2219 assert_eq!(
2220 from_completion_edits(
2221 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2222 &buffer,
2223 cx
2224 ),
2225 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2226 );
2227
2228 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2229 assert_eq!(
2230 from_completion_edits(
2231 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2232 &buffer,
2233 cx
2234 ),
2235 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2236 );
2237
2238 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2239 assert_eq!(
2240 from_completion_edits(
2241 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2242 &buffer,
2243 cx
2244 ),
2245 vec![(9..11, "".to_string())]
2246 );
2247
2248 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2249 assert_eq!(
2250 from_completion_edits(
2251 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2252 &buffer,
2253 cx
2254 ),
2255 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2256 );
2257
2258 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2259 assert_eq!(
2260 from_completion_edits(
2261 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2262 &buffer,
2263 cx
2264 ),
2265 vec![(4..4, "M".to_string())]
2266 );
2267
2268 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2269 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2270 })
2271 }
2272
2273 #[gpui::test]
2274 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2275 cx.update(|cx| {
2276 let settings_store = SettingsStore::test(cx);
2277 cx.set_global(settings_store);
2278 client::init_settings(cx);
2279 });
2280
2281 let edits = edits_for_prediction(
2282 indoc! {"
2283 fn main() {
2284 let word_1 = \"lorem\";
2285 let range = word.len()..word.len();
2286 }
2287 "},
2288 indoc! {"
2289 <|editable_region_start|>
2290 fn main() {
2291 let word_1 = \"lorem\";
2292 let range = word_1.len()..word_1.len();
2293 }
2294
2295 <|editable_region_end|>
2296 "},
2297 cx,
2298 )
2299 .await;
2300 assert_eq!(
2301 edits,
2302 [
2303 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2304 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2305 ]
2306 );
2307
2308 let edits = edits_for_prediction(
2309 indoc! {"
2310 fn main() {
2311 let story = \"the quick\"
2312 }
2313 "},
2314 indoc! {"
2315 <|editable_region_start|>
2316 fn main() {
2317 let story = \"the quick brown fox jumps over the lazy dog\";
2318 }
2319
2320 <|editable_region_end|>
2321 "},
2322 cx,
2323 )
2324 .await;
2325 assert_eq!(
2326 edits,
2327 [
2328 (
2329 Point::new(1, 26)..Point::new(1, 26),
2330 " brown fox jumps over the lazy dog".to_string()
2331 ),
2332 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2333 ]
2334 );
2335 }
2336
2337 #[gpui::test]
2338 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2339 cx.update(|cx| {
2340 let settings_store = SettingsStore::test(cx);
2341 cx.set_global(settings_store);
2342 client::init_settings(cx);
2343 });
2344
2345 let buffer_content = "lorem\n";
2346 let completion_response = indoc! {"
2347 ```animals.js
2348 <|start_of_file|>
2349 <|editable_region_start|>
2350 lorem
2351 ipsum
2352 <|editable_region_end|>
2353 ```"};
2354
2355 let http_client = FakeHttpClient::create(move |req| async move {
2356 match (req.method(), req.uri().path()) {
2357 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2358 .status(200)
2359 .body(
2360 serde_json::to_string(&CreateLlmTokenResponse {
2361 token: LlmToken("the-llm-token".to_string()),
2362 })
2363 .unwrap()
2364 .into(),
2365 )
2366 .unwrap()),
2367 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2368 .status(200)
2369 .body(
2370 serde_json::to_string(&PredictEditsResponse {
2371 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2372 .unwrap(),
2373 output_excerpt: completion_response.to_string(),
2374 })
2375 .unwrap()
2376 .into(),
2377 )
2378 .unwrap()),
2379 _ => Ok(http_client::Response::builder()
2380 .status(404)
2381 .body("Not Found".into())
2382 .unwrap()),
2383 }
2384 });
2385
2386 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2387 cx.update(|cx| {
2388 RefreshLlmTokenListener::register(client.clone(), cx);
2389 });
2390 // Construct the fake server to authenticate.
2391 let _server = FakeServer::for_client(42, &client, cx).await;
2392 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2393 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2394
2395 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2396 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2397 let completion_task = zeta.update(cx, |zeta, cx| {
2398 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2399 });
2400
2401 let completion = completion_task.await.unwrap().unwrap();
2402 buffer.update(cx, |buffer, cx| {
2403 buffer.edit(completion.edits.iter().cloned(), None, cx)
2404 });
2405 assert_eq!(
2406 buffer.read_with(cx, |buffer, _| buffer.text()),
2407 "lorem\nipsum"
2408 );
2409 }
2410
2411 async fn edits_for_prediction(
2412 buffer_content: &str,
2413 completion_response: &str,
2414 cx: &mut TestAppContext,
2415 ) -> Vec<(Range<Point>, String)> {
2416 let completion_response = completion_response.to_string();
2417 let http_client = FakeHttpClient::create(move |req| {
2418 let completion = completion_response.clone();
2419 async move {
2420 match (req.method(), req.uri().path()) {
2421 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2422 .status(200)
2423 .body(
2424 serde_json::to_string(&CreateLlmTokenResponse {
2425 token: LlmToken("the-llm-token".to_string()),
2426 })
2427 .unwrap()
2428 .into(),
2429 )
2430 .unwrap()),
2431 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2432 .status(200)
2433 .body(
2434 serde_json::to_string(&PredictEditsResponse {
2435 request_id: Uuid::new_v4(),
2436 output_excerpt: completion,
2437 })
2438 .unwrap()
2439 .into(),
2440 )
2441 .unwrap()),
2442 _ => Ok(http_client::Response::builder()
2443 .status(404)
2444 .body("Not Found".into())
2445 .unwrap()),
2446 }
2447 }
2448 });
2449
2450 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2451 cx.update(|cx| {
2452 RefreshLlmTokenListener::register(client.clone(), cx);
2453 });
2454 // Construct the fake server to authenticate.
2455 let _server = FakeServer::for_client(42, &client, cx).await;
2456 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2457 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2458
2459 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2460 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2461 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2462 let completion_task = zeta.update(cx, |zeta, cx| {
2463 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2464 });
2465
2466 let completion = completion_task.await.unwrap().unwrap();
2467 completion
2468 .edits
2469 .iter()
2470 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2471 .collect::<Vec<_>>()
2472 }
2473
2474 fn to_completion_edits(
2475 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2476 buffer: &Entity<Buffer>,
2477 cx: &App,
2478 ) -> Vec<(Range<Anchor>, String)> {
2479 let buffer = buffer.read(cx);
2480 iterator
2481 .into_iter()
2482 .map(|(range, text)| {
2483 (
2484 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2485 text,
2486 )
2487 })
2488 .collect()
2489 }
2490
2491 fn from_completion_edits(
2492 editor_edits: &[(Range<Anchor>, String)],
2493 buffer: &Entity<Buffer>,
2494 cx: &App,
2495 ) -> Vec<(Range<usize>, String)> {
2496 let buffer = buffer.read(cx);
2497 editor_edits
2498 .iter()
2499 .map(|(range, text)| {
2500 (
2501 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2502 text.clone(),
2503 )
2504 })
2505 .collect()
2506 }
2507
2508 #[ctor::ctor]
2509 fn init_logger() {
2510 zlog::init_test();
2511 }
2512}