1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8mod training_data_uploader;
9
10use arrayvec::ArrayVec;
11pub(crate) use completion_diff_element::*;
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction::DataCollectionState;
14use editor::Editor;
15pub use init::*;
16use license_detection::LicenseDetectionWatcher;
17use project::git_store::Repository;
18pub use rate_completion_modal::*;
19
20use anyhow::{Context as _, Result, anyhow};
21use client::{Client, EditPredictionUsage, UserStore};
22use cloud_llm_client::{
23 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
24 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
25 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
26};
27use collections::{HashMap, HashSet, VecDeque};
28use futures::AsyncReadExt;
29use gpui::{
30 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
31 SharedString, Subscription, Task, WeakEntity, actions,
32};
33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
34use input_excerpt::excerpt_for_cursor_position;
35use language::{
36 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
37};
38use language_model::{LlmApiToken, RefreshLlmTokenListener};
39use multi_buffer::MultiBufferPoint;
40use project::{Project, ProjectPath};
41use release_channel::AppVersion;
42use settings::WorktreeId;
43use std::collections::hash_map;
44use std::mem;
45use std::str::FromStr;
46use std::{
47 cmp,
48 fmt::Write,
49 future::Future,
50 ops::Range,
51 path::Path,
52 rc::Rc,
53 sync::Arc,
54 time::{Duration, Instant},
55};
56use telemetry_events::EditPredictionRating;
57use thiserror::Error;
58use util::{ResultExt, maybe};
59use uuid::Uuid;
60use workspace::Workspace;
61use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
62use worktree::Worktree;
63
64use crate::training_data_uploader::TrainingDataUploader;
65
66const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
67const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
68const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
69const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
70const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
71const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
72
73const MAX_CONTEXT_TOKENS: usize = 150;
74const MAX_REWRITE_TOKENS: usize = 350;
75const MAX_EVENT_TOKENS: usize = 500;
76
77/// Maximum number of events to track.
78const MAX_EVENT_COUNT: usize = 16;
79
80/// Maximum number of recent files to track.
81const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
82
83/// Minimum number of milliseconds between recent file entries.
84const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
85
86/// Maximum file path length to include in recent files list.
87const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
88
89/// Maximum number of JSON bytes for diagnostics in additional context.
90const MAX_DIAGNOSTICS_BYTES: usize = 4096;
91
92/// Maximum number of edit predictions to store for feedback.
93const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
94
95/// Interval between polls tracking time editing files.
96const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
97
98/// Interval between polls of whether data collection is enabled, when it is disabled.
99const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
100
101actions!(
102 edit_prediction,
103 [
104 /// Clears the edit prediction history.
105 ClearHistory
106 ]
107);
108
109#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
110pub struct EditPredictionId(Uuid);
111
112impl From<EditPredictionId> for gpui::ElementId {
113 fn from(value: EditPredictionId) -> Self {
114 gpui::ElementId::Uuid(value.0)
115 }
116}
117
118impl std::fmt::Display for EditPredictionId {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "{}", self.0)
121 }
122}
123
124struct ZedPredictUpsell;
125
126impl Dismissable for ZedPredictUpsell {
127 const KEY: &'static str = "dismissed-edit-predict-upsell";
128
129 fn dismissed() -> bool {
130 // To make this backwards compatible with older versions of Zed, we
131 // check if the user has seen the previous Edit Prediction Onboarding
132 // before, by checking the data collection choice which was written to
133 // the database once the user clicked on "Accept and Enable"
134 if KEY_VALUE_STORE
135 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
136 .log_err()
137 .is_some_and(|s| s.is_some())
138 {
139 return true;
140 }
141
142 KEY_VALUE_STORE
143 .read_kvp(Self::KEY)
144 .log_err()
145 .is_some_and(|s| s.is_some())
146 }
147}
148
149pub fn should_show_upsell_modal() -> bool {
150 !ZedPredictUpsell::dismissed()
151}
152
153#[derive(Clone)]
154struct ZetaGlobal(Entity<Zeta>);
155
156impl Global for ZetaGlobal {}
157
158#[derive(Clone)]
159pub struct EditPrediction {
160 id: EditPredictionId,
161 path: Arc<Path>,
162 excerpt_range: Range<usize>,
163 cursor_offset: usize,
164 edits: Arc<[(Range<Anchor>, String)]>,
165 snapshot: BufferSnapshot,
166 edit_preview: EditPreview,
167 input_events: Arc<str>,
168 input_excerpt: Arc<str>,
169 output_excerpt: Arc<str>,
170 buffer_snapshotted_at: Instant,
171 response_received_at: Instant,
172}
173
174impl EditPrediction {
175 fn latency(&self) -> Duration {
176 self.response_received_at
177 .duration_since(self.buffer_snapshotted_at)
178 }
179
180 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
181 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
182 }
183}
184
185fn interpolate(
186 old_snapshot: &BufferSnapshot,
187 new_snapshot: &BufferSnapshot,
188 current_edits: Arc<[(Range<Anchor>, String)]>,
189) -> Option<Vec<(Range<Anchor>, String)>> {
190 let mut edits = Vec::new();
191
192 let mut model_edits = current_edits.iter().peekable();
193 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
194 while let Some((model_old_range, _)) = model_edits.peek() {
195 let model_old_range = model_old_range.to_offset(old_snapshot);
196 if model_old_range.end < user_edit.old.start {
197 let (model_old_range, model_new_text) = model_edits.next().unwrap();
198 edits.push((model_old_range.clone(), model_new_text.clone()));
199 } else {
200 break;
201 }
202 }
203
204 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
205 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
206 if user_edit.old == model_old_offset_range {
207 let user_new_text = new_snapshot
208 .text_for_range(user_edit.new.clone())
209 .collect::<String>();
210
211 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
212 if !model_suffix.is_empty() {
213 let anchor = old_snapshot.anchor_after(user_edit.old.end);
214 edits.push((anchor..anchor, model_suffix.to_string()));
215 }
216
217 model_edits.next();
218 continue;
219 }
220 }
221 }
222
223 return None;
224 }
225
226 edits.extend(model_edits.cloned());
227
228 if edits.is_empty() { None } else { Some(edits) }
229}
230
231impl std::fmt::Debug for EditPrediction {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 f.debug_struct("EditPrediction")
234 .field("id", &self.id)
235 .field("path", &self.path)
236 .field("edits", &self.edits)
237 .finish_non_exhaustive()
238 }
239}
240
241pub struct Zeta {
242 client: Arc<Client>,
243 shown_completions: VecDeque<EditPrediction>,
244 rated_completions: HashSet<EditPredictionId>,
245 data_collection_choice: Entity<DataCollectionChoice>,
246 llm_token: LlmApiToken,
247 _llm_token_subscription: Subscription,
248 /// Whether an update to a newer version of Zed is required to continue using Zeta.
249 update_required: bool,
250 user_store: Entity<UserStore>,
251 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
252 projects: HashMap<EntityId, ZetaProject>,
253 /// todo! document that this should be the only place Entity<TrainingDataUploader> is stored.
254 training_data_uploader: Option<Entity<TrainingDataUploader>>,
255}
256
257struct ZetaProject {
258 events: VecDeque<Event>,
259 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
260 recent_editors: VecDeque<RecentEditor>,
261 last_activity_state: Option<ActivityState>,
262 _activity_poll_task: Option<Task<Result<()>>>,
263}
264
265struct RecentEditor {
266 editor: WeakEntity<Editor>,
267 last_active_at: Instant,
268 activation_count: u32,
269 cumulative_time_editing: Duration,
270 cumulative_time_navigating: Duration,
271}
272
273#[derive(Debug)]
274struct ActivityState {
275 scroll_position: gpui::Point<f32>,
276 cursor_point: MultiBufferPoint,
277 singleton_version: Option<clock::Global>,
278}
279
280impl Zeta {
281 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
282 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
283 }
284
285 pub fn register(
286 worktree: Option<Entity<Worktree>>,
287 client: Arc<Client>,
288 user_store: Entity<UserStore>,
289 cx: &mut App,
290 ) -> Entity<Self> {
291 let this = Self::global(cx).unwrap_or_else(|| {
292 let entity = cx.new(|cx| Self::new(client, user_store, cx));
293 cx.set_global(ZetaGlobal(entity.clone()));
294 entity
295 });
296
297 this.update(cx, move |this, cx| {
298 if let Some(worktree) = worktree {
299 let worktree_id = worktree.read(cx).id();
300 this.license_detection_watchers
301 .entry(worktree_id)
302 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
303 }
304 });
305
306 this
307 }
308
309 pub fn clear_history(&mut self) {
310 for zeta_project in self.projects.values_mut() {
311 zeta_project.events.clear();
312 }
313 }
314
315 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
316 self.user_store.read(cx).edit_prediction_usage()
317 }
318
319 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
320 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
321
322 let data_collection_choice = Self::load_data_collection_choices();
323 let data_collection_choice = cx.new(|_| data_collection_choice);
324
325 /* todo!
326 let mut activity_poll_task = None;
327
328 if let Some(workspace) = &workspace {
329 let project = workspace.read(cx).project().clone();
330 cx.subscribe(&project, |this, _project, event, cx| match event {
331 project::Event::ActiveEntryChanged(entry_id) => {
332 this.handle_active_project_entry_changed(cx)
333 }
334 _ => {}
335 })
336 .detach();
337
338 // TODO: ideally this would attend to window focus when tracking time, and pause the
339 // loop for efficiency when not focused.
340 activity_poll_task = Some(cx.spawn(async move |this, cx| {
341 let mut instant_before_delay = None;
342 loop {
343 let data_collection_is_enabled = this.read_with(cx, |this, cx| {
344 this.data_collection_choice.read(cx).is_enabled()
345 })?;
346 let interval = if data_collection_is_enabled {
347 ACTIVITY_POLL_INTERVAL
348 } else {
349 instant_before_delay = None;
350 DISABLED_ACTIVITY_POLL_INTERVAL
351 };
352 cx.background_executor().timer(interval).await;
353 this.update(cx, |this, cx| {
354 let now = Instant::now();
355 this.handle_activity_poll(instant_before_delay, now, cx);
356 instant_before_delay = Some(now);
357 })?
358 }
359 }));
360 }
361 */
362
363 Self {
364 client,
365 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
366 rated_completions: HashSet::default(),
367 data_collection_choice,
368 llm_token: LlmApiToken::default(),
369 _llm_token_subscription: cx.subscribe(
370 &refresh_llm_token_listener,
371 |this, _listener, _event, cx| {
372 let client = this.client.clone();
373 let llm_token = this.llm_token.clone();
374 cx.spawn(async move |_this, _cx| {
375 llm_token.refresh(&client).await?;
376 anyhow::Ok(())
377 })
378 .detach_and_log_err(cx);
379 },
380 ),
381 update_required: false,
382 license_detection_watchers: HashMap::default(),
383 user_store,
384 projects: HashMap::default(),
385 training_data_uploader: None,
386 }
387 }
388
389 fn get_mut_or_init_zeta_project(
390 &mut self,
391 project: &Entity<Project>,
392 cx: &mut Context<Self>,
393 ) -> &mut ZetaProject {
394 let project_id = project.entity_id();
395 match self.projects.entry(project_id) {
396 hash_map::Entry::Occupied(entry) => entry.into_mut(),
397 hash_map::Entry::Vacant(entry) => {
398 cx.observe_release(project, move |this, _, _cx| {
399 this.projects.remove(&project_id);
400 });
401 entry.insert(ZetaProject {
402 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
403 registered_buffers: HashMap::default(),
404 recent_editors: VecDeque::new(),
405 last_activity_state: None,
406 // todo!
407 _activity_poll_task: None,
408 })
409 }
410 }
411 }
412
413 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
414 let events = &mut zeta_project.events;
415
416 if let Some(Event::BufferChange {
417 new_snapshot: last_new_snapshot,
418 timestamp: last_timestamp,
419 ..
420 }) = events.back_mut()
421 {
422 // Coalesce edits for the same buffer when they happen one after the other.
423 let Event::BufferChange {
424 old_snapshot,
425 new_snapshot,
426 timestamp,
427 } = &event;
428
429 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
430 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
431 && old_snapshot.version == last_new_snapshot.version
432 {
433 *last_new_snapshot = new_snapshot.clone();
434 *last_timestamp = *timestamp;
435 return;
436 }
437 }
438
439 if events.len() >= MAX_EVENT_COUNT {
440 // These are halved instead of popping to improve prompt caching.
441 events.drain(..MAX_EVENT_COUNT / 2);
442 }
443
444 events.push_back(event);
445 }
446
447 pub fn register_buffer(
448 &mut self,
449 buffer: &Entity<Buffer>,
450 project: &Entity<Project>,
451 cx: &mut Context<Self>,
452 ) {
453 let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
454 Self::register_buffer_impl(zeta_project, buffer, project, cx);
455 }
456
457 fn register_buffer_impl<'a>(
458 zeta_project: &'a mut ZetaProject,
459 buffer: &Entity<Buffer>,
460 project: &Entity<Project>,
461 cx: &mut Context<Self>,
462 ) -> &'a mut RegisteredBuffer {
463 let buffer_id = buffer.entity_id();
464 match zeta_project.registered_buffers.entry(buffer_id) {
465 hash_map::Entry::Occupied(entry) => entry.into_mut(),
466 hash_map::Entry::Vacant(entry) => {
467 let snapshot = buffer.read(cx).snapshot();
468 let project_entity_id = project.entity_id();
469 entry.insert(RegisteredBuffer {
470 snapshot,
471 _subscriptions: [
472 cx.subscribe(buffer, {
473 let project = project.downgrade();
474 move |this, buffer, event, cx| {
475 if let language::BufferEvent::Edited = event
476 && let Some(project) = project.upgrade()
477 {
478 this.report_changes_for_buffer(&buffer, &project, cx);
479 }
480 }
481 }),
482 cx.observe_release(buffer, move |this, _buffer, _cx| {
483 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
484 else {
485 return;
486 };
487 zeta_project.registered_buffers.remove(&buffer_id);
488 }),
489 ],
490 })
491 }
492 }
493 }
494
495 fn request_completion_impl<F, R>(
496 &mut self,
497 project: &Entity<Project>,
498 buffer: &Entity<Buffer>,
499 cursor: language::Anchor,
500 can_collect_data: CanCollectData,
501 cx: &mut Context<Self>,
502 perform_predict_edits: F,
503 ) -> Task<Result<Option<EditPrediction>>>
504 where
505 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
506 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
507 + Send
508 + 'static,
509 {
510 let buffer = buffer.clone();
511 let buffer_snapshotted_at = Instant::now();
512 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
513 let zeta = cx.entity();
514 let events = self
515 .get_mut_or_init_zeta_project(project, cx)
516 .events
517 .clone();
518 let client = self.client.clone();
519 let llm_token = self.llm_token.clone();
520 let app_version = AppVersion::global(cx);
521
522 let full_path: Arc<Path> = snapshot
523 .file()
524 .map(|f| Arc::from(f.full_path(cx).as_path()))
525 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
526 let full_path_str = full_path.to_string_lossy().to_string();
527 let cursor_point = cursor.to_point(&snapshot);
528 let cursor_offset = cursor_point.to_offset(&snapshot);
529 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
530 let gather_task = cx.background_spawn(gather_context(
531 full_path_str,
532 snapshot.clone(),
533 cursor_point,
534 make_events_prompt,
535 can_collect_data,
536 ));
537
538 // todo! async
539 if matches!(can_collect_data, CanCollectData(true)) {
540 let training_data_uploader = match &self.training_data_uploader {
541 None => {
542 let training_data_uploader = cx.new(|cx| TrainingDataUploader::new(cx));
543 self.training_data_uploader = Some(training_data_uploader.clone());
544 &training_data_uploader
545 }
546 Some(training_data_uploader) => training_data_uploader,
547 };
548
549 training_data_uploader.update(cx, |training_data_uploader, cx| {
550 let project = project.read(cx);
551 let entry = project.entry_for_path(&project_path, cx)?;
552 if !worktree_entry_is_eligible_for_collection(&entry) {
553 return None;
554 }
555
556 let git_store = project.git_store().read(cx);
557 let (repository, repo_path) =
558 git_store.repository_and_path_for_project_path(&project_path, cx)?;
559 let repo_path_string = repo_path.to_str()?.to_string();
560 });
561 }
562
563 cx.spawn(async move |this, cx| {
564 let GatherContextOutput {
565 body,
566 editable_range,
567 } = gather_task.await?;
568 let done_gathering_context_at = Instant::now();
569
570 let additional_context_task: Option<Task<PredictEditsAdditionalContext>> = None;
571 /* todo!
572 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
573 && let Some(file) = snapshot.file()
574 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
575 {
576 // This is async to reduce latency of the edit predictions request. The downside is
577 // that it will see a slightly later state than was used when gathering context.
578 let snapshot = snapshot.clone();
579 let this = this.clone();
580 Some(cx.spawn(async move |cx| {
581 if let Ok(Some(task)) = this.update(cx, |this, cx| {
582 this.gather_additional_context(
583 cursor_point,
584 cursor_offset,
585 snapshot,
586 &buffer_snapshotted_at,
587 project_path,
588 &project,
589 cx,
590 )
591 }) {
592 Some(task.await)
593 } else {
594 None
595 }
596 }))
597 } else {
598 None
599 };
600 */
601
602 log::debug!(
603 "Events:\n{}\nExcerpt:\n{:?}",
604 body.input_events,
605 body.input_excerpt
606 );
607
608 let input_events = body.input_events.clone();
609 let input_excerpt = body.input_excerpt.clone();
610
611 let response = perform_predict_edits(PerformPredictEditsParams {
612 client,
613 llm_token,
614 app_version,
615 body,
616 })
617 .await;
618 let (response, usage) = match response {
619 Ok(response) => response,
620 Err(err) => {
621 if err.is::<ZedUpdateRequiredError>() {
622 cx.update(|cx| {
623 zeta.update(cx, |zeta, _cx| {
624 zeta.update_required = true;
625 });
626
627 let error_message: SharedString = err.to_string().into();
628 show_app_notification(
629 NotificationId::unique::<ZedUpdateRequiredError>(),
630 cx,
631 move |cx| {
632 cx.new(|cx| {
633 ErrorMessagePrompt::new(error_message.clone(), cx)
634 .with_link_button(
635 "Update Zed",
636 "https://zed.dev/releases",
637 )
638 })
639 },
640 );
641 })
642 .ok();
643 }
644
645 return Err(err);
646 }
647 };
648
649 let received_response_at = Instant::now();
650 log::debug!("completion response: {}", &response.output_excerpt);
651
652 if let Some(usage) = usage {
653 this.update(cx, |this, cx| {
654 this.user_store.update(cx, |user_store, cx| {
655 user_store.update_edit_prediction_usage(usage, cx);
656 });
657 })
658 .ok();
659 }
660
661 let request_id = response.request_id.clone();
662 let edit_prediction = Self::process_completion_response(
663 response,
664 buffer,
665 &snapshot,
666 editable_range,
667 cursor_offset,
668 full_path,
669 input_events,
670 input_excerpt,
671 buffer_snapshotted_at,
672 cx,
673 )
674 .await;
675
676 let finished_at = Instant::now();
677
678 // record latency for ~1% of requests
679 if rand::random::<u8>() <= 2 {
680 telemetry::event!(
681 "Edit Prediction Request",
682 context_latency = done_gathering_context_at
683 .duration_since(buffer_snapshotted_at)
684 .as_millis(),
685 request_latency = received_response_at
686 .duration_since(done_gathering_context_at)
687 .as_millis(),
688 process_latency = finished_at.duration_since(received_response_at).as_millis()
689 );
690 }
691
692 /* todo!
693 if let Some(additional_context_task) = additional_context_task {
694 cx.background_spawn(async move {
695 if let Some(additional_context) = additional_context_task.await {
696 telemetry::event!(
697 "Edit Prediction Additional Context",
698 request_id = request_id,
699 additional_context = additional_context
700 );
701 }
702 })
703 .detach();
704 }
705 */
706
707 edit_prediction
708 })
709 }
710
711 // Generates several example completions of various states to fill the Zeta completion modal
712 #[cfg(any(test, feature = "test-support"))]
713 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
714 /*
715 use language::Point;
716
717 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
718 And maybe a short line
719
720 Then a few lines
721
722 and then another
723 "#};
724
725 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
726 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
727
728 let completion_tasks = vec![
729 self.fake_completion(
730 None,
731 &buffer,
732 position,
733 PredictEditsResponse {
734 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
735 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
736a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
737[here's an edit]
738And maybe a short line
739Then a few lines
740and then another
741{EDITABLE_REGION_END_MARKER}
742 "),
743 },
744 cx,
745 ),
746 self.fake_completion(
747 None,
748 &buffer,
749 position,
750 PredictEditsResponse {
751 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
752 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
753a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
754And maybe a short line
755[and another edit]
756Then a few lines
757and then another
758{EDITABLE_REGION_END_MARKER}
759 "#),
760 },
761 cx,
762 ),
763 self.fake_completion(
764 None,
765 &buffer,
766 position,
767 PredictEditsResponse {
768 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
769 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
770a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
771And maybe a short line
772
773Then a few lines
774
775and then another
776{EDITABLE_REGION_END_MARKER}
777 "#),
778 },
779 cx,
780 ),
781 self.fake_completion(
782 None,
783 &buffer,
784 position,
785 PredictEditsResponse {
786 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
787 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
788a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
789And maybe a short line
790
791Then a few lines
792
793and then another
794{EDITABLE_REGION_END_MARKER}
795 "#),
796 },
797 cx,
798 ),
799 self.fake_completion(
800 None,
801 &buffer,
802 position,
803 PredictEditsResponse {
804 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
805 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
806a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
807And maybe a short line
808Then a few lines
809[a third completion]
810and then another
811{EDITABLE_REGION_END_MARKER}
812 "#),
813 },
814 cx,
815 ),
816 self.fake_completion(
817 None,
818 &buffer,
819 position,
820 PredictEditsResponse {
821 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
822 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
823a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
824And maybe a short line
825and then another
826[fourth completion example]
827{EDITABLE_REGION_END_MARKER}
828 "#),
829 },
830 cx,
831 ),
832 self.fake_completion(
833 None,
834 &buffer,
835 position,
836 PredictEditsResponse {
837 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
838 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
839a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
840And maybe a short line
841Then a few lines
842and then another
843[fifth and final completion]
844{EDITABLE_REGION_END_MARKER}
845 "#),
846 },
847 cx,
848 ),
849 ];
850
851 cx.spawn(async move |zeta, cx| {
852 for task in completion_tasks {
853 task.await.unwrap();
854 }
855
856 zeta.update(cx, |zeta, _cx| {
857 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
858 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
859 })
860 .ok();
861 })
862 */
863 todo!()
864 }
865
866 #[cfg(any(test, feature = "test-support"))]
867 pub fn fake_completion(
868 &mut self,
869 project: &Entity<Project>,
870 buffer: &Entity<Buffer>,
871 position: language::Anchor,
872 response: PredictEditsResponse,
873 cx: &mut Context<Self>,
874 ) -> Task<Result<Option<EditPrediction>>> {
875 use std::future::ready;
876
877 self.request_completion_impl(
878 project,
879 buffer,
880 position,
881 CanCollectData(false),
882 cx,
883 |_params| ready(Ok((response, None))),
884 )
885 }
886
887 pub fn request_completion(
888 &mut self,
889 project: &Entity<Project>,
890 buffer: &Entity<Buffer>,
891 position: language::Anchor,
892 can_collect_data: CanCollectData,
893 cx: &mut Context<Self>,
894 ) -> Task<Result<Option<EditPrediction>>> {
895 self.request_completion_impl(
896 project,
897 buffer,
898 position,
899 can_collect_data,
900 cx,
901 Self::perform_predict_edits,
902 )
903 }
904
905 pub fn perform_predict_edits(
906 params: PerformPredictEditsParams,
907 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
908 async move {
909 let PerformPredictEditsParams {
910 client,
911 llm_token,
912 app_version,
913 body,
914 ..
915 } = params;
916
917 let http_client = client.http_client();
918 let mut token = llm_token.acquire(&client).await?;
919 let mut did_retry = false;
920
921 loop {
922 let request_builder = http_client::Request::builder().method(Method::POST);
923 let request_builder =
924 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
925 request_builder.uri(predict_edits_url)
926 } else {
927 request_builder.uri(
928 http_client
929 .build_zed_llm_url("/predict_edits/v2", &[])?
930 .as_ref(),
931 )
932 };
933 let request = request_builder
934 .header("Content-Type", "application/json")
935 .header("Authorization", format!("Bearer {}", token))
936 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
937 .body(serde_json::to_string(&body)?.into())?;
938
939 let mut response = http_client.send(request).await?;
940
941 if let Some(minimum_required_version) = response
942 .headers()
943 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
944 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
945 {
946 anyhow::ensure!(
947 app_version >= minimum_required_version,
948 ZedUpdateRequiredError {
949 minimum_version: minimum_required_version
950 }
951 );
952 }
953
954 if response.status().is_success() {
955 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
956
957 let mut body = String::new();
958 response.body_mut().read_to_string(&mut body).await?;
959 return Ok((serde_json::from_str(&body)?, usage));
960 } else if !did_retry
961 && response
962 .headers()
963 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
964 .is_some()
965 {
966 did_retry = true;
967 token = llm_token.refresh(&client).await?;
968 } else {
969 let mut body = String::new();
970 response.body_mut().read_to_string(&mut body).await?;
971 anyhow::bail!(
972 "error predicting edits.\nStatus: {:?}\nBody: {}",
973 response.status(),
974 body
975 );
976 }
977 }
978 }
979 }
980
981 fn accept_edit_prediction(
982 &mut self,
983 request_id: EditPredictionId,
984 cx: &mut Context<Self>,
985 ) -> Task<Result<()>> {
986 let client = self.client.clone();
987 let llm_token = self.llm_token.clone();
988 let app_version = AppVersion::global(cx);
989 cx.spawn(async move |this, cx| {
990 let http_client = client.http_client();
991 let mut response = llm_token_retry(&llm_token, &client, |token| {
992 let request_builder = http_client::Request::builder().method(Method::POST);
993 let request_builder =
994 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
995 request_builder.uri(accept_prediction_url)
996 } else {
997 request_builder.uri(
998 http_client
999 .build_zed_llm_url("/predict_edits/accept", &[])?
1000 .as_ref(),
1001 )
1002 };
1003 Ok(request_builder
1004 .header("Content-Type", "application/json")
1005 .header("Authorization", format!("Bearer {}", token))
1006 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
1007 .body(
1008 serde_json::to_string(&AcceptEditPredictionBody {
1009 request_id: request_id.0,
1010 })?
1011 .into(),
1012 )?)
1013 })
1014 .await?;
1015
1016 if let Some(minimum_required_version) = response
1017 .headers()
1018 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1019 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1020 && app_version < minimum_required_version
1021 {
1022 return Err(anyhow!(ZedUpdateRequiredError {
1023 minimum_version: minimum_required_version
1024 }));
1025 }
1026
1027 if response.status().is_success() {
1028 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
1029 this.update(cx, |this, cx| {
1030 this.user_store.update(cx, |user_store, cx| {
1031 user_store.update_edit_prediction_usage(usage, cx);
1032 });
1033 })?;
1034 }
1035
1036 Ok(())
1037 } else {
1038 let mut body = String::new();
1039 response.body_mut().read_to_string(&mut body).await?;
1040 Err(anyhow!(
1041 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
1042 response.status(),
1043 body
1044 ))
1045 }
1046 })
1047 }
1048
1049 fn process_completion_response(
1050 prediction_response: PredictEditsResponse,
1051 buffer: Entity<Buffer>,
1052 snapshot: &BufferSnapshot,
1053 editable_range: Range<usize>,
1054 cursor_offset: usize,
1055 path: Arc<Path>,
1056 input_events: String,
1057 input_excerpt: String,
1058 buffer_snapshotted_at: Instant,
1059 cx: &AsyncApp,
1060 ) -> Task<Result<Option<EditPrediction>>> {
1061 let snapshot = snapshot.clone();
1062 let request_id = prediction_response.request_id;
1063 let output_excerpt = prediction_response.output_excerpt;
1064 cx.spawn(async move |cx| {
1065 let output_excerpt: Arc<str> = output_excerpt.into();
1066
1067 let edits: Arc<[(Range<Anchor>, String)]> = cx
1068 .background_spawn({
1069 let output_excerpt = output_excerpt.clone();
1070 let editable_range = editable_range.clone();
1071 let snapshot = snapshot.clone();
1072 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
1073 })
1074 .await?
1075 .into();
1076
1077 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
1078 let edits = edits.clone();
1079 |buffer, cx| {
1080 let new_snapshot = buffer.snapshot();
1081 let edits: Arc<[(Range<Anchor>, String)]> =
1082 interpolate(&snapshot, &new_snapshot, edits)?.into();
1083 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
1084 }
1085 })?
1086 else {
1087 return anyhow::Ok(None);
1088 };
1089
1090 let edit_preview = edit_preview.await;
1091
1092 Ok(Some(EditPrediction {
1093 id: EditPredictionId(request_id),
1094 path,
1095 excerpt_range: editable_range,
1096 cursor_offset,
1097 edits,
1098 edit_preview,
1099 snapshot,
1100 input_events: input_events.into(),
1101 input_excerpt: input_excerpt.into(),
1102 output_excerpt,
1103 buffer_snapshotted_at,
1104 response_received_at: Instant::now(),
1105 }))
1106 })
1107 }
1108
1109 fn parse_edits(
1110 output_excerpt: Arc<str>,
1111 editable_range: Range<usize>,
1112 snapshot: &BufferSnapshot,
1113 ) -> Result<Vec<(Range<Anchor>, String)>> {
1114 let content = output_excerpt.replace(CURSOR_MARKER, "");
1115
1116 let start_markers = content
1117 .match_indices(EDITABLE_REGION_START_MARKER)
1118 .collect::<Vec<_>>();
1119 anyhow::ensure!(
1120 start_markers.len() == 1,
1121 "expected exactly one start marker, found {}",
1122 start_markers.len()
1123 );
1124
1125 let end_markers = content
1126 .match_indices(EDITABLE_REGION_END_MARKER)
1127 .collect::<Vec<_>>();
1128 anyhow::ensure!(
1129 end_markers.len() == 1,
1130 "expected exactly one end marker, found {}",
1131 end_markers.len()
1132 );
1133
1134 let sof_markers = content
1135 .match_indices(START_OF_FILE_MARKER)
1136 .collect::<Vec<_>>();
1137 anyhow::ensure!(
1138 sof_markers.len() <= 1,
1139 "expected at most one start-of-file marker, found {}",
1140 sof_markers.len()
1141 );
1142
1143 let codefence_start = start_markers[0].0;
1144 let content = &content[codefence_start..];
1145
1146 let newline_ix = content.find('\n').context("could not find newline")?;
1147 let content = &content[newline_ix + 1..];
1148
1149 let codefence_end = content
1150 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1151 .context("could not find end marker")?;
1152 let new_text = &content[..codefence_end];
1153
1154 let old_text = snapshot
1155 .text_for_range(editable_range.clone())
1156 .collect::<String>();
1157
1158 Ok(Self::compute_edits(
1159 old_text,
1160 new_text,
1161 editable_range.start,
1162 snapshot,
1163 ))
1164 }
1165
1166 pub fn compute_edits(
1167 old_text: String,
1168 new_text: &str,
1169 offset: usize,
1170 snapshot: &BufferSnapshot,
1171 ) -> Vec<(Range<Anchor>, String)> {
1172 text_diff(&old_text, new_text)
1173 .into_iter()
1174 .map(|(mut old_range, new_text)| {
1175 old_range.start += offset;
1176 old_range.end += offset;
1177
1178 let prefix_len = common_prefix(
1179 snapshot.chars_for_range(old_range.clone()),
1180 new_text.chars(),
1181 );
1182 old_range.start += prefix_len;
1183
1184 let suffix_len = common_prefix(
1185 snapshot.reversed_chars_for_range(old_range.clone()),
1186 new_text[prefix_len..].chars().rev(),
1187 );
1188 old_range.end = old_range.end.saturating_sub(suffix_len);
1189
1190 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1191 let range = if old_range.is_empty() {
1192 let anchor = snapshot.anchor_after(old_range.start);
1193 anchor..anchor
1194 } else {
1195 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1196 };
1197 (range, new_text)
1198 })
1199 .collect()
1200 }
1201
1202 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1203 self.rated_completions.contains(&completion_id)
1204 }
1205
1206 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1207 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1208 let completion = self.shown_completions.pop_back().unwrap();
1209 self.rated_completions.remove(&completion.id);
1210 }
1211 self.shown_completions.push_front(completion.clone());
1212 cx.notify();
1213 }
1214
1215 pub fn rate_completion(
1216 &mut self,
1217 completion: &EditPrediction,
1218 rating: EditPredictionRating,
1219 feedback: String,
1220 cx: &mut Context<Self>,
1221 ) {
1222 self.rated_completions.insert(completion.id);
1223 telemetry::event!(
1224 "Edit Prediction Rated",
1225 rating,
1226 input_events = completion.input_events,
1227 input_excerpt = completion.input_excerpt,
1228 output_excerpt = completion.output_excerpt,
1229 feedback
1230 );
1231 self.client.telemetry().flush_events().detach();
1232 cx.notify();
1233 }
1234
1235 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1236 self.shown_completions.iter()
1237 }
1238
1239 pub fn shown_completions_len(&self) -> usize {
1240 self.shown_completions.len()
1241 }
1242
1243 fn report_changes_for_buffer(
1244 &mut self,
1245 buffer: &Entity<Buffer>,
1246 project: &Entity<Project>,
1247 cx: &mut Context<Self>,
1248 ) -> BufferSnapshot {
1249 let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
1250 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
1251
1252 let new_snapshot = buffer.read(cx).snapshot();
1253 if new_snapshot.version != registered_buffer.snapshot.version {
1254 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1255 Self::push_event(
1256 zeta_project,
1257 Event::BufferChange {
1258 old_snapshot,
1259 new_snapshot: new_snapshot.clone(),
1260 timestamp: Instant::now(),
1261 },
1262 );
1263 }
1264
1265 new_snapshot
1266 }
1267
1268 fn load_data_collection_choices() -> DataCollectionChoice {
1269 let choice = KEY_VALUE_STORE
1270 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1271 .log_err()
1272 .flatten();
1273
1274 match choice.as_deref() {
1275 Some("true") => DataCollectionChoice::Enabled,
1276 Some("false") => DataCollectionChoice::Disabled,
1277 Some(_) => {
1278 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1279 DataCollectionChoice::NotAnswered
1280 }
1281 None => DataCollectionChoice::NotAnswered,
1282 }
1283 }
1284
1285 /*
1286 fn gather_additional_context(
1287 &mut self,
1288 cursor_point: language::Point,
1289 cursor_offset: usize,
1290 snapshot: BufferSnapshot,
1291 buffer_snapshotted_at: &Instant,
1292 project_path: ProjectPath,
1293 project: &WeakEntity<Project>,
1294 cx: &mut Context<Self>,
1295 ) -> Option<Task<PredictEditsAdditionalContext>> {
1296 let project = project.upgrade()?.read(cx);
1297 let entry = project.entry_for_path(&project_path, cx)?;
1298 if !worktree_entry_is_eligible_for_collection(&entry) {
1299 return None;
1300 }
1301
1302 let git_store = project.git_store().read(cx);
1303 let (repository, repo_path) =
1304 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1305 let repo_path_string = repo_path.to_str()?.to_string();
1306
1307 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1308 snapshot
1309 .diagnostics
1310 .iter()
1311 .filter_map(|(language_server_id, diagnostics)| {
1312 let language_server =
1313 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1314 Some((
1315 *language_server_id,
1316 language_server.name(),
1317 diagnostics.clone(),
1318 ))
1319 })
1320 .collect()
1321 } else {
1322 Vec::new()
1323 };
1324
1325 repository.update(cx, |repository, cx| {
1326 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1327 let remote_origin_url = repository.remote_origin_url.clone();
1328 let remote_upstream_url = repository.remote_upstream_url.clone();
1329 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1330
1331 // group, resolve, and select diagnostics on a background thread
1332 Some(cx.background_spawn(async move {
1333 let mut diagnostic_groups_with_name = Vec::new();
1334 for (language_server_id, language_server_name, diagnostics) in
1335 diagnostics.into_iter()
1336 {
1337 let mut groups = Vec::new();
1338 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1339 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1340 (
1341 language_server_name.clone(),
1342 group.resolve::<usize>(&snapshot),
1343 )
1344 }));
1345 }
1346
1347 // sort by proximity to cursor
1348 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1349 let range = &group.entries[group.primary_ix].range;
1350 if range.start >= cursor_offset {
1351 range.start - cursor_offset
1352 } else if cursor_offset >= range.end {
1353 cursor_offset - range.end
1354 } else {
1355 (cursor_offset - range.start).min(range.end - cursor_offset)
1356 }
1357 });
1358
1359 let mut diagnostic_groups = Vec::new();
1360 let mut diagnostic_groups_truncated = false;
1361 let mut diagnostics_byte_count = 0;
1362 for (name, group) in diagnostic_groups_with_name {
1363 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1364 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1365 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1366 diagnostic_groups_truncated = true;
1367 break;
1368 }
1369 diagnostic_groups.push((name.to_string(), raw_value));
1370 }
1371
1372 PredictEditsAdditionalContext {
1373 input_path: repo_path_string,
1374 cursor_point: to_cloud_llm_client_point(cursor_point),
1375 cursor_offset: cursor_offset,
1376 git_info: PredictEditsGitInfo {
1377 head_sha: Some(head_sha),
1378 remote_origin_url,
1379 remote_upstream_url,
1380 },
1381 diagnostic_groups,
1382 diagnostic_groups_truncated,
1383 recent_files,
1384 }
1385 }))
1386 })
1387 }
1388
1389 fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1390 if !self.data_collection_choice.read(cx).is_enabled() {
1391 self.recent_editors.clear();
1392 self.last_activity_state = None;
1393 return;
1394 }
1395 if let Some(active_editor) = self
1396 .workspace
1397 .read_with(cx, |workspace, cx| {
1398 workspace
1399 .active_item(cx)
1400 .and_then(|item| item.act_as::<Editor>(cx))
1401 })
1402 .ok()
1403 .flatten()
1404 {
1405 let now = Instant::now();
1406 let editor = active_editor.downgrade();
1407 let existing_recent_editor = if let Some(existing_ix) = self
1408 .recent_editors
1409 .iter()
1410 .rposition(|recent| &recent.editor == &editor)
1411 {
1412 if existing_ix + 1 != self.recent_editors.len() {
1413 self.last_activity_state = None;
1414 }
1415 self.recent_editors.remove(existing_ix)
1416 } else {
1417 None
1418 };
1419 let new_recent = RecentEditor {
1420 editor: active_editor.downgrade(),
1421 last_active_at: now,
1422 activation_count: existing_recent_editor
1423 .as_ref()
1424 .map_or(0, |recent| recent.activation_count + 1),
1425 cumulative_time_navigating: existing_recent_editor
1426 .as_ref()
1427 .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1428 cumulative_time_editing: existing_recent_editor
1429 .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1430 };
1431 // filter out rapid changes in active item, particularly since this can happen rapidly when
1432 // a workspace is loaded.
1433 if let Some(previous_recent) = self.recent_editors.back_mut()
1434 && previous_recent.activation_count == 1
1435 && now.duration_since(previous_recent.last_active_at)
1436 < MIN_TIME_BETWEEN_RECENT_FILES
1437 {
1438 *previous_recent = new_recent;
1439 return;
1440 }
1441 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1442 self.recent_editors.pop_front();
1443 }
1444 self.recent_editors.push_back(new_recent);
1445 }
1446 }
1447
1448 fn handle_activity_poll(
1449 &mut self,
1450 instant_before_delay: Option<Instant>,
1451 now: Instant,
1452 cx: &mut Context<Self>,
1453 ) {
1454 if !self.data_collection_choice.read(cx).is_enabled() {
1455 self.last_activity_state = None;
1456 return;
1457 }
1458 if let Some(recent_editor) = self.recent_editors.back()
1459 && let Some(editor) = recent_editor.editor.upgrade()
1460 {
1461 let (scroll_position, cursor_point, singleton_version) =
1462 editor.update(cx, |editor, cx| {
1463 let scroll_position = editor.scroll_position(cx);
1464 let cursor_point = editor.selections.newest(cx).head();
1465 let singleton_version = editor
1466 .buffer()
1467 .read(cx)
1468 .as_singleton()
1469 .map(|singleton_buffer| singleton_buffer.read(cx).version());
1470 (scroll_position, cursor_point, singleton_version)
1471 });
1472
1473 let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1474 last_activity_state.scroll_position != scroll_position
1475 || last_activity_state.cursor_point != cursor_point
1476 } else {
1477 false
1478 };
1479
1480 let edited = if let Some(singleton_version) = &singleton_version
1481 && let Some(last_activity_state) = &self.last_activity_state
1482 && let Some(last_singleton_version) = &last_activity_state.singleton_version
1483 {
1484 singleton_version.changed_since(last_singleton_version)
1485 } else {
1486 false
1487 };
1488
1489 self.last_activity_state = Some(ActivityState {
1490 scroll_position,
1491 cursor_point,
1492 singleton_version,
1493 });
1494
1495 let prior_recent_editor = if self.recent_editors.len() > 1 {
1496 Some(&self.recent_editors[self.recent_editors.len() - 2])
1497 } else {
1498 None
1499 };
1500 let additional_time: Option<Duration> =
1501 instant_before_delay.map(|instant_before_delay| {
1502 now.duration_since(prior_recent_editor.map_or(
1503 instant_before_delay,
1504 |prior_recent_editor| {
1505 prior_recent_editor.last_active_at.max(instant_before_delay)
1506 },
1507 ))
1508 });
1509
1510 if let Some(additional_time) = additional_time {
1511 let recent_editor = self.recent_editors.back_mut().unwrap();
1512 if navigated {
1513 recent_editor.cumulative_time_navigating += additional_time;
1514 }
1515 if edited {
1516 recent_editor.cumulative_time_editing += additional_time;
1517 }
1518 }
1519 }
1520 }
1521
1522 fn recent_files(
1523 &mut self,
1524 now: &Instant,
1525 repository: &Repository,
1526 cx: &mut App,
1527 ) -> Vec<PredictEditsRecentFile> {
1528 let Ok(project) = self
1529 .workspace
1530 .read_with(cx, |workspace, _cx| workspace.project().clone())
1531 else {
1532 return Vec::new();
1533 };
1534 let mut results = Vec::with_capacity(self.recent_editors.len());
1535 for ix in (0..self.recent_editors.len()).rev() {
1536 let recent_editor = &self.recent_editors[ix];
1537 let keep_entry = recent_editor
1538 .editor
1539 .update(cx, |editor, cx| {
1540 maybe!({
1541 let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1542 let multibuffer = editor.buffer().read(cx);
1543 let (buffer, cursor_point, _) =
1544 multibuffer.point_to_buffer_point(cursor, cx)?;
1545 let file = buffer.read(cx).file()?;
1546 if !file_is_eligible_for_collection(file.as_ref()) {
1547 return None;
1548 }
1549 let project_path = ProjectPath {
1550 worktree_id: file.worktree_id(cx),
1551 path: file.path().clone(),
1552 };
1553 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1554 if !worktree_entry_is_eligible_for_collection(entry) {
1555 return None;
1556 }
1557 let Some(repo_path) =
1558 repository.project_path_to_repo_path(&project_path, cx)
1559 else {
1560 // entry not removed since later queries may involve other repositories
1561 return Some(());
1562 };
1563 // paths may not be valid UTF-8
1564 let repo_path_str = repo_path.to_str()?;
1565 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1566 return None;
1567 }
1568 let active_to_now_ms = now
1569 .duration_since(recent_editor.last_active_at)
1570 .as_millis()
1571 .try_into()
1572 .ok()?;
1573 let cumulative_time_editing_ms = recent_editor
1574 .cumulative_time_editing
1575 .as_millis()
1576 .try_into()
1577 .ok()?;
1578 let cumulative_time_navigating_ms = recent_editor
1579 .cumulative_time_navigating
1580 .as_millis()
1581 .try_into()
1582 .ok()?;
1583 results.push(PredictEditsRecentFile {
1584 path: repo_path_str.to_string(),
1585 cursor_point: to_cloud_llm_client_point(cursor_point),
1586 active_to_now_ms,
1587 activation_count: recent_editor.activation_count,
1588 cumulative_time_editing_ms,
1589 cumulative_time_navigating_ms,
1590 is_multibuffer: !multibuffer.is_singleton(),
1591 });
1592 Some(())
1593 })
1594 })
1595 .ok()
1596 .flatten();
1597 if keep_entry.is_none() {
1598 self.recent_editors.remove(ix);
1599 }
1600 }
1601 results
1602 }
1603 */
1604}
1605
1606fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1607 cloud_llm_client::Point {
1608 row: point.row,
1609 column: point.column,
1610 }
1611}
1612
1613fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1614 file.is_local() && !file.is_private()
1615}
1616
1617fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1618 entry.is_file()
1619 && entry.is_created()
1620 && !entry.is_ignored
1621 && !entry.is_private
1622 && !entry.is_external
1623 && !entry.is_fifo
1624}
1625
1626pub struct PerformPredictEditsParams {
1627 pub client: Arc<Client>,
1628 pub llm_token: LlmApiToken,
1629 pub app_version: SemanticVersion,
1630 pub body: PredictEditsBody,
1631}
1632
1633#[derive(Error, Debug)]
1634#[error(
1635 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1636)]
1637pub struct ZedUpdateRequiredError {
1638 minimum_version: SemanticVersion,
1639}
1640
1641fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1642 a.zip(b)
1643 .take_while(|(a, b)| a == b)
1644 .map(|(a, _)| a.len_utf8())
1645 .sum()
1646}
1647
1648pub struct GatherContextOutput {
1649 pub body: PredictEditsBody,
1650 pub editable_range: Range<usize>,
1651}
1652
1653pub async fn gather_context(
1654 full_path_str: String,
1655 snapshot: BufferSnapshot,
1656 cursor_point: language::Point,
1657 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1658 can_collect_data: CanCollectData,
1659) -> Result<GatherContextOutput> {
1660 let input_excerpt = excerpt_for_cursor_position(
1661 cursor_point,
1662 &full_path_str,
1663 &snapshot,
1664 MAX_REWRITE_TOKENS,
1665 MAX_CONTEXT_TOKENS,
1666 );
1667 let input_events = make_events_prompt();
1668 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1669
1670 let body = PredictEditsBody {
1671 input_events,
1672 input_excerpt: input_excerpt.prompt,
1673 can_collect_data: can_collect_data.0,
1674 diagnostic_groups: None,
1675 git_info: None,
1676 };
1677
1678 Ok(GatherContextOutput {
1679 body,
1680 editable_range,
1681 })
1682}
1683
1684fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1685 let mut result = String::new();
1686 for event in events.iter().rev() {
1687 let event_string = event.to_prompt();
1688 let event_tokens = tokens_for_bytes(event_string.len());
1689 if event_tokens > remaining_tokens {
1690 break;
1691 }
1692
1693 if !result.is_empty() {
1694 result.insert_str(0, "\n\n");
1695 }
1696 result.insert_str(0, &event_string);
1697 remaining_tokens -= event_tokens;
1698 }
1699 result
1700}
1701
1702struct RegisteredBuffer {
1703 snapshot: BufferSnapshot,
1704 _subscriptions: [gpui::Subscription; 2],
1705}
1706
1707#[derive(Clone)]
1708pub enum Event {
1709 BufferChange {
1710 old_snapshot: BufferSnapshot,
1711 new_snapshot: BufferSnapshot,
1712 timestamp: Instant,
1713 },
1714}
1715
1716impl Event {
1717 fn to_prompt(&self) -> String {
1718 match self {
1719 Event::BufferChange {
1720 old_snapshot,
1721 new_snapshot,
1722 ..
1723 } => {
1724 let mut prompt = String::new();
1725
1726 let old_path = old_snapshot
1727 .file()
1728 .map(|f| f.path().as_ref())
1729 .unwrap_or(Path::new("untitled"));
1730 let new_path = new_snapshot
1731 .file()
1732 .map(|f| f.path().as_ref())
1733 .unwrap_or(Path::new("untitled"));
1734 if old_path != new_path {
1735 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1736 }
1737
1738 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1739 if !diff.is_empty() {
1740 write!(
1741 prompt,
1742 "User edited {:?}:\n```diff\n{}\n```",
1743 new_path, diff
1744 )
1745 .unwrap();
1746 }
1747
1748 prompt
1749 }
1750 }
1751 }
1752}
1753
1754#[derive(Debug, Clone)]
1755struct CurrentEditPrediction {
1756 buffer_id: EntityId,
1757 completion: EditPrediction,
1758}
1759
1760impl CurrentEditPrediction {
1761 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1762 if self.buffer_id != old_completion.buffer_id {
1763 return true;
1764 }
1765
1766 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1767 return true;
1768 };
1769 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1770 return false;
1771 };
1772
1773 if old_edits.len() == 1 && new_edits.len() == 1 {
1774 let (old_range, old_text) = &old_edits[0];
1775 let (new_range, new_text) = &new_edits[0];
1776 new_range == old_range && new_text.starts_with(old_text)
1777 } else {
1778 true
1779 }
1780 }
1781}
1782
1783struct PendingCompletion {
1784 id: usize,
1785 _task: Task<()>,
1786}
1787
1788#[derive(Debug, Clone, Copy)]
1789pub enum DataCollectionChoice {
1790 NotAnswered,
1791 Enabled,
1792 Disabled,
1793}
1794
1795impl DataCollectionChoice {
1796 pub fn is_enabled(self) -> bool {
1797 match self {
1798 Self::Enabled => true,
1799 Self::NotAnswered | Self::Disabled => false,
1800 }
1801 }
1802
1803 pub fn is_answered(self) -> bool {
1804 match self {
1805 Self::Enabled | Self::Disabled => true,
1806 Self::NotAnswered => false,
1807 }
1808 }
1809
1810 pub fn toggle(&self) -> DataCollectionChoice {
1811 match self {
1812 Self::Enabled => Self::Disabled,
1813 Self::Disabled => Self::Enabled,
1814 Self::NotAnswered => Self::Enabled,
1815 }
1816 }
1817}
1818
1819impl From<bool> for DataCollectionChoice {
1820 fn from(value: bool) -> Self {
1821 match value {
1822 true => DataCollectionChoice::Enabled,
1823 false => DataCollectionChoice::Disabled,
1824 }
1825 }
1826}
1827
1828pub struct ProviderDataCollection {
1829 /// When set to None, data collection is not possible in the provider buffer
1830 choice: Option<Entity<DataCollectionChoice>>,
1831 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1832}
1833
1834#[derive(Debug, Clone, Copy)]
1835pub struct CanCollectData(pub bool);
1836
1837impl ProviderDataCollection {
1838 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1839 let choice_and_watcher = buffer.and_then(|buffer| {
1840 let file = buffer.read(cx).file()?;
1841
1842 if !file_is_eligible_for_collection(file.as_ref()) {
1843 return None;
1844 }
1845
1846 let zeta = zeta.read(cx);
1847 let choice = zeta.data_collection_choice.clone();
1848
1849 let license_detection_watcher = zeta
1850 .license_detection_watchers
1851 .get(&file.worktree_id(cx))
1852 .cloned()?;
1853
1854 Some((choice, license_detection_watcher))
1855 });
1856
1857 if let Some((choice, watcher)) = choice_and_watcher {
1858 ProviderDataCollection {
1859 choice: Some(choice),
1860 license_detection_watcher: Some(watcher),
1861 }
1862 } else {
1863 ProviderDataCollection {
1864 choice: None,
1865 license_detection_watcher: None,
1866 }
1867 }
1868 }
1869
1870 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1871 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1872 }
1873
1874 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1875 self.choice
1876 .as_ref()
1877 .is_some_and(|choice| choice.read(cx).is_enabled())
1878 }
1879
1880 fn is_project_open_source(&self) -> bool {
1881 self.license_detection_watcher
1882 .as_ref()
1883 .is_some_and(|watcher| watcher.is_project_open_source())
1884 }
1885
1886 pub fn toggle(&mut self, cx: &mut App) {
1887 if let Some(choice) = self.choice.as_mut() {
1888 let new_choice = choice.update(cx, |choice, _cx| {
1889 let new_choice = choice.toggle();
1890 *choice = new_choice;
1891 new_choice
1892 });
1893
1894 db::write_and_log(cx, move || {
1895 KEY_VALUE_STORE.write_kvp(
1896 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1897 new_choice.is_enabled().to_string(),
1898 )
1899 });
1900 }
1901 }
1902}
1903
1904async fn llm_token_retry(
1905 llm_token: &LlmApiToken,
1906 client: &Arc<Client>,
1907 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1908) -> Result<Response<AsyncBody>> {
1909 let mut did_retry = false;
1910 let http_client = client.http_client();
1911 let mut token = llm_token.acquire(client).await?;
1912 loop {
1913 let request = build_request(token.clone())?;
1914 let response = http_client.send(request).await?;
1915
1916 if !did_retry
1917 && !response.status().is_success()
1918 && response
1919 .headers()
1920 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1921 .is_some()
1922 {
1923 did_retry = true;
1924 token = llm_token.refresh(client).await?;
1925 continue;
1926 }
1927
1928 return Ok(response);
1929 }
1930}
1931
1932pub struct ZetaEditPredictionProvider {
1933 zeta: Entity<Zeta>,
1934 pending_completions: ArrayVec<PendingCompletion, 2>,
1935 next_pending_completion_id: usize,
1936 current_completion: Option<CurrentEditPrediction>,
1937 /// None if this is entirely disabled for this provider
1938 provider_data_collection: ProviderDataCollection,
1939 last_request_timestamp: Instant,
1940}
1941
1942impl ZetaEditPredictionProvider {
1943 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1944
1945 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1946 Self {
1947 zeta,
1948 pending_completions: ArrayVec::new(),
1949 next_pending_completion_id: 0,
1950 current_completion: None,
1951 provider_data_collection,
1952 last_request_timestamp: Instant::now(),
1953 }
1954 }
1955}
1956
1957impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1958 fn name() -> &'static str {
1959 "zed-predict"
1960 }
1961
1962 fn display_name() -> &'static str {
1963 "Zed's Edit Predictions"
1964 }
1965
1966 fn show_completions_in_menu() -> bool {
1967 true
1968 }
1969
1970 fn show_tab_accept_marker() -> bool {
1971 true
1972 }
1973
1974 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1975 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1976
1977 if self.provider_data_collection.is_data_collection_enabled(cx) {
1978 DataCollectionState::Enabled {
1979 is_project_open_source,
1980 }
1981 } else {
1982 DataCollectionState::Disabled {
1983 is_project_open_source,
1984 }
1985 }
1986 }
1987
1988 fn toggle_data_collection(&mut self, cx: &mut App) {
1989 self.provider_data_collection.toggle(cx);
1990 }
1991
1992 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1993 self.zeta.read(cx).usage(cx)
1994 }
1995
1996 fn is_enabled(
1997 &self,
1998 _buffer: &Entity<Buffer>,
1999 _cursor_position: language::Anchor,
2000 _cx: &App,
2001 ) -> bool {
2002 true
2003 }
2004 fn is_refreshing(&self) -> bool {
2005 !self.pending_completions.is_empty()
2006 }
2007
2008 fn refresh(
2009 &mut self,
2010 project: Option<Entity<Project>>,
2011 buffer: Entity<Buffer>,
2012 position: language::Anchor,
2013 _debounce: bool,
2014 cx: &mut Context<Self>,
2015 ) {
2016 if self.zeta.read(cx).update_required {
2017 return;
2018 }
2019 // todo! Don't require a project
2020 let Some(project) = project else {
2021 return;
2022 };
2023
2024 if self
2025 .zeta
2026 .read(cx)
2027 .user_store
2028 .read_with(cx, |user_store, _cx| {
2029 user_store.account_too_young() || user_store.has_overdue_invoices()
2030 })
2031 {
2032 return;
2033 }
2034
2035 if let Some(current_completion) = self.current_completion.as_ref() {
2036 let snapshot = buffer.read(cx).snapshot();
2037 if current_completion
2038 .completion
2039 .interpolate(&snapshot)
2040 .is_some()
2041 {
2042 return;
2043 }
2044 }
2045
2046 let pending_completion_id = self.next_pending_completion_id;
2047 self.next_pending_completion_id += 1;
2048 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
2049 let last_request_timestamp = self.last_request_timestamp;
2050
2051 let task = cx.spawn(async move |this, cx| {
2052 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
2053 .checked_duration_since(Instant::now())
2054 {
2055 cx.background_executor().timer(timeout).await;
2056 }
2057
2058 let completion_request = this.update(cx, |this, cx| {
2059 this.last_request_timestamp = Instant::now();
2060 this.zeta.update(cx, |zeta, cx| {
2061 zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
2062 })
2063 });
2064
2065 let completion = match completion_request {
2066 Ok(completion_request) => {
2067 let completion_request = completion_request.await;
2068 completion_request.map(|c| {
2069 c.map(|completion| CurrentEditPrediction {
2070 buffer_id: buffer.entity_id(),
2071 completion,
2072 })
2073 })
2074 }
2075 Err(error) => Err(error),
2076 };
2077 let Some(new_completion) = completion
2078 .context("edit prediction failed")
2079 .log_err()
2080 .flatten()
2081 else {
2082 this.update(cx, |this, cx| {
2083 if this.pending_completions[0].id == pending_completion_id {
2084 this.pending_completions.remove(0);
2085 } else {
2086 this.pending_completions.clear();
2087 }
2088
2089 cx.notify();
2090 })
2091 .ok();
2092 return;
2093 };
2094
2095 this.update(cx, |this, cx| {
2096 if this.pending_completions[0].id == pending_completion_id {
2097 this.pending_completions.remove(0);
2098 } else {
2099 this.pending_completions.clear();
2100 }
2101
2102 if let Some(old_completion) = this.current_completion.as_ref() {
2103 let snapshot = buffer.read(cx).snapshot();
2104 if new_completion.should_replace_completion(old_completion, &snapshot) {
2105 this.zeta.update(cx, |zeta, cx| {
2106 zeta.completion_shown(&new_completion.completion, cx);
2107 });
2108 this.current_completion = Some(new_completion);
2109 }
2110 } else {
2111 this.zeta.update(cx, |zeta, cx| {
2112 zeta.completion_shown(&new_completion.completion, cx);
2113 });
2114 this.current_completion = Some(new_completion);
2115 }
2116
2117 cx.notify();
2118 })
2119 .ok();
2120 });
2121
2122 // We always maintain at most two pending completions. When we already
2123 // have two, we replace the newest one.
2124 if self.pending_completions.len() <= 1 {
2125 self.pending_completions.push(PendingCompletion {
2126 id: pending_completion_id,
2127 _task: task,
2128 });
2129 } else if self.pending_completions.len() == 2 {
2130 self.pending_completions.pop();
2131 self.pending_completions.push(PendingCompletion {
2132 id: pending_completion_id,
2133 _task: task,
2134 });
2135 }
2136 }
2137
2138 fn cycle(
2139 &mut self,
2140 _buffer: Entity<Buffer>,
2141 _cursor_position: language::Anchor,
2142 _direction: edit_prediction::Direction,
2143 _cx: &mut Context<Self>,
2144 ) {
2145 // Right now we don't support cycling.
2146 }
2147
2148 fn accept(&mut self, cx: &mut Context<Self>) {
2149 let completion_id = self
2150 .current_completion
2151 .as_ref()
2152 .map(|completion| completion.completion.id);
2153 if let Some(completion_id) = completion_id {
2154 self.zeta
2155 .update(cx, |zeta, cx| {
2156 zeta.accept_edit_prediction(completion_id, cx)
2157 })
2158 .detach();
2159 }
2160 self.pending_completions.clear();
2161 }
2162
2163 fn discard(&mut self, _cx: &mut Context<Self>) {
2164 self.pending_completions.clear();
2165 self.current_completion.take();
2166 }
2167
2168 fn suggest(
2169 &mut self,
2170 buffer: &Entity<Buffer>,
2171 cursor_position: language::Anchor,
2172 cx: &mut Context<Self>,
2173 ) -> Option<edit_prediction::EditPrediction> {
2174 let CurrentEditPrediction {
2175 buffer_id,
2176 completion,
2177 ..
2178 } = self.current_completion.as_mut()?;
2179
2180 // Invalidate previous completion if it was generated for a different buffer.
2181 if *buffer_id != buffer.entity_id() {
2182 self.current_completion.take();
2183 return None;
2184 }
2185
2186 let buffer = buffer.read(cx);
2187 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2188 self.current_completion.take();
2189 return None;
2190 };
2191
2192 let cursor_row = cursor_position.to_point(buffer).row;
2193 let (closest_edit_ix, (closest_edit_range, _)) =
2194 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2195 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2196 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2197 cmp::min(distance_from_start, distance_from_end)
2198 })?;
2199
2200 let mut edit_start_ix = closest_edit_ix;
2201 for (range, _) in edits[..edit_start_ix].iter().rev() {
2202 let distance_from_closest_edit =
2203 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2204 if distance_from_closest_edit <= 1 {
2205 edit_start_ix -= 1;
2206 } else {
2207 break;
2208 }
2209 }
2210
2211 let mut edit_end_ix = closest_edit_ix + 1;
2212 for (range, _) in &edits[edit_end_ix..] {
2213 let distance_from_closest_edit =
2214 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2215 if distance_from_closest_edit <= 1 {
2216 edit_end_ix += 1;
2217 } else {
2218 break;
2219 }
2220 }
2221
2222 Some(edit_prediction::EditPrediction {
2223 id: Some(completion.id.to_string().into()),
2224 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2225 edit_preview: Some(completion.edit_preview.clone()),
2226 })
2227 }
2228}
2229
2230fn tokens_for_bytes(bytes: usize) -> usize {
2231 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2232 /// intentionally low to err on the side of underestimating limits.
2233 const BYTES_PER_TOKEN_GUESS: usize = 3;
2234 bytes / BYTES_PER_TOKEN_GUESS
2235}
2236
2237/* todo!
2238#[cfg(test)]
2239mod tests {
2240 use client::UserStore;
2241 use client::test::FakeServer;
2242 use clock::FakeSystemClock;
2243 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2244 use gpui::TestAppContext;
2245 use http_client::FakeHttpClient;
2246 use indoc::indoc;
2247 use language::Point;
2248 use settings::SettingsStore;
2249
2250 use super::*;
2251
2252 #[gpui::test]
2253 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2254 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2255 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2256 to_completion_edits(
2257 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2258 &buffer,
2259 cx,
2260 )
2261 .into()
2262 });
2263
2264 let edit_preview = cx
2265 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2266 .await;
2267
2268 let completion = EditPrediction {
2269 edits,
2270 edit_preview,
2271 path: Path::new("").into(),
2272 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2273 id: EditPredictionId(Uuid::new_v4()),
2274 excerpt_range: 0..0,
2275 cursor_offset: 0,
2276 input_events: "".into(),
2277 input_excerpt: "".into(),
2278 output_excerpt: "".into(),
2279 buffer_snapshotted_at: Instant::now(),
2280 response_received_at: Instant::now(),
2281 };
2282
2283 cx.update(|cx| {
2284 assert_eq!(
2285 from_completion_edits(
2286 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2287 &buffer,
2288 cx
2289 ),
2290 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2291 );
2292
2293 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2294 assert_eq!(
2295 from_completion_edits(
2296 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2297 &buffer,
2298 cx
2299 ),
2300 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2301 );
2302
2303 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2304 assert_eq!(
2305 from_completion_edits(
2306 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2307 &buffer,
2308 cx
2309 ),
2310 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2311 );
2312
2313 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2314 assert_eq!(
2315 from_completion_edits(
2316 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2317 &buffer,
2318 cx
2319 ),
2320 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2321 );
2322
2323 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2324 assert_eq!(
2325 from_completion_edits(
2326 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2327 &buffer,
2328 cx
2329 ),
2330 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2331 );
2332
2333 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2334 assert_eq!(
2335 from_completion_edits(
2336 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2337 &buffer,
2338 cx
2339 ),
2340 vec![(9..11, "".to_string())]
2341 );
2342
2343 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2344 assert_eq!(
2345 from_completion_edits(
2346 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2347 &buffer,
2348 cx
2349 ),
2350 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2351 );
2352
2353 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2354 assert_eq!(
2355 from_completion_edits(
2356 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2357 &buffer,
2358 cx
2359 ),
2360 vec![(4..4, "M".to_string())]
2361 );
2362
2363 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2364 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2365 })
2366 }
2367
2368 #[gpui::test]
2369 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2370 cx.update(|cx| {
2371 let settings_store = SettingsStore::test(cx);
2372 cx.set_global(settings_store);
2373 client::init_settings(cx);
2374 });
2375
2376 let edits = edits_for_prediction(
2377 indoc! {"
2378 fn main() {
2379 let word_1 = \"lorem\";
2380 let range = word.len()..word.len();
2381 }
2382 "},
2383 indoc! {"
2384 <|editable_region_start|>
2385 fn main() {
2386 let word_1 = \"lorem\";
2387 let range = word_1.len()..word_1.len();
2388 }
2389
2390 <|editable_region_end|>
2391 "},
2392 cx,
2393 )
2394 .await;
2395 assert_eq!(
2396 edits,
2397 [
2398 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2399 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2400 ]
2401 );
2402
2403 let edits = edits_for_prediction(
2404 indoc! {"
2405 fn main() {
2406 let story = \"the quick\"
2407 }
2408 "},
2409 indoc! {"
2410 <|editable_region_start|>
2411 fn main() {
2412 let story = \"the quick brown fox jumps over the lazy dog\";
2413 }
2414
2415 <|editable_region_end|>
2416 "},
2417 cx,
2418 )
2419 .await;
2420 assert_eq!(
2421 edits,
2422 [
2423 (
2424 Point::new(1, 26)..Point::new(1, 26),
2425 " brown fox jumps over the lazy dog".to_string()
2426 ),
2427 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2428 ]
2429 );
2430 }
2431
2432 #[gpui::test]
2433 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2434 cx.update(|cx| {
2435 let settings_store = SettingsStore::test(cx);
2436 cx.set_global(settings_store);
2437 client::init_settings(cx);
2438 });
2439
2440 let buffer_content = "lorem\n";
2441 let completion_response = indoc! {"
2442 ```animals.js
2443 <|start_of_file|>
2444 <|editable_region_start|>
2445 lorem
2446 ipsum
2447 <|editable_region_end|>
2448 ```"};
2449
2450 let http_client = FakeHttpClient::create(move |req| async move {
2451 match (req.method(), req.uri().path()) {
2452 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2453 .status(200)
2454 .body(
2455 serde_json::to_string(&CreateLlmTokenResponse {
2456 token: LlmToken("the-llm-token".to_string()),
2457 })
2458 .unwrap()
2459 .into(),
2460 )
2461 .unwrap()),
2462 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2463 .status(200)
2464 .body(
2465 serde_json::to_string(&PredictEditsResponse {
2466 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2467 .unwrap(),
2468 output_excerpt: completion_response.to_string(),
2469 })
2470 .unwrap()
2471 .into(),
2472 )
2473 .unwrap()),
2474 _ => Ok(http_client::Response::builder()
2475 .status(404)
2476 .body("Not Found".into())
2477 .unwrap()),
2478 }
2479 });
2480
2481 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2482 cx.update(|cx| {
2483 RefreshLlmTokenListener::register(client.clone(), cx);
2484 });
2485 // Construct the fake server to authenticate.
2486 let _server = FakeServer::for_client(42, &client, cx).await;
2487 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2488 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2489
2490 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2491 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2492 let completion_task = zeta.update(cx, |zeta, cx| {
2493 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2494 });
2495
2496 let completion = completion_task.await.unwrap().unwrap();
2497 buffer.update(cx, |buffer, cx| {
2498 buffer.edit(completion.edits.iter().cloned(), None, cx)
2499 });
2500 assert_eq!(
2501 buffer.read_with(cx, |buffer, _| buffer.text()),
2502 "lorem\nipsum"
2503 );
2504 }
2505
2506 async fn edits_for_prediction(
2507 buffer_content: &str,
2508 completion_response: &str,
2509 cx: &mut TestAppContext,
2510 ) -> Vec<(Range<Point>, String)> {
2511 let completion_response = completion_response.to_string();
2512 let http_client = FakeHttpClient::create(move |req| {
2513 let completion = completion_response.clone();
2514 async move {
2515 match (req.method(), req.uri().path()) {
2516 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2517 .status(200)
2518 .body(
2519 serde_json::to_string(&CreateLlmTokenResponse {
2520 token: LlmToken("the-llm-token".to_string()),
2521 })
2522 .unwrap()
2523 .into(),
2524 )
2525 .unwrap()),
2526 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2527 .status(200)
2528 .body(
2529 serde_json::to_string(&PredictEditsResponse {
2530 request_id: Uuid::new_v4(),
2531 output_excerpt: completion,
2532 })
2533 .unwrap()
2534 .into(),
2535 )
2536 .unwrap()),
2537 _ => Ok(http_client::Response::builder()
2538 .status(404)
2539 .body("Not Found".into())
2540 .unwrap()),
2541 }
2542 }
2543 });
2544
2545 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2546 cx.update(|cx| {
2547 RefreshLlmTokenListener::register(client.clone(), cx);
2548 });
2549 // Construct the fake server to authenticate.
2550 let _server = FakeServer::for_client(42, &client, cx).await;
2551 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2552 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2553
2554 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2555 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2556 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2557 let completion_task = zeta.update(cx, |zeta, cx| {
2558 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2559 });
2560
2561 let completion = completion_task.await.unwrap().unwrap();
2562 completion
2563 .edits
2564 .iter()
2565 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2566 .collect::<Vec<_>>()
2567 }
2568
2569 fn to_completion_edits(
2570 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2571 buffer: &Entity<Buffer>,
2572 cx: &App,
2573 ) -> Vec<(Range<Anchor>, String)> {
2574 let buffer = buffer.read(cx);
2575 iterator
2576 .into_iter()
2577 .map(|(range, text)| {
2578 (
2579 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2580 text,
2581 )
2582 })
2583 .collect()
2584 }
2585
2586 fn from_completion_edits(
2587 editor_edits: &[(Range<Anchor>, String)],
2588 buffer: &Entity<Buffer>,
2589 cx: &App,
2590 ) -> Vec<(Range<usize>, String)> {
2591 let buffer = buffer.read(cx);
2592 editor_edits
2593 .iter()
2594 .map(|(range, text)| {
2595 (
2596 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2597 text.clone(),
2598 )
2599 })
2600 .collect()
2601 }
2602
2603 #[ctor::ctor]
2604 fn init_logger() {
2605 zlog::init_test();
2606 }
2607}
2608*/