1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 SharedString, Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use multi_buffer::MultiBufferPoint;
39use project::{Project, ProjectPath};
40use release_channel::AppVersion;
41use settings::WorktreeId;
42use std::collections::hash_map;
43use std::mem;
44use std::str::FromStr;
45use std::{
46 cmp,
47 fmt::Write,
48 future::Future,
49 ops::Range,
50 path::Path,
51 rc::Rc,
52 sync::Arc,
53 time::{Duration, Instant},
54};
55use telemetry_events::EditPredictionRating;
56use thiserror::Error;
57use util::{ResultExt, maybe};
58use uuid::Uuid;
59use workspace::Workspace;
60use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
61use worktree::Worktree;
62
63const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
64const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
65const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
66const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
69
70const MAX_CONTEXT_TOKENS: usize = 150;
71const MAX_REWRITE_TOKENS: usize = 350;
72const MAX_EVENT_TOKENS: usize = 500;
73
74/// Maximum number of events to track.
75const MAX_EVENT_COUNT: usize = 16;
76
77/// Maximum number of recent files to track.
78const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
79
80/// Minimum number of milliseconds between recent file entries.
81const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
82
83/// Maximum file path length to include in recent files list.
84const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
85
86/// Maximum number of JSON bytes for diagnostics in additional context.
87const MAX_DIAGNOSTICS_BYTES: usize = 4096;
88
89/// Maximum number of edit predictions to store for feedback.
90const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
91
92/// Interval between polls tracking time editing files.
93const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
94
95/// Interval between polls of whether data collection is enabled, when it is disabled.
96const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
97
98actions!(
99 edit_prediction,
100 [
101 /// Clears the edit prediction history.
102 ClearHistory
103 ]
104);
105
106#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
107pub struct EditPredictionId(Uuid);
108
109impl From<EditPredictionId> for gpui::ElementId {
110 fn from(value: EditPredictionId) -> Self {
111 gpui::ElementId::Uuid(value.0)
112 }
113}
114
115impl std::fmt::Display for EditPredictionId {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "{}", self.0)
118 }
119}
120
121struct ZedPredictUpsell;
122
123impl Dismissable for ZedPredictUpsell {
124 const KEY: &'static str = "dismissed-edit-predict-upsell";
125
126 fn dismissed() -> bool {
127 // To make this backwards compatible with older versions of Zed, we
128 // check if the user has seen the previous Edit Prediction Onboarding
129 // before, by checking the data collection choice which was written to
130 // the database once the user clicked on "Accept and Enable"
131 if KEY_VALUE_STORE
132 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
133 .log_err()
134 .is_some_and(|s| s.is_some())
135 {
136 return true;
137 }
138
139 KEY_VALUE_STORE
140 .read_kvp(Self::KEY)
141 .log_err()
142 .is_some_and(|s| s.is_some())
143 }
144}
145
146pub fn should_show_upsell_modal() -> bool {
147 !ZedPredictUpsell::dismissed()
148}
149
150#[derive(Clone)]
151struct ZetaGlobal(Entity<Zeta>);
152
153impl Global for ZetaGlobal {}
154
155#[derive(Clone)]
156pub struct EditPrediction {
157 id: EditPredictionId,
158 path: Arc<Path>,
159 excerpt_range: Range<usize>,
160 cursor_offset: usize,
161 edits: Arc<[(Range<Anchor>, String)]>,
162 snapshot: BufferSnapshot,
163 edit_preview: EditPreview,
164 input_events: Arc<str>,
165 input_excerpt: Arc<str>,
166 output_excerpt: Arc<str>,
167 buffer_snapshotted_at: Instant,
168 response_received_at: Instant,
169}
170
171impl EditPrediction {
172 fn latency(&self) -> Duration {
173 self.response_received_at
174 .duration_since(self.buffer_snapshotted_at)
175 }
176
177 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
178 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
179 }
180}
181
182fn interpolate(
183 old_snapshot: &BufferSnapshot,
184 new_snapshot: &BufferSnapshot,
185 current_edits: Arc<[(Range<Anchor>, String)]>,
186) -> Option<Vec<(Range<Anchor>, String)>> {
187 let mut edits = Vec::new();
188
189 let mut model_edits = current_edits.iter().peekable();
190 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
191 while let Some((model_old_range, _)) = model_edits.peek() {
192 let model_old_range = model_old_range.to_offset(old_snapshot);
193 if model_old_range.end < user_edit.old.start {
194 let (model_old_range, model_new_text) = model_edits.next().unwrap();
195 edits.push((model_old_range.clone(), model_new_text.clone()));
196 } else {
197 break;
198 }
199 }
200
201 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
202 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
203 if user_edit.old == model_old_offset_range {
204 let user_new_text = new_snapshot
205 .text_for_range(user_edit.new.clone())
206 .collect::<String>();
207
208 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
209 if !model_suffix.is_empty() {
210 let anchor = old_snapshot.anchor_after(user_edit.old.end);
211 edits.push((anchor..anchor, model_suffix.to_string()));
212 }
213
214 model_edits.next();
215 continue;
216 }
217 }
218 }
219
220 return None;
221 }
222
223 edits.extend(model_edits.cloned());
224
225 if edits.is_empty() { None } else { Some(edits) }
226}
227
228impl std::fmt::Debug for EditPrediction {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("EditPrediction")
231 .field("id", &self.id)
232 .field("path", &self.path)
233 .field("edits", &self.edits)
234 .finish_non_exhaustive()
235 }
236}
237
238pub struct Zeta {
239 client: Arc<Client>,
240 shown_completions: VecDeque<EditPrediction>,
241 rated_completions: HashSet<EditPredictionId>,
242 data_collection_choice: Entity<DataCollectionChoice>,
243 llm_token: LlmApiToken,
244 _llm_token_subscription: Subscription,
245 /// Whether an update to a newer version of Zed is required to continue using Zeta.
246 update_required: bool,
247 user_store: Entity<UserStore>,
248 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
249 projects: HashMap<EntityId, ZetaProject>,
250}
251
252struct ZetaProject {
253 events: VecDeque<Event>,
254 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
255 recent_editors: VecDeque<RecentEditor>,
256 last_activity_state: Option<ActivityState>,
257 _activity_poll_task: Option<Task<Result<()>>>,
258}
259
260struct RecentEditor {
261 editor: WeakEntity<Editor>,
262 last_active_at: Instant,
263 activation_count: u32,
264 cumulative_time_editing: Duration,
265 cumulative_time_navigating: Duration,
266}
267
268#[derive(Debug)]
269struct ActivityState {
270 scroll_position: gpui::Point<f32>,
271 cursor_point: MultiBufferPoint,
272 singleton_version: Option<clock::Global>,
273}
274
275impl Zeta {
276 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
277 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
278 }
279
280 pub fn register(
281 worktree: Option<Entity<Worktree>>,
282 client: Arc<Client>,
283 user_store: Entity<UserStore>,
284 cx: &mut App,
285 ) -> Entity<Self> {
286 let this = Self::global(cx).unwrap_or_else(|| {
287 let entity = cx.new(|cx| Self::new(client, user_store, cx));
288 cx.set_global(ZetaGlobal(entity.clone()));
289 entity
290 });
291
292 this.update(cx, move |this, cx| {
293 if let Some(worktree) = worktree {
294 let worktree_id = worktree.read(cx).id();
295 this.license_detection_watchers
296 .entry(worktree_id)
297 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
298 }
299 });
300
301 this
302 }
303
304 pub fn clear_history(&mut self) {
305 for zeta_project in self.projects.values_mut() {
306 zeta_project.events.clear();
307 }
308 }
309
310 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
311 self.user_store.read(cx).edit_prediction_usage()
312 }
313
314 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
315 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
316
317 let data_collection_choice = Self::load_data_collection_choices();
318 let data_collection_choice = cx.new(|_| data_collection_choice);
319
320 /* todo!
321 let mut activity_poll_task = None;
322
323 if let Some(workspace) = &workspace {
324 let project = workspace.read(cx).project().clone();
325 cx.subscribe(&project, |this, _project, event, cx| match event {
326 project::Event::ActiveEntryChanged(entry_id) => {
327 this.handle_active_project_entry_changed(cx)
328 }
329 _ => {}
330 })
331 .detach();
332
333 // TODO: ideally this would attend to window focus when tracking time, and pause the
334 // loop for efficiency when not focused.
335 activity_poll_task = Some(cx.spawn(async move |this, cx| {
336 let mut instant_before_delay = None;
337 loop {
338 let data_collection_is_enabled = this.read_with(cx, |this, cx| {
339 this.data_collection_choice.read(cx).is_enabled()
340 })?;
341 let interval = if data_collection_is_enabled {
342 ACTIVITY_POLL_INTERVAL
343 } else {
344 instant_before_delay = None;
345 DISABLED_ACTIVITY_POLL_INTERVAL
346 };
347 cx.background_executor().timer(interval).await;
348 this.update(cx, |this, cx| {
349 let now = Instant::now();
350 this.handle_activity_poll(instant_before_delay, now, cx);
351 instant_before_delay = Some(now);
352 })?
353 }
354 }));
355 }
356 */
357
358 Self {
359 client,
360 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
361 rated_completions: HashSet::default(),
362 data_collection_choice,
363 llm_token: LlmApiToken::default(),
364 _llm_token_subscription: cx.subscribe(
365 &refresh_llm_token_listener,
366 |this, _listener, _event, cx| {
367 let client = this.client.clone();
368 let llm_token = this.llm_token.clone();
369 cx.spawn(async move |_this, _cx| {
370 llm_token.refresh(&client).await?;
371 anyhow::Ok(())
372 })
373 .detach_and_log_err(cx);
374 },
375 ),
376 update_required: false,
377 license_detection_watchers: HashMap::default(),
378 user_store,
379 projects: HashMap::default(),
380 }
381 }
382
383 fn get_mut_or_init_zeta_project(
384 &mut self,
385 project: &Entity<Project>,
386 cx: &mut Context<Self>,
387 ) -> &mut ZetaProject {
388 let project_id = project.entity_id();
389 match self.projects.entry(project_id) {
390 hash_map::Entry::Occupied(entry) => entry.into_mut(),
391 hash_map::Entry::Vacant(entry) => {
392 cx.observe_release(project, move |this, _, _cx| {
393 this.projects.remove(&project_id);
394 });
395 entry.insert(ZetaProject {
396 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
397 registered_buffers: HashMap::default(),
398 recent_editors: VecDeque::new(),
399 last_activity_state: None,
400 // todo!
401 _activity_poll_task: None,
402 })
403 }
404 }
405 }
406
407 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
408 let events = &mut zeta_project.events;
409
410 if let Some(Event::BufferChange {
411 new_snapshot: last_new_snapshot,
412 timestamp: last_timestamp,
413 ..
414 }) = events.back_mut()
415 {
416 // Coalesce edits for the same buffer when they happen one after the other.
417 let Event::BufferChange {
418 old_snapshot,
419 new_snapshot,
420 timestamp,
421 } = &event;
422
423 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
424 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
425 && old_snapshot.version == last_new_snapshot.version
426 {
427 *last_new_snapshot = new_snapshot.clone();
428 *last_timestamp = *timestamp;
429 return;
430 }
431 }
432
433 if events.len() >= MAX_EVENT_COUNT {
434 // These are halved instead of popping to improve prompt caching.
435 events.drain(..MAX_EVENT_COUNT / 2);
436 }
437
438 events.push_back(event);
439 }
440
441 pub fn register_buffer(
442 &mut self,
443 buffer: &Entity<Buffer>,
444 project: &Entity<Project>,
445 cx: &mut Context<Self>,
446 ) {
447 let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
448 Self::register_buffer_impl(zeta_project, buffer, project, cx);
449 }
450
451 fn register_buffer_impl<'a>(
452 zeta_project: &'a mut ZetaProject,
453 buffer: &Entity<Buffer>,
454 project: &Entity<Project>,
455 cx: &mut Context<Self>,
456 ) -> &'a mut RegisteredBuffer {
457 let buffer_id = buffer.entity_id();
458 match zeta_project.registered_buffers.entry(buffer_id) {
459 hash_map::Entry::Occupied(entry) => entry.into_mut(),
460 hash_map::Entry::Vacant(entry) => {
461 let snapshot = buffer.read(cx).snapshot();
462 let project_entity_id = project.entity_id();
463 entry.insert(RegisteredBuffer {
464 snapshot,
465 _subscriptions: [
466 cx.subscribe(buffer, {
467 let project = project.downgrade();
468 move |this, buffer, event, cx| {
469 if let language::BufferEvent::Edited = event
470 && let Some(project) = project.upgrade()
471 {
472 this.report_changes_for_buffer(&buffer, &project, cx);
473 }
474 }
475 }),
476 cx.observe_release(buffer, move |this, _buffer, _cx| {
477 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
478 else {
479 return;
480 };
481 zeta_project.registered_buffers.remove(&buffer_id);
482 }),
483 ],
484 })
485 }
486 }
487 }
488
489 fn request_completion_impl<F, R>(
490 &mut self,
491 project: &Entity<Project>,
492 buffer: &Entity<Buffer>,
493 cursor: language::Anchor,
494 can_collect_data: CanCollectData,
495 cx: &mut Context<Self>,
496 perform_predict_edits: F,
497 ) -> Task<Result<Option<EditPrediction>>>
498 where
499 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
500 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
501 + Send
502 + 'static,
503 {
504 let buffer = buffer.clone();
505 let buffer_snapshotted_at = Instant::now();
506 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
507 let zeta = cx.entity();
508 let events = self
509 .get_mut_or_init_zeta_project(project, cx)
510 .events
511 .clone();
512 let client = self.client.clone();
513 let llm_token = self.llm_token.clone();
514 let app_version = AppVersion::global(cx);
515
516 let full_path: Arc<Path> = snapshot
517 .file()
518 .map(|f| Arc::from(f.full_path(cx).as_path()))
519 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
520 let full_path_str = full_path.to_string_lossy().to_string();
521 let cursor_point = cursor.to_point(&snapshot);
522 let cursor_offset = cursor_point.to_offset(&snapshot);
523 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
524 let gather_task = cx.background_spawn(gather_context(
525 full_path_str,
526 snapshot.clone(),
527 cursor_point,
528 make_events_prompt,
529 can_collect_data,
530 ));
531
532 cx.spawn(async move |this, cx| {
533 let GatherContextOutput {
534 body,
535 editable_range,
536 } = gather_task.await?;
537 let done_gathering_context_at = Instant::now();
538
539 let additional_context_task: Option<Task<PredictEditsAdditionalContext>> = None;
540 /* todo!
541 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
542 && let Some(file) = snapshot.file()
543 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
544 {
545 // This is async to reduce latency of the edit predictions request. The downside is
546 // that it will see a slightly later state than was used when gathering context.
547 let snapshot = snapshot.clone();
548 let this = this.clone();
549 Some(cx.spawn(async move |cx| {
550 if let Ok(Some(task)) = this.update(cx, |this, cx| {
551 this.gather_additional_context(
552 cursor_point,
553 cursor_offset,
554 snapshot,
555 &buffer_snapshotted_at,
556 project_path,
557 &project,
558 cx,
559 )
560 }) {
561 Some(task.await)
562 } else {
563 None
564 }
565 }))
566 } else {
567 None
568 };
569 */
570
571 log::debug!(
572 "Events:\n{}\nExcerpt:\n{:?}",
573 body.input_events,
574 body.input_excerpt
575 );
576
577 let input_events = body.input_events.clone();
578 let input_excerpt = body.input_excerpt.clone();
579
580 let response = perform_predict_edits(PerformPredictEditsParams {
581 client,
582 llm_token,
583 app_version,
584 body,
585 })
586 .await;
587 let (response, usage) = match response {
588 Ok(response) => response,
589 Err(err) => {
590 if err.is::<ZedUpdateRequiredError>() {
591 cx.update(|cx| {
592 zeta.update(cx, |zeta, _cx| {
593 zeta.update_required = true;
594 });
595
596 let error_message: SharedString = err.to_string().into();
597 show_app_notification(
598 NotificationId::unique::<ZedUpdateRequiredError>(),
599 cx,
600 move |cx| {
601 cx.new(|cx| {
602 ErrorMessagePrompt::new(error_message.clone(), cx)
603 .with_link_button(
604 "Update Zed",
605 "https://zed.dev/releases",
606 )
607 })
608 },
609 );
610 })
611 .ok();
612 }
613
614 return Err(err);
615 }
616 };
617
618 let received_response_at = Instant::now();
619 log::debug!("completion response: {}", &response.output_excerpt);
620
621 if let Some(usage) = usage {
622 this.update(cx, |this, cx| {
623 this.user_store.update(cx, |user_store, cx| {
624 user_store.update_edit_prediction_usage(usage, cx);
625 });
626 })
627 .ok();
628 }
629
630 let request_id = response.request_id.clone();
631 let edit_prediction = Self::process_completion_response(
632 response,
633 buffer,
634 &snapshot,
635 editable_range,
636 cursor_offset,
637 full_path,
638 input_events,
639 input_excerpt,
640 buffer_snapshotted_at,
641 cx,
642 )
643 .await;
644
645 let finished_at = Instant::now();
646
647 // record latency for ~1% of requests
648 if rand::random::<u8>() <= 2 {
649 telemetry::event!(
650 "Edit Prediction Request",
651 context_latency = done_gathering_context_at
652 .duration_since(buffer_snapshotted_at)
653 .as_millis(),
654 request_latency = received_response_at
655 .duration_since(done_gathering_context_at)
656 .as_millis(),
657 process_latency = finished_at.duration_since(received_response_at).as_millis()
658 );
659 }
660
661 /* todo!
662 if let Some(additional_context_task) = additional_context_task {
663 cx.background_spawn(async move {
664 if let Some(additional_context) = additional_context_task.await {
665 telemetry::event!(
666 "Edit Prediction Additional Context",
667 request_id = request_id,
668 additional_context = additional_context
669 );
670 }
671 })
672 .detach();
673 }
674 */
675
676 edit_prediction
677 })
678 }
679
680 // Generates several example completions of various states to fill the Zeta completion modal
681 #[cfg(any(test, feature = "test-support"))]
682 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
683 /*
684 use language::Point;
685
686 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
687 And maybe a short line
688
689 Then a few lines
690
691 and then another
692 "#};
693
694 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
695 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
696
697 let completion_tasks = vec![
698 self.fake_completion(
699 None,
700 &buffer,
701 position,
702 PredictEditsResponse {
703 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
704 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
705a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
706[here's an edit]
707And maybe a short line
708Then a few lines
709and then another
710{EDITABLE_REGION_END_MARKER}
711 "),
712 },
713 cx,
714 ),
715 self.fake_completion(
716 None,
717 &buffer,
718 position,
719 PredictEditsResponse {
720 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
721 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
722a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
723And maybe a short line
724[and another edit]
725Then a few lines
726and then another
727{EDITABLE_REGION_END_MARKER}
728 "#),
729 },
730 cx,
731 ),
732 self.fake_completion(
733 None,
734 &buffer,
735 position,
736 PredictEditsResponse {
737 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
738 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
739a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
740And maybe a short line
741
742Then a few lines
743
744and then another
745{EDITABLE_REGION_END_MARKER}
746 "#),
747 },
748 cx,
749 ),
750 self.fake_completion(
751 None,
752 &buffer,
753 position,
754 PredictEditsResponse {
755 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
756 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
757a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
758And maybe a short line
759
760Then a few lines
761
762and then another
763{EDITABLE_REGION_END_MARKER}
764 "#),
765 },
766 cx,
767 ),
768 self.fake_completion(
769 None,
770 &buffer,
771 position,
772 PredictEditsResponse {
773 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
774 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
775a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
776And maybe a short line
777Then a few lines
778[a third completion]
779and then another
780{EDITABLE_REGION_END_MARKER}
781 "#),
782 },
783 cx,
784 ),
785 self.fake_completion(
786 None,
787 &buffer,
788 position,
789 PredictEditsResponse {
790 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
791 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
792a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
793And maybe a short line
794and then another
795[fourth completion example]
796{EDITABLE_REGION_END_MARKER}
797 "#),
798 },
799 cx,
800 ),
801 self.fake_completion(
802 None,
803 &buffer,
804 position,
805 PredictEditsResponse {
806 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
807 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
808a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
809And maybe a short line
810Then a few lines
811and then another
812[fifth and final completion]
813{EDITABLE_REGION_END_MARKER}
814 "#),
815 },
816 cx,
817 ),
818 ];
819
820 cx.spawn(async move |zeta, cx| {
821 for task in completion_tasks {
822 task.await.unwrap();
823 }
824
825 zeta.update(cx, |zeta, _cx| {
826 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
827 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
828 })
829 .ok();
830 })
831 */
832 todo!()
833 }
834
835 #[cfg(any(test, feature = "test-support"))]
836 pub fn fake_completion(
837 &mut self,
838 project: &Entity<Project>,
839 buffer: &Entity<Buffer>,
840 position: language::Anchor,
841 response: PredictEditsResponse,
842 cx: &mut Context<Self>,
843 ) -> Task<Result<Option<EditPrediction>>> {
844 use std::future::ready;
845
846 self.request_completion_impl(
847 project,
848 buffer,
849 position,
850 CanCollectData(false),
851 cx,
852 |_params| ready(Ok((response, None))),
853 )
854 }
855
856 pub fn request_completion(
857 &mut self,
858 project: &Entity<Project>,
859 buffer: &Entity<Buffer>,
860 position: language::Anchor,
861 can_collect_data: CanCollectData,
862 cx: &mut Context<Self>,
863 ) -> Task<Result<Option<EditPrediction>>> {
864 self.request_completion_impl(
865 project,
866 buffer,
867 position,
868 can_collect_data,
869 cx,
870 Self::perform_predict_edits,
871 )
872 }
873
874 pub fn perform_predict_edits(
875 params: PerformPredictEditsParams,
876 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
877 async move {
878 let PerformPredictEditsParams {
879 client,
880 llm_token,
881 app_version,
882 body,
883 ..
884 } = params;
885
886 let http_client = client.http_client();
887 let mut token = llm_token.acquire(&client).await?;
888 let mut did_retry = false;
889
890 loop {
891 let request_builder = http_client::Request::builder().method(Method::POST);
892 let request_builder =
893 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
894 request_builder.uri(predict_edits_url)
895 } else {
896 request_builder.uri(
897 http_client
898 .build_zed_llm_url("/predict_edits/v2", &[])?
899 .as_ref(),
900 )
901 };
902 let request = request_builder
903 .header("Content-Type", "application/json")
904 .header("Authorization", format!("Bearer {}", token))
905 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
906 .body(serde_json::to_string(&body)?.into())?;
907
908 let mut response = http_client.send(request).await?;
909
910 if let Some(minimum_required_version) = response
911 .headers()
912 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
913 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
914 {
915 anyhow::ensure!(
916 app_version >= minimum_required_version,
917 ZedUpdateRequiredError {
918 minimum_version: minimum_required_version
919 }
920 );
921 }
922
923 if response.status().is_success() {
924 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
925
926 let mut body = String::new();
927 response.body_mut().read_to_string(&mut body).await?;
928 return Ok((serde_json::from_str(&body)?, usage));
929 } else if !did_retry
930 && response
931 .headers()
932 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
933 .is_some()
934 {
935 did_retry = true;
936 token = llm_token.refresh(&client).await?;
937 } else {
938 let mut body = String::new();
939 response.body_mut().read_to_string(&mut body).await?;
940 anyhow::bail!(
941 "error predicting edits.\nStatus: {:?}\nBody: {}",
942 response.status(),
943 body
944 );
945 }
946 }
947 }
948 }
949
950 fn accept_edit_prediction(
951 &mut self,
952 request_id: EditPredictionId,
953 cx: &mut Context<Self>,
954 ) -> Task<Result<()>> {
955 let client = self.client.clone();
956 let llm_token = self.llm_token.clone();
957 let app_version = AppVersion::global(cx);
958 cx.spawn(async move |this, cx| {
959 let http_client = client.http_client();
960 let mut response = llm_token_retry(&llm_token, &client, |token| {
961 let request_builder = http_client::Request::builder().method(Method::POST);
962 let request_builder =
963 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
964 request_builder.uri(accept_prediction_url)
965 } else {
966 request_builder.uri(
967 http_client
968 .build_zed_llm_url("/predict_edits/accept", &[])?
969 .as_ref(),
970 )
971 };
972 Ok(request_builder
973 .header("Content-Type", "application/json")
974 .header("Authorization", format!("Bearer {}", token))
975 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
976 .body(
977 serde_json::to_string(&AcceptEditPredictionBody {
978 request_id: request_id.0,
979 })?
980 .into(),
981 )?)
982 })
983 .await?;
984
985 if let Some(minimum_required_version) = response
986 .headers()
987 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
988 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
989 && app_version < minimum_required_version
990 {
991 return Err(anyhow!(ZedUpdateRequiredError {
992 minimum_version: minimum_required_version
993 }));
994 }
995
996 if response.status().is_success() {
997 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
998 this.update(cx, |this, cx| {
999 this.user_store.update(cx, |user_store, cx| {
1000 user_store.update_edit_prediction_usage(usage, cx);
1001 });
1002 })?;
1003 }
1004
1005 Ok(())
1006 } else {
1007 let mut body = String::new();
1008 response.body_mut().read_to_string(&mut body).await?;
1009 Err(anyhow!(
1010 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
1011 response.status(),
1012 body
1013 ))
1014 }
1015 })
1016 }
1017
1018 fn process_completion_response(
1019 prediction_response: PredictEditsResponse,
1020 buffer: Entity<Buffer>,
1021 snapshot: &BufferSnapshot,
1022 editable_range: Range<usize>,
1023 cursor_offset: usize,
1024 path: Arc<Path>,
1025 input_events: String,
1026 input_excerpt: String,
1027 buffer_snapshotted_at: Instant,
1028 cx: &AsyncApp,
1029 ) -> Task<Result<Option<EditPrediction>>> {
1030 let snapshot = snapshot.clone();
1031 let request_id = prediction_response.request_id;
1032 let output_excerpt = prediction_response.output_excerpt;
1033 cx.spawn(async move |cx| {
1034 let output_excerpt: Arc<str> = output_excerpt.into();
1035
1036 let edits: Arc<[(Range<Anchor>, String)]> = cx
1037 .background_spawn({
1038 let output_excerpt = output_excerpt.clone();
1039 let editable_range = editable_range.clone();
1040 let snapshot = snapshot.clone();
1041 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
1042 })
1043 .await?
1044 .into();
1045
1046 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
1047 let edits = edits.clone();
1048 |buffer, cx| {
1049 let new_snapshot = buffer.snapshot();
1050 let edits: Arc<[(Range<Anchor>, String)]> =
1051 interpolate(&snapshot, &new_snapshot, edits)?.into();
1052 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
1053 }
1054 })?
1055 else {
1056 return anyhow::Ok(None);
1057 };
1058
1059 let edit_preview = edit_preview.await;
1060
1061 Ok(Some(EditPrediction {
1062 id: EditPredictionId(request_id),
1063 path,
1064 excerpt_range: editable_range,
1065 cursor_offset,
1066 edits,
1067 edit_preview,
1068 snapshot,
1069 input_events: input_events.into(),
1070 input_excerpt: input_excerpt.into(),
1071 output_excerpt,
1072 buffer_snapshotted_at,
1073 response_received_at: Instant::now(),
1074 }))
1075 })
1076 }
1077
1078 fn parse_edits(
1079 output_excerpt: Arc<str>,
1080 editable_range: Range<usize>,
1081 snapshot: &BufferSnapshot,
1082 ) -> Result<Vec<(Range<Anchor>, String)>> {
1083 let content = output_excerpt.replace(CURSOR_MARKER, "");
1084
1085 let start_markers = content
1086 .match_indices(EDITABLE_REGION_START_MARKER)
1087 .collect::<Vec<_>>();
1088 anyhow::ensure!(
1089 start_markers.len() == 1,
1090 "expected exactly one start marker, found {}",
1091 start_markers.len()
1092 );
1093
1094 let end_markers = content
1095 .match_indices(EDITABLE_REGION_END_MARKER)
1096 .collect::<Vec<_>>();
1097 anyhow::ensure!(
1098 end_markers.len() == 1,
1099 "expected exactly one end marker, found {}",
1100 end_markers.len()
1101 );
1102
1103 let sof_markers = content
1104 .match_indices(START_OF_FILE_MARKER)
1105 .collect::<Vec<_>>();
1106 anyhow::ensure!(
1107 sof_markers.len() <= 1,
1108 "expected at most one start-of-file marker, found {}",
1109 sof_markers.len()
1110 );
1111
1112 let codefence_start = start_markers[0].0;
1113 let content = &content[codefence_start..];
1114
1115 let newline_ix = content.find('\n').context("could not find newline")?;
1116 let content = &content[newline_ix + 1..];
1117
1118 let codefence_end = content
1119 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1120 .context("could not find end marker")?;
1121 let new_text = &content[..codefence_end];
1122
1123 let old_text = snapshot
1124 .text_for_range(editable_range.clone())
1125 .collect::<String>();
1126
1127 Ok(Self::compute_edits(
1128 old_text,
1129 new_text,
1130 editable_range.start,
1131 snapshot,
1132 ))
1133 }
1134
1135 pub fn compute_edits(
1136 old_text: String,
1137 new_text: &str,
1138 offset: usize,
1139 snapshot: &BufferSnapshot,
1140 ) -> Vec<(Range<Anchor>, String)> {
1141 text_diff(&old_text, new_text)
1142 .into_iter()
1143 .map(|(mut old_range, new_text)| {
1144 old_range.start += offset;
1145 old_range.end += offset;
1146
1147 let prefix_len = common_prefix(
1148 snapshot.chars_for_range(old_range.clone()),
1149 new_text.chars(),
1150 );
1151 old_range.start += prefix_len;
1152
1153 let suffix_len = common_prefix(
1154 snapshot.reversed_chars_for_range(old_range.clone()),
1155 new_text[prefix_len..].chars().rev(),
1156 );
1157 old_range.end = old_range.end.saturating_sub(suffix_len);
1158
1159 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1160 let range = if old_range.is_empty() {
1161 let anchor = snapshot.anchor_after(old_range.start);
1162 anchor..anchor
1163 } else {
1164 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1165 };
1166 (range, new_text)
1167 })
1168 .collect()
1169 }
1170
1171 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1172 self.rated_completions.contains(&completion_id)
1173 }
1174
1175 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1176 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1177 let completion = self.shown_completions.pop_back().unwrap();
1178 self.rated_completions.remove(&completion.id);
1179 }
1180 self.shown_completions.push_front(completion.clone());
1181 cx.notify();
1182 }
1183
1184 pub fn rate_completion(
1185 &mut self,
1186 completion: &EditPrediction,
1187 rating: EditPredictionRating,
1188 feedback: String,
1189 cx: &mut Context<Self>,
1190 ) {
1191 self.rated_completions.insert(completion.id);
1192 telemetry::event!(
1193 "Edit Prediction Rated",
1194 rating,
1195 input_events = completion.input_events,
1196 input_excerpt = completion.input_excerpt,
1197 output_excerpt = completion.output_excerpt,
1198 feedback
1199 );
1200 self.client.telemetry().flush_events().detach();
1201 cx.notify();
1202 }
1203
1204 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1205 self.shown_completions.iter()
1206 }
1207
1208 pub fn shown_completions_len(&self) -> usize {
1209 self.shown_completions.len()
1210 }
1211
1212 fn report_changes_for_buffer(
1213 &mut self,
1214 buffer: &Entity<Buffer>,
1215 project: &Entity<Project>,
1216 cx: &mut Context<Self>,
1217 ) -> BufferSnapshot {
1218 let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
1219 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
1220
1221 let new_snapshot = buffer.read(cx).snapshot();
1222 if new_snapshot.version != registered_buffer.snapshot.version {
1223 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1224 Self::push_event(
1225 zeta_project,
1226 Event::BufferChange {
1227 old_snapshot,
1228 new_snapshot: new_snapshot.clone(),
1229 timestamp: Instant::now(),
1230 },
1231 );
1232 }
1233
1234 new_snapshot
1235 }
1236
1237 fn load_data_collection_choices() -> DataCollectionChoice {
1238 let choice = KEY_VALUE_STORE
1239 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1240 .log_err()
1241 .flatten();
1242
1243 match choice.as_deref() {
1244 Some("true") => DataCollectionChoice::Enabled,
1245 Some("false") => DataCollectionChoice::Disabled,
1246 Some(_) => {
1247 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1248 DataCollectionChoice::NotAnswered
1249 }
1250 None => DataCollectionChoice::NotAnswered,
1251 }
1252 }
1253
1254 /*
1255 fn gather_additional_context(
1256 &mut self,
1257 cursor_point: language::Point,
1258 cursor_offset: usize,
1259 snapshot: BufferSnapshot,
1260 buffer_snapshotted_at: &Instant,
1261 project_path: ProjectPath,
1262 project: &WeakEntity<Project>,
1263 cx: &mut Context<Self>,
1264 ) -> Option<Task<PredictEditsAdditionalContext>> {
1265 let project = project.upgrade()?.read(cx);
1266 let entry = project.entry_for_path(&project_path, cx)?;
1267 if !worktree_entry_is_eligible_for_collection(&entry) {
1268 return None;
1269 }
1270
1271 let git_store = project.git_store().read(cx);
1272 let (repository, repo_path) =
1273 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1274 let repo_path_string = repo_path.to_str()?.to_string();
1275
1276 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1277 snapshot
1278 .diagnostics
1279 .iter()
1280 .filter_map(|(language_server_id, diagnostics)| {
1281 let language_server =
1282 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1283 Some((
1284 *language_server_id,
1285 language_server.name(),
1286 diagnostics.clone(),
1287 ))
1288 })
1289 .collect()
1290 } else {
1291 Vec::new()
1292 };
1293
1294 repository.update(cx, |repository, cx| {
1295 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1296 let remote_origin_url = repository.remote_origin_url.clone();
1297 let remote_upstream_url = repository.remote_upstream_url.clone();
1298 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1299
1300 // group, resolve, and select diagnostics on a background thread
1301 Some(cx.background_spawn(async move {
1302 let mut diagnostic_groups_with_name = Vec::new();
1303 for (language_server_id, language_server_name, diagnostics) in
1304 diagnostics.into_iter()
1305 {
1306 let mut groups = Vec::new();
1307 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1308 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1309 (
1310 language_server_name.clone(),
1311 group.resolve::<usize>(&snapshot),
1312 )
1313 }));
1314 }
1315
1316 // sort by proximity to cursor
1317 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1318 let range = &group.entries[group.primary_ix].range;
1319 if range.start >= cursor_offset {
1320 range.start - cursor_offset
1321 } else if cursor_offset >= range.end {
1322 cursor_offset - range.end
1323 } else {
1324 (cursor_offset - range.start).min(range.end - cursor_offset)
1325 }
1326 });
1327
1328 let mut diagnostic_groups = Vec::new();
1329 let mut diagnostic_groups_truncated = false;
1330 let mut diagnostics_byte_count = 0;
1331 for (name, group) in diagnostic_groups_with_name {
1332 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1333 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1334 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1335 diagnostic_groups_truncated = true;
1336 break;
1337 }
1338 diagnostic_groups.push((name.to_string(), raw_value));
1339 }
1340
1341 PredictEditsAdditionalContext {
1342 input_path: repo_path_string,
1343 cursor_point: to_cloud_llm_client_point(cursor_point),
1344 cursor_offset: cursor_offset,
1345 git_info: PredictEditsGitInfo {
1346 head_sha: Some(head_sha),
1347 remote_origin_url,
1348 remote_upstream_url,
1349 },
1350 diagnostic_groups,
1351 diagnostic_groups_truncated,
1352 recent_files,
1353 }
1354 }))
1355 })
1356 }
1357
1358 fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1359 if !self.data_collection_choice.read(cx).is_enabled() {
1360 self.recent_editors.clear();
1361 self.last_activity_state = None;
1362 return;
1363 }
1364 if let Some(active_editor) = self
1365 .workspace
1366 .read_with(cx, |workspace, cx| {
1367 workspace
1368 .active_item(cx)
1369 .and_then(|item| item.act_as::<Editor>(cx))
1370 })
1371 .ok()
1372 .flatten()
1373 {
1374 let now = Instant::now();
1375 let editor = active_editor.downgrade();
1376 let existing_recent_editor = if let Some(existing_ix) = self
1377 .recent_editors
1378 .iter()
1379 .rposition(|recent| &recent.editor == &editor)
1380 {
1381 if existing_ix + 1 != self.recent_editors.len() {
1382 self.last_activity_state = None;
1383 }
1384 self.recent_editors.remove(existing_ix)
1385 } else {
1386 None
1387 };
1388 let new_recent = RecentEditor {
1389 editor: active_editor.downgrade(),
1390 last_active_at: now,
1391 activation_count: existing_recent_editor
1392 .as_ref()
1393 .map_or(0, |recent| recent.activation_count + 1),
1394 cumulative_time_navigating: existing_recent_editor
1395 .as_ref()
1396 .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1397 cumulative_time_editing: existing_recent_editor
1398 .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1399 };
1400 // filter out rapid changes in active item, particularly since this can happen rapidly when
1401 // a workspace is loaded.
1402 if let Some(previous_recent) = self.recent_editors.back_mut()
1403 && previous_recent.activation_count == 1
1404 && now.duration_since(previous_recent.last_active_at)
1405 < MIN_TIME_BETWEEN_RECENT_FILES
1406 {
1407 *previous_recent = new_recent;
1408 return;
1409 }
1410 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1411 self.recent_editors.pop_front();
1412 }
1413 self.recent_editors.push_back(new_recent);
1414 }
1415 }
1416
1417 fn handle_activity_poll(
1418 &mut self,
1419 instant_before_delay: Option<Instant>,
1420 now: Instant,
1421 cx: &mut Context<Self>,
1422 ) {
1423 if !self.data_collection_choice.read(cx).is_enabled() {
1424 self.last_activity_state = None;
1425 return;
1426 }
1427 if let Some(recent_editor) = self.recent_editors.back()
1428 && let Some(editor) = recent_editor.editor.upgrade()
1429 {
1430 let (scroll_position, cursor_point, singleton_version) =
1431 editor.update(cx, |editor, cx| {
1432 let scroll_position = editor.scroll_position(cx);
1433 let cursor_point = editor.selections.newest(cx).head();
1434 let singleton_version = editor
1435 .buffer()
1436 .read(cx)
1437 .as_singleton()
1438 .map(|singleton_buffer| singleton_buffer.read(cx).version());
1439 (scroll_position, cursor_point, singleton_version)
1440 });
1441
1442 let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1443 last_activity_state.scroll_position != scroll_position
1444 || last_activity_state.cursor_point != cursor_point
1445 } else {
1446 false
1447 };
1448
1449 let edited = if let Some(singleton_version) = &singleton_version
1450 && let Some(last_activity_state) = &self.last_activity_state
1451 && let Some(last_singleton_version) = &last_activity_state.singleton_version
1452 {
1453 singleton_version.changed_since(last_singleton_version)
1454 } else {
1455 false
1456 };
1457
1458 self.last_activity_state = Some(ActivityState {
1459 scroll_position,
1460 cursor_point,
1461 singleton_version,
1462 });
1463
1464 let prior_recent_editor = if self.recent_editors.len() > 1 {
1465 Some(&self.recent_editors[self.recent_editors.len() - 2])
1466 } else {
1467 None
1468 };
1469 let additional_time: Option<Duration> =
1470 instant_before_delay.map(|instant_before_delay| {
1471 now.duration_since(prior_recent_editor.map_or(
1472 instant_before_delay,
1473 |prior_recent_editor| {
1474 prior_recent_editor.last_active_at.max(instant_before_delay)
1475 },
1476 ))
1477 });
1478
1479 if let Some(additional_time) = additional_time {
1480 let recent_editor = self.recent_editors.back_mut().unwrap();
1481 if navigated {
1482 recent_editor.cumulative_time_navigating += additional_time;
1483 }
1484 if edited {
1485 recent_editor.cumulative_time_editing += additional_time;
1486 }
1487 }
1488 }
1489 }
1490
1491 fn recent_files(
1492 &mut self,
1493 now: &Instant,
1494 repository: &Repository,
1495 cx: &mut App,
1496 ) -> Vec<PredictEditsRecentFile> {
1497 let Ok(project) = self
1498 .workspace
1499 .read_with(cx, |workspace, _cx| workspace.project().clone())
1500 else {
1501 return Vec::new();
1502 };
1503 let mut results = Vec::with_capacity(self.recent_editors.len());
1504 for ix in (0..self.recent_editors.len()).rev() {
1505 let recent_editor = &self.recent_editors[ix];
1506 let keep_entry = recent_editor
1507 .editor
1508 .update(cx, |editor, cx| {
1509 maybe!({
1510 let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1511 let multibuffer = editor.buffer().read(cx);
1512 let (buffer, cursor_point, _) =
1513 multibuffer.point_to_buffer_point(cursor, cx)?;
1514 let file = buffer.read(cx).file()?;
1515 if !file_is_eligible_for_collection(file.as_ref()) {
1516 return None;
1517 }
1518 let project_path = ProjectPath {
1519 worktree_id: file.worktree_id(cx),
1520 path: file.path().clone(),
1521 };
1522 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1523 if !worktree_entry_is_eligible_for_collection(entry) {
1524 return None;
1525 }
1526 let Some(repo_path) =
1527 repository.project_path_to_repo_path(&project_path, cx)
1528 else {
1529 // entry not removed since later queries may involve other repositories
1530 return Some(());
1531 };
1532 // paths may not be valid UTF-8
1533 let repo_path_str = repo_path.to_str()?;
1534 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1535 return None;
1536 }
1537 let active_to_now_ms = now
1538 .duration_since(recent_editor.last_active_at)
1539 .as_millis()
1540 .try_into()
1541 .ok()?;
1542 let cumulative_time_editing_ms = recent_editor
1543 .cumulative_time_editing
1544 .as_millis()
1545 .try_into()
1546 .ok()?;
1547 let cumulative_time_navigating_ms = recent_editor
1548 .cumulative_time_navigating
1549 .as_millis()
1550 .try_into()
1551 .ok()?;
1552 results.push(PredictEditsRecentFile {
1553 path: repo_path_str.to_string(),
1554 cursor_point: to_cloud_llm_client_point(cursor_point),
1555 active_to_now_ms,
1556 activation_count: recent_editor.activation_count,
1557 cumulative_time_editing_ms,
1558 cumulative_time_navigating_ms,
1559 is_multibuffer: !multibuffer.is_singleton(),
1560 });
1561 Some(())
1562 })
1563 })
1564 .ok()
1565 .flatten();
1566 if keep_entry.is_none() {
1567 self.recent_editors.remove(ix);
1568 }
1569 }
1570 results
1571 }
1572 */
1573}
1574
1575fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1576 cloud_llm_client::Point {
1577 row: point.row,
1578 column: point.column,
1579 }
1580}
1581
1582fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1583 file.is_local() && !file.is_private()
1584}
1585
1586fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1587 entry.is_file()
1588 && entry.is_created()
1589 && !entry.is_ignored
1590 && !entry.is_private
1591 && !entry.is_external
1592 && !entry.is_fifo
1593}
1594
1595pub struct PerformPredictEditsParams {
1596 pub client: Arc<Client>,
1597 pub llm_token: LlmApiToken,
1598 pub app_version: SemanticVersion,
1599 pub body: PredictEditsBody,
1600}
1601
1602#[derive(Error, Debug)]
1603#[error(
1604 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1605)]
1606pub struct ZedUpdateRequiredError {
1607 minimum_version: SemanticVersion,
1608}
1609
1610fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1611 a.zip(b)
1612 .take_while(|(a, b)| a == b)
1613 .map(|(a, _)| a.len_utf8())
1614 .sum()
1615}
1616
1617pub struct GatherContextOutput {
1618 pub body: PredictEditsBody,
1619 pub editable_range: Range<usize>,
1620}
1621
1622pub async fn gather_context(
1623 full_path_str: String,
1624 snapshot: BufferSnapshot,
1625 cursor_point: language::Point,
1626 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1627 can_collect_data: CanCollectData,
1628) -> Result<GatherContextOutput> {
1629 let input_excerpt = excerpt_for_cursor_position(
1630 cursor_point,
1631 &full_path_str,
1632 &snapshot,
1633 MAX_REWRITE_TOKENS,
1634 MAX_CONTEXT_TOKENS,
1635 );
1636 let input_events = make_events_prompt();
1637 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1638
1639 let body = PredictEditsBody {
1640 input_events,
1641 input_excerpt: input_excerpt.prompt,
1642 can_collect_data: can_collect_data.0,
1643 diagnostic_groups: None,
1644 git_info: None,
1645 };
1646
1647 Ok(GatherContextOutput {
1648 body,
1649 editable_range,
1650 })
1651}
1652
1653fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1654 let mut result = String::new();
1655 for event in events.iter().rev() {
1656 let event_string = event.to_prompt();
1657 let event_tokens = tokens_for_bytes(event_string.len());
1658 if event_tokens > remaining_tokens {
1659 break;
1660 }
1661
1662 if !result.is_empty() {
1663 result.insert_str(0, "\n\n");
1664 }
1665 result.insert_str(0, &event_string);
1666 remaining_tokens -= event_tokens;
1667 }
1668 result
1669}
1670
1671struct RegisteredBuffer {
1672 snapshot: BufferSnapshot,
1673 _subscriptions: [gpui::Subscription; 2],
1674}
1675
1676#[derive(Clone)]
1677pub enum Event {
1678 BufferChange {
1679 old_snapshot: BufferSnapshot,
1680 new_snapshot: BufferSnapshot,
1681 timestamp: Instant,
1682 },
1683}
1684
1685impl Event {
1686 fn to_prompt(&self) -> String {
1687 match self {
1688 Event::BufferChange {
1689 old_snapshot,
1690 new_snapshot,
1691 ..
1692 } => {
1693 let mut prompt = String::new();
1694
1695 let old_path = old_snapshot
1696 .file()
1697 .map(|f| f.path().as_ref())
1698 .unwrap_or(Path::new("untitled"));
1699 let new_path = new_snapshot
1700 .file()
1701 .map(|f| f.path().as_ref())
1702 .unwrap_or(Path::new("untitled"));
1703 if old_path != new_path {
1704 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1705 }
1706
1707 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1708 if !diff.is_empty() {
1709 write!(
1710 prompt,
1711 "User edited {:?}:\n```diff\n{}\n```",
1712 new_path, diff
1713 )
1714 .unwrap();
1715 }
1716
1717 prompt
1718 }
1719 }
1720 }
1721}
1722
1723#[derive(Debug, Clone)]
1724struct CurrentEditPrediction {
1725 buffer_id: EntityId,
1726 completion: EditPrediction,
1727}
1728
1729impl CurrentEditPrediction {
1730 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1731 if self.buffer_id != old_completion.buffer_id {
1732 return true;
1733 }
1734
1735 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1736 return true;
1737 };
1738 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1739 return false;
1740 };
1741
1742 if old_edits.len() == 1 && new_edits.len() == 1 {
1743 let (old_range, old_text) = &old_edits[0];
1744 let (new_range, new_text) = &new_edits[0];
1745 new_range == old_range && new_text.starts_with(old_text)
1746 } else {
1747 true
1748 }
1749 }
1750}
1751
1752struct PendingCompletion {
1753 id: usize,
1754 _task: Task<()>,
1755}
1756
1757#[derive(Debug, Clone, Copy)]
1758pub enum DataCollectionChoice {
1759 NotAnswered,
1760 Enabled,
1761 Disabled,
1762}
1763
1764impl DataCollectionChoice {
1765 pub fn is_enabled(self) -> bool {
1766 match self {
1767 Self::Enabled => true,
1768 Self::NotAnswered | Self::Disabled => false,
1769 }
1770 }
1771
1772 pub fn is_answered(self) -> bool {
1773 match self {
1774 Self::Enabled | Self::Disabled => true,
1775 Self::NotAnswered => false,
1776 }
1777 }
1778
1779 pub fn toggle(&self) -> DataCollectionChoice {
1780 match self {
1781 Self::Enabled => Self::Disabled,
1782 Self::Disabled => Self::Enabled,
1783 Self::NotAnswered => Self::Enabled,
1784 }
1785 }
1786}
1787
1788impl From<bool> for DataCollectionChoice {
1789 fn from(value: bool) -> Self {
1790 match value {
1791 true => DataCollectionChoice::Enabled,
1792 false => DataCollectionChoice::Disabled,
1793 }
1794 }
1795}
1796
1797pub struct ProviderDataCollection {
1798 /// When set to None, data collection is not possible in the provider buffer
1799 choice: Option<Entity<DataCollectionChoice>>,
1800 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1801}
1802
1803#[derive(Debug, Clone, Copy)]
1804pub struct CanCollectData(pub bool);
1805
1806impl ProviderDataCollection {
1807 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1808 let choice_and_watcher = buffer.and_then(|buffer| {
1809 let file = buffer.read(cx).file()?;
1810
1811 if !file_is_eligible_for_collection(file.as_ref()) {
1812 return None;
1813 }
1814
1815 let zeta = zeta.read(cx);
1816 let choice = zeta.data_collection_choice.clone();
1817
1818 let license_detection_watcher = zeta
1819 .license_detection_watchers
1820 .get(&file.worktree_id(cx))
1821 .cloned()?;
1822
1823 Some((choice, license_detection_watcher))
1824 });
1825
1826 if let Some((choice, watcher)) = choice_and_watcher {
1827 ProviderDataCollection {
1828 choice: Some(choice),
1829 license_detection_watcher: Some(watcher),
1830 }
1831 } else {
1832 ProviderDataCollection {
1833 choice: None,
1834 license_detection_watcher: None,
1835 }
1836 }
1837 }
1838
1839 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1840 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1841 }
1842
1843 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1844 self.choice
1845 .as_ref()
1846 .is_some_and(|choice| choice.read(cx).is_enabled())
1847 }
1848
1849 fn is_project_open_source(&self) -> bool {
1850 self.license_detection_watcher
1851 .as_ref()
1852 .is_some_and(|watcher| watcher.is_project_open_source())
1853 }
1854
1855 pub fn toggle(&mut self, cx: &mut App) {
1856 if let Some(choice) = self.choice.as_mut() {
1857 let new_choice = choice.update(cx, |choice, _cx| {
1858 let new_choice = choice.toggle();
1859 *choice = new_choice;
1860 new_choice
1861 });
1862
1863 db::write_and_log(cx, move || {
1864 KEY_VALUE_STORE.write_kvp(
1865 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1866 new_choice.is_enabled().to_string(),
1867 )
1868 });
1869 }
1870 }
1871}
1872
1873async fn llm_token_retry(
1874 llm_token: &LlmApiToken,
1875 client: &Arc<Client>,
1876 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1877) -> Result<Response<AsyncBody>> {
1878 let mut did_retry = false;
1879 let http_client = client.http_client();
1880 let mut token = llm_token.acquire(client).await?;
1881 loop {
1882 let request = build_request(token.clone())?;
1883 let response = http_client.send(request).await?;
1884
1885 if !did_retry
1886 && !response.status().is_success()
1887 && response
1888 .headers()
1889 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1890 .is_some()
1891 {
1892 did_retry = true;
1893 token = llm_token.refresh(client).await?;
1894 continue;
1895 }
1896
1897 return Ok(response);
1898 }
1899}
1900
1901pub struct ZetaEditPredictionProvider {
1902 zeta: Entity<Zeta>,
1903 pending_completions: ArrayVec<PendingCompletion, 2>,
1904 next_pending_completion_id: usize,
1905 current_completion: Option<CurrentEditPrediction>,
1906 /// None if this is entirely disabled for this provider
1907 provider_data_collection: ProviderDataCollection,
1908 last_request_timestamp: Instant,
1909}
1910
1911impl ZetaEditPredictionProvider {
1912 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1913
1914 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1915 Self {
1916 zeta,
1917 pending_completions: ArrayVec::new(),
1918 next_pending_completion_id: 0,
1919 current_completion: None,
1920 provider_data_collection,
1921 last_request_timestamp: Instant::now(),
1922 }
1923 }
1924}
1925
1926impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1927 fn name() -> &'static str {
1928 "zed-predict"
1929 }
1930
1931 fn display_name() -> &'static str {
1932 "Zed's Edit Predictions"
1933 }
1934
1935 fn show_completions_in_menu() -> bool {
1936 true
1937 }
1938
1939 fn show_tab_accept_marker() -> bool {
1940 true
1941 }
1942
1943 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1944 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1945
1946 if self.provider_data_collection.is_data_collection_enabled(cx) {
1947 DataCollectionState::Enabled {
1948 is_project_open_source,
1949 }
1950 } else {
1951 DataCollectionState::Disabled {
1952 is_project_open_source,
1953 }
1954 }
1955 }
1956
1957 fn toggle_data_collection(&mut self, cx: &mut App) {
1958 self.provider_data_collection.toggle(cx);
1959 }
1960
1961 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1962 self.zeta.read(cx).usage(cx)
1963 }
1964
1965 fn is_enabled(
1966 &self,
1967 _buffer: &Entity<Buffer>,
1968 _cursor_position: language::Anchor,
1969 _cx: &App,
1970 ) -> bool {
1971 true
1972 }
1973 fn is_refreshing(&self) -> bool {
1974 !self.pending_completions.is_empty()
1975 }
1976
1977 fn refresh(
1978 &mut self,
1979 project: Option<Entity<Project>>,
1980 buffer: Entity<Buffer>,
1981 position: language::Anchor,
1982 _debounce: bool,
1983 cx: &mut Context<Self>,
1984 ) {
1985 if self.zeta.read(cx).update_required {
1986 return;
1987 }
1988 // todo! Don't require a project
1989 let Some(project) = project else {
1990 return;
1991 };
1992
1993 if self
1994 .zeta
1995 .read(cx)
1996 .user_store
1997 .read_with(cx, |user_store, _cx| {
1998 user_store.account_too_young() || user_store.has_overdue_invoices()
1999 })
2000 {
2001 return;
2002 }
2003
2004 if let Some(current_completion) = self.current_completion.as_ref() {
2005 let snapshot = buffer.read(cx).snapshot();
2006 if current_completion
2007 .completion
2008 .interpolate(&snapshot)
2009 .is_some()
2010 {
2011 return;
2012 }
2013 }
2014
2015 let pending_completion_id = self.next_pending_completion_id;
2016 self.next_pending_completion_id += 1;
2017 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
2018 let last_request_timestamp = self.last_request_timestamp;
2019
2020 let task = cx.spawn(async move |this, cx| {
2021 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
2022 .checked_duration_since(Instant::now())
2023 {
2024 cx.background_executor().timer(timeout).await;
2025 }
2026
2027 let completion_request = this.update(cx, |this, cx| {
2028 this.last_request_timestamp = Instant::now();
2029 this.zeta.update(cx, |zeta, cx| {
2030 zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
2031 })
2032 });
2033
2034 let completion = match completion_request {
2035 Ok(completion_request) => {
2036 let completion_request = completion_request.await;
2037 completion_request.map(|c| {
2038 c.map(|completion| CurrentEditPrediction {
2039 buffer_id: buffer.entity_id(),
2040 completion,
2041 })
2042 })
2043 }
2044 Err(error) => Err(error),
2045 };
2046 let Some(new_completion) = completion
2047 .context("edit prediction failed")
2048 .log_err()
2049 .flatten()
2050 else {
2051 this.update(cx, |this, cx| {
2052 if this.pending_completions[0].id == pending_completion_id {
2053 this.pending_completions.remove(0);
2054 } else {
2055 this.pending_completions.clear();
2056 }
2057
2058 cx.notify();
2059 })
2060 .ok();
2061 return;
2062 };
2063
2064 this.update(cx, |this, cx| {
2065 if this.pending_completions[0].id == pending_completion_id {
2066 this.pending_completions.remove(0);
2067 } else {
2068 this.pending_completions.clear();
2069 }
2070
2071 if let Some(old_completion) = this.current_completion.as_ref() {
2072 let snapshot = buffer.read(cx).snapshot();
2073 if new_completion.should_replace_completion(old_completion, &snapshot) {
2074 this.zeta.update(cx, |zeta, cx| {
2075 zeta.completion_shown(&new_completion.completion, cx);
2076 });
2077 this.current_completion = Some(new_completion);
2078 }
2079 } else {
2080 this.zeta.update(cx, |zeta, cx| {
2081 zeta.completion_shown(&new_completion.completion, cx);
2082 });
2083 this.current_completion = Some(new_completion);
2084 }
2085
2086 cx.notify();
2087 })
2088 .ok();
2089 });
2090
2091 // We always maintain at most two pending completions. When we already
2092 // have two, we replace the newest one.
2093 if self.pending_completions.len() <= 1 {
2094 self.pending_completions.push(PendingCompletion {
2095 id: pending_completion_id,
2096 _task: task,
2097 });
2098 } else if self.pending_completions.len() == 2 {
2099 self.pending_completions.pop();
2100 self.pending_completions.push(PendingCompletion {
2101 id: pending_completion_id,
2102 _task: task,
2103 });
2104 }
2105 }
2106
2107 fn cycle(
2108 &mut self,
2109 _buffer: Entity<Buffer>,
2110 _cursor_position: language::Anchor,
2111 _direction: edit_prediction::Direction,
2112 _cx: &mut Context<Self>,
2113 ) {
2114 // Right now we don't support cycling.
2115 }
2116
2117 fn accept(&mut self, cx: &mut Context<Self>) {
2118 let completion_id = self
2119 .current_completion
2120 .as_ref()
2121 .map(|completion| completion.completion.id);
2122 if let Some(completion_id) = completion_id {
2123 self.zeta
2124 .update(cx, |zeta, cx| {
2125 zeta.accept_edit_prediction(completion_id, cx)
2126 })
2127 .detach();
2128 }
2129 self.pending_completions.clear();
2130 }
2131
2132 fn discard(&mut self, _cx: &mut Context<Self>) {
2133 self.pending_completions.clear();
2134 self.current_completion.take();
2135 }
2136
2137 fn suggest(
2138 &mut self,
2139 buffer: &Entity<Buffer>,
2140 cursor_position: language::Anchor,
2141 cx: &mut Context<Self>,
2142 ) -> Option<edit_prediction::EditPrediction> {
2143 let CurrentEditPrediction {
2144 buffer_id,
2145 completion,
2146 ..
2147 } = self.current_completion.as_mut()?;
2148
2149 // Invalidate previous completion if it was generated for a different buffer.
2150 if *buffer_id != buffer.entity_id() {
2151 self.current_completion.take();
2152 return None;
2153 }
2154
2155 let buffer = buffer.read(cx);
2156 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2157 self.current_completion.take();
2158 return None;
2159 };
2160
2161 let cursor_row = cursor_position.to_point(buffer).row;
2162 let (closest_edit_ix, (closest_edit_range, _)) =
2163 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2164 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2165 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2166 cmp::min(distance_from_start, distance_from_end)
2167 })?;
2168
2169 let mut edit_start_ix = closest_edit_ix;
2170 for (range, _) in edits[..edit_start_ix].iter().rev() {
2171 let distance_from_closest_edit =
2172 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2173 if distance_from_closest_edit <= 1 {
2174 edit_start_ix -= 1;
2175 } else {
2176 break;
2177 }
2178 }
2179
2180 let mut edit_end_ix = closest_edit_ix + 1;
2181 for (range, _) in &edits[edit_end_ix..] {
2182 let distance_from_closest_edit =
2183 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2184 if distance_from_closest_edit <= 1 {
2185 edit_end_ix += 1;
2186 } else {
2187 break;
2188 }
2189 }
2190
2191 Some(edit_prediction::EditPrediction {
2192 id: Some(completion.id.to_string().into()),
2193 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2194 edit_preview: Some(completion.edit_preview.clone()),
2195 })
2196 }
2197}
2198
2199fn tokens_for_bytes(bytes: usize) -> usize {
2200 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2201 /// intentionally low to err on the side of underestimating limits.
2202 const BYTES_PER_TOKEN_GUESS: usize = 3;
2203 bytes / BYTES_PER_TOKEN_GUESS
2204}
2205
2206/* todo!
2207#[cfg(test)]
2208mod tests {
2209 use client::UserStore;
2210 use client::test::FakeServer;
2211 use clock::FakeSystemClock;
2212 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2213 use gpui::TestAppContext;
2214 use http_client::FakeHttpClient;
2215 use indoc::indoc;
2216 use language::Point;
2217 use settings::SettingsStore;
2218
2219 use super::*;
2220
2221 #[gpui::test]
2222 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2223 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2224 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2225 to_completion_edits(
2226 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2227 &buffer,
2228 cx,
2229 )
2230 .into()
2231 });
2232
2233 let edit_preview = cx
2234 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2235 .await;
2236
2237 let completion = EditPrediction {
2238 edits,
2239 edit_preview,
2240 path: Path::new("").into(),
2241 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2242 id: EditPredictionId(Uuid::new_v4()),
2243 excerpt_range: 0..0,
2244 cursor_offset: 0,
2245 input_events: "".into(),
2246 input_excerpt: "".into(),
2247 output_excerpt: "".into(),
2248 buffer_snapshotted_at: Instant::now(),
2249 response_received_at: Instant::now(),
2250 };
2251
2252 cx.update(|cx| {
2253 assert_eq!(
2254 from_completion_edits(
2255 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2256 &buffer,
2257 cx
2258 ),
2259 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2260 );
2261
2262 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2263 assert_eq!(
2264 from_completion_edits(
2265 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2266 &buffer,
2267 cx
2268 ),
2269 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2270 );
2271
2272 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2273 assert_eq!(
2274 from_completion_edits(
2275 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2276 &buffer,
2277 cx
2278 ),
2279 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2280 );
2281
2282 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2283 assert_eq!(
2284 from_completion_edits(
2285 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2286 &buffer,
2287 cx
2288 ),
2289 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2290 );
2291
2292 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2293 assert_eq!(
2294 from_completion_edits(
2295 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2296 &buffer,
2297 cx
2298 ),
2299 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2300 );
2301
2302 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2303 assert_eq!(
2304 from_completion_edits(
2305 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2306 &buffer,
2307 cx
2308 ),
2309 vec![(9..11, "".to_string())]
2310 );
2311
2312 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2313 assert_eq!(
2314 from_completion_edits(
2315 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2316 &buffer,
2317 cx
2318 ),
2319 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2320 );
2321
2322 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2323 assert_eq!(
2324 from_completion_edits(
2325 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2326 &buffer,
2327 cx
2328 ),
2329 vec![(4..4, "M".to_string())]
2330 );
2331
2332 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2333 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2334 })
2335 }
2336
2337 #[gpui::test]
2338 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2339 cx.update(|cx| {
2340 let settings_store = SettingsStore::test(cx);
2341 cx.set_global(settings_store);
2342 client::init_settings(cx);
2343 });
2344
2345 let edits = edits_for_prediction(
2346 indoc! {"
2347 fn main() {
2348 let word_1 = \"lorem\";
2349 let range = word.len()..word.len();
2350 }
2351 "},
2352 indoc! {"
2353 <|editable_region_start|>
2354 fn main() {
2355 let word_1 = \"lorem\";
2356 let range = word_1.len()..word_1.len();
2357 }
2358
2359 <|editable_region_end|>
2360 "},
2361 cx,
2362 )
2363 .await;
2364 assert_eq!(
2365 edits,
2366 [
2367 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2368 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2369 ]
2370 );
2371
2372 let edits = edits_for_prediction(
2373 indoc! {"
2374 fn main() {
2375 let story = \"the quick\"
2376 }
2377 "},
2378 indoc! {"
2379 <|editable_region_start|>
2380 fn main() {
2381 let story = \"the quick brown fox jumps over the lazy dog\";
2382 }
2383
2384 <|editable_region_end|>
2385 "},
2386 cx,
2387 )
2388 .await;
2389 assert_eq!(
2390 edits,
2391 [
2392 (
2393 Point::new(1, 26)..Point::new(1, 26),
2394 " brown fox jumps over the lazy dog".to_string()
2395 ),
2396 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2397 ]
2398 );
2399 }
2400
2401 #[gpui::test]
2402 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2403 cx.update(|cx| {
2404 let settings_store = SettingsStore::test(cx);
2405 cx.set_global(settings_store);
2406 client::init_settings(cx);
2407 });
2408
2409 let buffer_content = "lorem\n";
2410 let completion_response = indoc! {"
2411 ```animals.js
2412 <|start_of_file|>
2413 <|editable_region_start|>
2414 lorem
2415 ipsum
2416 <|editable_region_end|>
2417 ```"};
2418
2419 let http_client = FakeHttpClient::create(move |req| async move {
2420 match (req.method(), req.uri().path()) {
2421 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2422 .status(200)
2423 .body(
2424 serde_json::to_string(&CreateLlmTokenResponse {
2425 token: LlmToken("the-llm-token".to_string()),
2426 })
2427 .unwrap()
2428 .into(),
2429 )
2430 .unwrap()),
2431 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2432 .status(200)
2433 .body(
2434 serde_json::to_string(&PredictEditsResponse {
2435 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2436 .unwrap(),
2437 output_excerpt: completion_response.to_string(),
2438 })
2439 .unwrap()
2440 .into(),
2441 )
2442 .unwrap()),
2443 _ => Ok(http_client::Response::builder()
2444 .status(404)
2445 .body("Not Found".into())
2446 .unwrap()),
2447 }
2448 });
2449
2450 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2451 cx.update(|cx| {
2452 RefreshLlmTokenListener::register(client.clone(), cx);
2453 });
2454 // Construct the fake server to authenticate.
2455 let _server = FakeServer::for_client(42, &client, cx).await;
2456 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2457 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2458
2459 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2460 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2461 let completion_task = zeta.update(cx, |zeta, cx| {
2462 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2463 });
2464
2465 let completion = completion_task.await.unwrap().unwrap();
2466 buffer.update(cx, |buffer, cx| {
2467 buffer.edit(completion.edits.iter().cloned(), None, cx)
2468 });
2469 assert_eq!(
2470 buffer.read_with(cx, |buffer, _| buffer.text()),
2471 "lorem\nipsum"
2472 );
2473 }
2474
2475 async fn edits_for_prediction(
2476 buffer_content: &str,
2477 completion_response: &str,
2478 cx: &mut TestAppContext,
2479 ) -> Vec<(Range<Point>, String)> {
2480 let completion_response = completion_response.to_string();
2481 let http_client = FakeHttpClient::create(move |req| {
2482 let completion = completion_response.clone();
2483 async move {
2484 match (req.method(), req.uri().path()) {
2485 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2486 .status(200)
2487 .body(
2488 serde_json::to_string(&CreateLlmTokenResponse {
2489 token: LlmToken("the-llm-token".to_string()),
2490 })
2491 .unwrap()
2492 .into(),
2493 )
2494 .unwrap()),
2495 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2496 .status(200)
2497 .body(
2498 serde_json::to_string(&PredictEditsResponse {
2499 request_id: Uuid::new_v4(),
2500 output_excerpt: completion,
2501 })
2502 .unwrap()
2503 .into(),
2504 )
2505 .unwrap()),
2506 _ => Ok(http_client::Response::builder()
2507 .status(404)
2508 .body("Not Found".into())
2509 .unwrap()),
2510 }
2511 }
2512 });
2513
2514 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2515 cx.update(|cx| {
2516 RefreshLlmTokenListener::register(client.clone(), cx);
2517 });
2518 // Construct the fake server to authenticate.
2519 let _server = FakeServer::for_client(42, &client, cx).await;
2520 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2521 let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2522
2523 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2524 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2525 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2526 let completion_task = zeta.update(cx, |zeta, cx| {
2527 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2528 });
2529
2530 let completion = completion_task.await.unwrap().unwrap();
2531 completion
2532 .edits
2533 .iter()
2534 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2535 .collect::<Vec<_>>()
2536 }
2537
2538 fn to_completion_edits(
2539 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2540 buffer: &Entity<Buffer>,
2541 cx: &App,
2542 ) -> Vec<(Range<Anchor>, String)> {
2543 let buffer = buffer.read(cx);
2544 iterator
2545 .into_iter()
2546 .map(|(range, text)| {
2547 (
2548 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2549 text,
2550 )
2551 })
2552 .collect()
2553 }
2554
2555 fn from_completion_edits(
2556 editor_edits: &[(Range<Anchor>, String)],
2557 buffer: &Entity<Buffer>,
2558 cx: &App,
2559 ) -> Vec<(Range<usize>, String)> {
2560 let buffer = buffer.read(cx);
2561 editor_edits
2562 .iter()
2563 .map(|(range, text)| {
2564 (
2565 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2566 text.clone(),
2567 )
2568 })
2569 .collect()
2570 }
2571
2572 #[ctor::ctor]
2573 fn init_logger() {
2574 zlog::init_test();
2575 }
2576}
2577*/