1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13pub use init::*;
14use license_detection::LicenseDetectionWatcher;
15use project::git_store::Repository;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use client::{Client, EditPredictionUsage, UserStore};
20use cloud_llm_client::{
21 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
22 PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
23 ZED_VERSION_HEADER_NAME,
24};
25use collections::{HashMap, HashSet, VecDeque};
26use futures::AsyncReadExt;
27use gpui::{
28 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
29 Subscription, Task, WeakEntity, actions,
30};
31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
32use input_excerpt::excerpt_for_cursor_position;
33use language::{
34 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
35};
36use language_model::{LlmApiToken, RefreshLlmTokenListener};
37use project::{Project, ProjectEntryId, ProjectPath};
38use release_channel::AppVersion;
39use settings::WorktreeId;
40use std::str::FromStr;
41use std::{
42 cmp,
43 fmt::Write,
44 future::Future,
45 mem,
46 ops::Range,
47 path::Path,
48 rc::Rc,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use telemetry_events::EditPredictionRating;
53use thiserror::Error;
54use util::ResultExt;
55use uuid::Uuid;
56use workspace::Workspace;
57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
58use worktree::Worktree;
59
60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
66
67const MAX_CONTEXT_TOKENS: usize = 150;
68const MAX_REWRITE_TOKENS: usize = 350;
69const MAX_EVENT_TOKENS: usize = 500;
70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Minimum number of milliseconds between recent project entries to keep them
79const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
80
81/// Maximum file path length to include in recent files list.
82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
83
84/// Maximum number of edit predictions to store for feedback.
85const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
86
87actions!(
88 edit_prediction,
89 [
90 /// Clears the edit prediction history.
91 ClearHistory
92 ]
93);
94
95#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
96pub struct EditPredictionId(Uuid);
97
98impl From<EditPredictionId> for gpui::ElementId {
99 fn from(value: EditPredictionId) -> Self {
100 gpui::ElementId::Uuid(value.0)
101 }
102}
103
104impl std::fmt::Display for EditPredictionId {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", self.0)
107 }
108}
109
110struct ZedPredictUpsell;
111
112impl Dismissable for ZedPredictUpsell {
113 const KEY: &'static str = "dismissed-edit-predict-upsell";
114
115 fn dismissed() -> bool {
116 // To make this backwards compatible with older versions of Zed, we
117 // check if the user has seen the previous Edit Prediction Onboarding
118 // before, by checking the data collection choice which was written to
119 // the database once the user clicked on "Accept and Enable"
120 if KEY_VALUE_STORE
121 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
122 .log_err()
123 .is_some_and(|s| s.is_some())
124 {
125 return true;
126 }
127
128 KEY_VALUE_STORE
129 .read_kvp(Self::KEY)
130 .log_err()
131 .is_some_and(|s| s.is_some())
132 }
133}
134
135pub fn should_show_upsell_modal() -> bool {
136 !ZedPredictUpsell::dismissed()
137}
138
139#[derive(Clone)]
140struct ZetaGlobal(Entity<Zeta>);
141
142impl Global for ZetaGlobal {}
143
144#[derive(Clone)]
145pub struct EditPrediction {
146 id: EditPredictionId,
147 path: Arc<Path>,
148 excerpt_range: Range<usize>,
149 cursor_offset: usize,
150 edits: Arc<[(Range<Anchor>, String)]>,
151 snapshot: BufferSnapshot,
152 edit_preview: EditPreview,
153 input_outline: Arc<str>,
154 input_events: Arc<str>,
155 input_excerpt: Arc<str>,
156 output_excerpt: Arc<str>,
157 buffer_snapshotted_at: Instant,
158 response_received_at: Instant,
159}
160
161impl EditPrediction {
162 fn latency(&self) -> Duration {
163 self.response_received_at
164 .duration_since(self.buffer_snapshotted_at)
165 }
166
167 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
168 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
169 }
170}
171
172fn interpolate(
173 old_snapshot: &BufferSnapshot,
174 new_snapshot: &BufferSnapshot,
175 current_edits: Arc<[(Range<Anchor>, String)]>,
176) -> Option<Vec<(Range<Anchor>, String)>> {
177 let mut edits = Vec::new();
178
179 let mut model_edits = current_edits.iter().peekable();
180 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
181 while let Some((model_old_range, _)) = model_edits.peek() {
182 let model_old_range = model_old_range.to_offset(old_snapshot);
183 if model_old_range.end < user_edit.old.start {
184 let (model_old_range, model_new_text) = model_edits.next().unwrap();
185 edits.push((model_old_range.clone(), model_new_text.clone()));
186 } else {
187 break;
188 }
189 }
190
191 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
192 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
193 if user_edit.old == model_old_offset_range {
194 let user_new_text = new_snapshot
195 .text_for_range(user_edit.new.clone())
196 .collect::<String>();
197
198 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
199 if !model_suffix.is_empty() {
200 let anchor = old_snapshot.anchor_after(user_edit.old.end);
201 edits.push((anchor..anchor, model_suffix.to_string()));
202 }
203
204 model_edits.next();
205 continue;
206 }
207 }
208 }
209
210 return None;
211 }
212
213 edits.extend(model_edits.cloned());
214
215 if edits.is_empty() { None } else { Some(edits) }
216}
217
218impl std::fmt::Debug for EditPrediction {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("EditPrediction")
221 .field("id", &self.id)
222 .field("path", &self.path)
223 .field("edits", &self.edits)
224 .finish_non_exhaustive()
225 }
226}
227
228pub struct Zeta {
229 workspace: WeakEntity<Workspace>,
230 client: Arc<Client>,
231 events: VecDeque<Event>,
232 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
233 shown_completions: VecDeque<EditPrediction>,
234 rated_completions: HashSet<EditPredictionId>,
235 data_collection_choice: Entity<DataCollectionChoice>,
236 llm_token: LlmApiToken,
237 _llm_token_subscription: Subscription,
238 /// Whether an update to a newer version of Zed is required to continue using Zeta.
239 update_required: bool,
240 user_store: Entity<UserStore>,
241 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
242 recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
243}
244
245impl Zeta {
246 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
247 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
248 }
249
250 pub fn register(
251 workspace: Option<Entity<Workspace>>,
252 worktree: Option<Entity<Worktree>>,
253 client: Arc<Client>,
254 user_store: Entity<UserStore>,
255 cx: &mut App,
256 ) -> Entity<Self> {
257 let this = Self::global(cx).unwrap_or_else(|| {
258 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
259 cx.set_global(ZetaGlobal(entity.clone()));
260 entity
261 });
262
263 this.update(cx, move |this, cx| {
264 if let Some(worktree) = worktree {
265 let worktree_id = worktree.read(cx).id();
266 this.license_detection_watchers
267 .entry(worktree_id)
268 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
269 }
270 });
271
272 this
273 }
274
275 pub fn clear_history(&mut self) {
276 self.events.clear();
277 }
278
279 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
280 self.user_store.read(cx).edit_prediction_usage()
281 }
282
283 fn new(
284 workspace: Option<Entity<Workspace>>,
285 client: Arc<Client>,
286 user_store: Entity<UserStore>,
287 cx: &mut Context<Self>,
288 ) -> Self {
289 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
290
291 let data_collection_choice = Self::load_data_collection_choices();
292 let data_collection_choice = cx.new(|_| data_collection_choice);
293
294 if let Some(workspace) = &workspace {
295 cx.subscribe(
296 &workspace.read(cx).project().clone(),
297 |this, _workspace, event, _cx| match event {
298 project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
299 this.push_recent_project_entry(*project_entry_id)
300 }
301 _ => {}
302 },
303 )
304 .detach();
305 }
306
307 Self {
308 workspace: workspace.map_or_else(
309 || WeakEntity::new_invalid(),
310 |workspace| workspace.downgrade(),
311 ),
312 client,
313 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
314 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
315 rated_completions: HashSet::default(),
316 registered_buffers: HashMap::default(),
317 data_collection_choice,
318 llm_token: LlmApiToken::default(),
319 _llm_token_subscription: cx.subscribe(
320 &refresh_llm_token_listener,
321 |this, _listener, _event, cx| {
322 let client = this.client.clone();
323 let llm_token = this.llm_token.clone();
324 cx.spawn(async move |_this, _cx| {
325 llm_token.refresh(&client).await?;
326 anyhow::Ok(())
327 })
328 .detach_and_log_err(cx);
329 },
330 ),
331 update_required: false,
332 license_detection_watchers: HashMap::default(),
333 user_store,
334 recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
335 }
336 }
337
338 fn push_event(&mut self, event: Event) {
339 if let Some(Event::BufferChange {
340 new_snapshot: last_new_snapshot,
341 timestamp: last_timestamp,
342 ..
343 }) = self.events.back_mut()
344 {
345 // Coalesce edits for the same buffer when they happen one after the other.
346 let Event::BufferChange {
347 old_snapshot,
348 new_snapshot,
349 timestamp,
350 } = &event;
351
352 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
353 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
354 && old_snapshot.version == last_new_snapshot.version
355 {
356 *last_new_snapshot = new_snapshot.clone();
357 *last_timestamp = *timestamp;
358 return;
359 }
360 }
361
362 if self.events.len() >= MAX_EVENT_COUNT {
363 // These are halved instead of popping to improve prompt caching.
364 self.events.drain(..MAX_EVENT_COUNT / 2);
365 }
366
367 self.events.push_back(event);
368 }
369
370 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
371 let buffer_id = buffer.entity_id();
372 let weak_buffer = buffer.downgrade();
373
374 if let std::collections::hash_map::Entry::Vacant(entry) =
375 self.registered_buffers.entry(buffer_id)
376 {
377 let snapshot = buffer.read(cx).snapshot();
378
379 entry.insert(RegisteredBuffer {
380 snapshot,
381 _subscriptions: [
382 cx.subscribe(buffer, move |this, buffer, event, cx| {
383 this.handle_buffer_event(buffer, event, cx);
384 }),
385 cx.observe_release(buffer, move |this, _buffer, _cx| {
386 this.registered_buffers.remove(&weak_buffer.entity_id());
387 }),
388 ],
389 });
390 };
391 }
392
393 fn handle_buffer_event(
394 &mut self,
395 buffer: Entity<Buffer>,
396 event: &language::BufferEvent,
397 cx: &mut Context<Self>,
398 ) {
399 if let language::BufferEvent::Edited = event {
400 self.report_changes_for_buffer(&buffer, cx);
401 }
402 }
403
404 fn request_completion_impl<F, R>(
405 &mut self,
406 workspace: Option<Entity<Workspace>>,
407 project: Option<&Entity<Project>>,
408 buffer: &Entity<Buffer>,
409 cursor: language::Anchor,
410 can_collect_data: CanCollectData,
411 cx: &mut Context<Self>,
412 perform_predict_edits: F,
413 ) -> Task<Result<Option<EditPrediction>>>
414 where
415 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
416 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
417 + Send
418 + 'static,
419 {
420 let buffer = buffer.clone();
421 let buffer_snapshotted_at = Instant::now();
422 let snapshot = self.report_changes_for_buffer(&buffer, cx);
423 let zeta = cx.entity();
424 let events = self.events.clone();
425 let client = self.client.clone();
426 let llm_token = self.llm_token.clone();
427 let app_version = AppVersion::global(cx);
428
429 let cursor_point = cursor.to_point(&snapshot);
430 let cursor_offset = cursor_point.to_offset(&snapshot);
431 let git_info = if matches!(can_collect_data, CanCollectData(true)) {
432 self.gather_git_info(
433 cursor_point.clone(),
434 &buffer_snapshotted_at,
435 &snapshot,
436 project.clone(),
437 cx,
438 )
439 } else {
440 None
441 };
442
443 let full_path: Arc<Path> = snapshot
444 .file()
445 .map(|f| Arc::from(f.full_path(cx).as_path()))
446 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
447 let full_path_str = full_path.to_string_lossy().to_string();
448 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
449 let gather_task = gather_context(
450 project,
451 full_path_str,
452 &snapshot,
453 cursor_point,
454 make_events_prompt,
455 can_collect_data,
456 git_info,
457 cx,
458 );
459
460 cx.spawn(async move |this, cx| {
461 let GatherContextOutput {
462 body,
463 editable_range,
464 } = gather_task.await?;
465 let done_gathering_context_at = Instant::now();
466
467 log::debug!(
468 "Events:\n{}\nExcerpt:\n{:?}",
469 body.input_events,
470 body.input_excerpt
471 );
472
473 let input_outline = body.outline.clone().unwrap_or_default();
474 let input_events = body.input_events.clone();
475 let input_excerpt = body.input_excerpt.clone();
476
477 let response = perform_predict_edits(PerformPredictEditsParams {
478 client,
479 llm_token,
480 app_version,
481 body,
482 })
483 .await;
484 let (response, usage) = match response {
485 Ok(response) => response,
486 Err(err) => {
487 if err.is::<ZedUpdateRequiredError>() {
488 cx.update(|cx| {
489 zeta.update(cx, |zeta, _cx| {
490 zeta.update_required = true;
491 });
492
493 if let Some(workspace) = workspace {
494 workspace.update(cx, |workspace, cx| {
495 workspace.show_notification(
496 NotificationId::unique::<ZedUpdateRequiredError>(),
497 cx,
498 |cx| {
499 cx.new(|cx| {
500 ErrorMessagePrompt::new(err.to_string(), cx)
501 .with_link_button(
502 "Update Zed",
503 "https://zed.dev/releases",
504 )
505 })
506 },
507 );
508 });
509 }
510 })
511 .ok();
512 }
513
514 return Err(err);
515 }
516 };
517
518 let received_response_at = Instant::now();
519 log::debug!("completion response: {}", &response.output_excerpt);
520
521 if let Some(usage) = usage {
522 this.update(cx, |this, cx| {
523 this.user_store.update(cx, |user_store, cx| {
524 user_store.update_edit_prediction_usage(usage, cx);
525 });
526 })
527 .ok();
528 }
529
530 let edit_prediction = Self::process_completion_response(
531 response,
532 buffer,
533 &snapshot,
534 editable_range,
535 cursor_offset,
536 full_path,
537 input_outline,
538 input_events,
539 input_excerpt,
540 buffer_snapshotted_at,
541 cx,
542 )
543 .await;
544
545 let finished_at = Instant::now();
546
547 // record latency for ~1% of requests
548 if rand::random::<u8>() <= 2 {
549 telemetry::event!(
550 "Edit Prediction Request",
551 context_latency = done_gathering_context_at
552 .duration_since(buffer_snapshotted_at)
553 .as_millis(),
554 request_latency = received_response_at
555 .duration_since(done_gathering_context_at)
556 .as_millis(),
557 process_latency = finished_at.duration_since(received_response_at).as_millis()
558 );
559 }
560
561 edit_prediction
562 })
563 }
564
565 // Generates several example completions of various states to fill the Zeta completion modal
566 #[cfg(any(test, feature = "test-support"))]
567 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
568 use language::Point;
569
570 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
571 And maybe a short line
572
573 Then a few lines
574
575 and then another
576 "#};
577
578 let project = None;
579 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
580 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
581
582 let completion_tasks = vec![
583 self.fake_completion(
584 project,
585 &buffer,
586 position,
587 PredictEditsResponse {
588 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
589 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
590a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
591[here's an edit]
592And maybe a short line
593Then a few lines
594and then another
595{EDITABLE_REGION_END_MARKER}
596 ", ),
597 },
598 cx,
599 ),
600 self.fake_completion(
601 project,
602 &buffer,
603 position,
604 PredictEditsResponse {
605 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
606 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
607a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
608And maybe a short line
609[and another edit]
610Then a few lines
611and then another
612{EDITABLE_REGION_END_MARKER}
613 "#),
614 },
615 cx,
616 ),
617 self.fake_completion(
618 project,
619 &buffer,
620 position,
621 PredictEditsResponse {
622 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
623 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
624a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
625And maybe a short line
626
627Then a few lines
628
629and then another
630{EDITABLE_REGION_END_MARKER}
631 "#),
632 },
633 cx,
634 ),
635 self.fake_completion(
636 project,
637 &buffer,
638 position,
639 PredictEditsResponse {
640 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
641 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
642a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
643And maybe a short line
644
645Then a few lines
646
647and then another
648{EDITABLE_REGION_END_MARKER}
649 "#),
650 },
651 cx,
652 ),
653 self.fake_completion(
654 project,
655 &buffer,
656 position,
657 PredictEditsResponse {
658 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
659 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
660a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
661And maybe a short line
662Then a few lines
663[a third completion]
664and then another
665{EDITABLE_REGION_END_MARKER}
666 "#),
667 },
668 cx,
669 ),
670 self.fake_completion(
671 project,
672 &buffer,
673 position,
674 PredictEditsResponse {
675 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
676 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
677a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
678And maybe a short line
679and then another
680[fourth completion example]
681{EDITABLE_REGION_END_MARKER}
682 "#),
683 },
684 cx,
685 ),
686 self.fake_completion(
687 project,
688 &buffer,
689 position,
690 PredictEditsResponse {
691 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
692 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
693a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
694And maybe a short line
695Then a few lines
696and then another
697[fifth and final completion]
698{EDITABLE_REGION_END_MARKER}
699 "#),
700 },
701 cx,
702 ),
703 ];
704
705 cx.spawn(async move |zeta, cx| {
706 for task in completion_tasks {
707 task.await.unwrap();
708 }
709
710 zeta.update(cx, |zeta, _cx| {
711 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
712 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
713 })
714 .ok();
715 })
716 }
717
718 #[cfg(any(test, feature = "test-support"))]
719 pub fn fake_completion(
720 &mut self,
721 project: Option<&Entity<Project>>,
722 buffer: &Entity<Buffer>,
723 position: language::Anchor,
724 response: PredictEditsResponse,
725 cx: &mut Context<Self>,
726 ) -> Task<Result<Option<EditPrediction>>> {
727 use std::future::ready;
728
729 self.request_completion_impl(
730 None,
731 project,
732 buffer,
733 position,
734 CanCollectData(false),
735 cx,
736 |_params| ready(Ok((response, None))),
737 )
738 }
739
740 pub fn request_completion(
741 &mut self,
742 project: Option<&Entity<Project>>,
743 buffer: &Entity<Buffer>,
744 position: language::Anchor,
745 can_collect_data: CanCollectData,
746 cx: &mut Context<Self>,
747 ) -> Task<Result<Option<EditPrediction>>> {
748 self.request_completion_impl(
749 self.workspace.upgrade(),
750 project,
751 buffer,
752 position,
753 can_collect_data,
754 cx,
755 Self::perform_predict_edits,
756 )
757 }
758
759 pub fn perform_predict_edits(
760 params: PerformPredictEditsParams,
761 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
762 async move {
763 let PerformPredictEditsParams {
764 client,
765 llm_token,
766 app_version,
767 body,
768 ..
769 } = params;
770
771 let http_client = client.http_client();
772 let mut token = llm_token.acquire(&client).await?;
773 let mut did_retry = false;
774
775 loop {
776 let request_builder = http_client::Request::builder().method(Method::POST);
777 let request_builder =
778 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
779 request_builder.uri(predict_edits_url)
780 } else {
781 request_builder.uri(
782 http_client
783 .build_zed_llm_url("/predict_edits/v2", &[])?
784 .as_ref(),
785 )
786 };
787 let request = request_builder
788 .header("Content-Type", "application/json")
789 .header("Authorization", format!("Bearer {}", token))
790 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
791 .body(serde_json::to_string(&body)?.into())?;
792
793 let mut response = http_client.send(request).await?;
794
795 if let Some(minimum_required_version) = response
796 .headers()
797 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
798 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
799 {
800 anyhow::ensure!(
801 app_version >= minimum_required_version,
802 ZedUpdateRequiredError {
803 minimum_version: minimum_required_version
804 }
805 );
806 }
807
808 if response.status().is_success() {
809 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
810
811 let mut body = String::new();
812 response.body_mut().read_to_string(&mut body).await?;
813 return Ok((serde_json::from_str(&body)?, usage));
814 } else if !did_retry
815 && response
816 .headers()
817 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
818 .is_some()
819 {
820 did_retry = true;
821 token = llm_token.refresh(&client).await?;
822 } else {
823 let mut body = String::new();
824 response.body_mut().read_to_string(&mut body).await?;
825 anyhow::bail!(
826 "error predicting edits.\nStatus: {:?}\nBody: {}",
827 response.status(),
828 body
829 );
830 }
831 }
832 }
833 }
834
835 fn accept_edit_prediction(
836 &mut self,
837 request_id: EditPredictionId,
838 cx: &mut Context<Self>,
839 ) -> Task<Result<()>> {
840 let client = self.client.clone();
841 let llm_token = self.llm_token.clone();
842 let app_version = AppVersion::global(cx);
843 cx.spawn(async move |this, cx| {
844 let http_client = client.http_client();
845 let mut response = llm_token_retry(&llm_token, &client, |token| {
846 let request_builder = http_client::Request::builder().method(Method::POST);
847 let request_builder =
848 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
849 request_builder.uri(accept_prediction_url)
850 } else {
851 request_builder.uri(
852 http_client
853 .build_zed_llm_url("/predict_edits/accept", &[])?
854 .as_ref(),
855 )
856 };
857 Ok(request_builder
858 .header("Content-Type", "application/json")
859 .header("Authorization", format!("Bearer {}", token))
860 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
861 .body(
862 serde_json::to_string(&AcceptEditPredictionBody {
863 request_id: request_id.0,
864 })?
865 .into(),
866 )?)
867 })
868 .await?;
869
870 if let Some(minimum_required_version) = response
871 .headers()
872 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
873 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
874 && app_version < minimum_required_version
875 {
876 return Err(anyhow!(ZedUpdateRequiredError {
877 minimum_version: minimum_required_version
878 }));
879 }
880
881 if response.status().is_success() {
882 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
883 this.update(cx, |this, cx| {
884 this.user_store.update(cx, |user_store, cx| {
885 user_store.update_edit_prediction_usage(usage, cx);
886 });
887 })?;
888 }
889
890 Ok(())
891 } else {
892 let mut body = String::new();
893 response.body_mut().read_to_string(&mut body).await?;
894 Err(anyhow!(
895 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
896 response.status(),
897 body
898 ))
899 }
900 })
901 }
902
903 fn process_completion_response(
904 prediction_response: PredictEditsResponse,
905 buffer: Entity<Buffer>,
906 snapshot: &BufferSnapshot,
907 editable_range: Range<usize>,
908 cursor_offset: usize,
909 path: Arc<Path>,
910 input_outline: String,
911 input_events: String,
912 input_excerpt: String,
913 buffer_snapshotted_at: Instant,
914 cx: &AsyncApp,
915 ) -> Task<Result<Option<EditPrediction>>> {
916 let snapshot = snapshot.clone();
917 let request_id = prediction_response.request_id;
918 let output_excerpt = prediction_response.output_excerpt;
919 cx.spawn(async move |cx| {
920 let output_excerpt: Arc<str> = output_excerpt.into();
921
922 let edits: Arc<[(Range<Anchor>, String)]> = cx
923 .background_spawn({
924 let output_excerpt = output_excerpt.clone();
925 let editable_range = editable_range.clone();
926 let snapshot = snapshot.clone();
927 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
928 })
929 .await?
930 .into();
931
932 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
933 let edits = edits.clone();
934 |buffer, cx| {
935 let new_snapshot = buffer.snapshot();
936 let edits: Arc<[(Range<Anchor>, String)]> =
937 interpolate(&snapshot, &new_snapshot, edits)?.into();
938 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
939 }
940 })?
941 else {
942 return anyhow::Ok(None);
943 };
944
945 let edit_preview = edit_preview.await;
946
947 Ok(Some(EditPrediction {
948 id: EditPredictionId(request_id),
949 path,
950 excerpt_range: editable_range,
951 cursor_offset,
952 edits,
953 edit_preview,
954 snapshot,
955 input_outline: input_outline.into(),
956 input_events: input_events.into(),
957 input_excerpt: input_excerpt.into(),
958 output_excerpt,
959 buffer_snapshotted_at,
960 response_received_at: Instant::now(),
961 }))
962 })
963 }
964
965 fn parse_edits(
966 output_excerpt: Arc<str>,
967 editable_range: Range<usize>,
968 snapshot: &BufferSnapshot,
969 ) -> Result<Vec<(Range<Anchor>, String)>> {
970 let content = output_excerpt.replace(CURSOR_MARKER, "");
971
972 let start_markers = content
973 .match_indices(EDITABLE_REGION_START_MARKER)
974 .collect::<Vec<_>>();
975 anyhow::ensure!(
976 start_markers.len() == 1,
977 "expected exactly one start marker, found {}",
978 start_markers.len()
979 );
980
981 let end_markers = content
982 .match_indices(EDITABLE_REGION_END_MARKER)
983 .collect::<Vec<_>>();
984 anyhow::ensure!(
985 end_markers.len() == 1,
986 "expected exactly one end marker, found {}",
987 end_markers.len()
988 );
989
990 let sof_markers = content
991 .match_indices(START_OF_FILE_MARKER)
992 .collect::<Vec<_>>();
993 anyhow::ensure!(
994 sof_markers.len() <= 1,
995 "expected at most one start-of-file marker, found {}",
996 sof_markers.len()
997 );
998
999 let codefence_start = start_markers[0].0;
1000 let content = &content[codefence_start..];
1001
1002 let newline_ix = content.find('\n').context("could not find newline")?;
1003 let content = &content[newline_ix + 1..];
1004
1005 let codefence_end = content
1006 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1007 .context("could not find end marker")?;
1008 let new_text = &content[..codefence_end];
1009
1010 let old_text = snapshot
1011 .text_for_range(editable_range.clone())
1012 .collect::<String>();
1013
1014 Ok(Self::compute_edits(
1015 old_text,
1016 new_text,
1017 editable_range.start,
1018 snapshot,
1019 ))
1020 }
1021
1022 pub fn compute_edits(
1023 old_text: String,
1024 new_text: &str,
1025 offset: usize,
1026 snapshot: &BufferSnapshot,
1027 ) -> Vec<(Range<Anchor>, String)> {
1028 text_diff(&old_text, new_text)
1029 .into_iter()
1030 .map(|(mut old_range, new_text)| {
1031 old_range.start += offset;
1032 old_range.end += offset;
1033
1034 let prefix_len = common_prefix(
1035 snapshot.chars_for_range(old_range.clone()),
1036 new_text.chars(),
1037 );
1038 old_range.start += prefix_len;
1039
1040 let suffix_len = common_prefix(
1041 snapshot.reversed_chars_for_range(old_range.clone()),
1042 new_text[prefix_len..].chars().rev(),
1043 );
1044 old_range.end = old_range.end.saturating_sub(suffix_len);
1045
1046 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1047 let range = if old_range.is_empty() {
1048 let anchor = snapshot.anchor_after(old_range.start);
1049 anchor..anchor
1050 } else {
1051 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1052 };
1053 (range, new_text)
1054 })
1055 .collect()
1056 }
1057
1058 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1059 self.rated_completions.contains(&completion_id)
1060 }
1061
1062 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1063 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1064 let completion = self.shown_completions.pop_back().unwrap();
1065 self.rated_completions.remove(&completion.id);
1066 }
1067 self.shown_completions.push_front(completion.clone());
1068 cx.notify();
1069 }
1070
1071 pub fn rate_completion(
1072 &mut self,
1073 completion: &EditPrediction,
1074 rating: EditPredictionRating,
1075 feedback: String,
1076 cx: &mut Context<Self>,
1077 ) {
1078 self.rated_completions.insert(completion.id);
1079 telemetry::event!(
1080 "Edit Prediction Rated",
1081 rating,
1082 input_events = completion.input_events,
1083 input_excerpt = completion.input_excerpt,
1084 input_outline = completion.input_outline,
1085 output_excerpt = completion.output_excerpt,
1086 feedback
1087 );
1088 self.client.telemetry().flush_events().detach();
1089 cx.notify();
1090 }
1091
1092 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1093 self.shown_completions.iter()
1094 }
1095
1096 pub fn shown_completions_len(&self) -> usize {
1097 self.shown_completions.len()
1098 }
1099
1100 fn report_changes_for_buffer(
1101 &mut self,
1102 buffer: &Entity<Buffer>,
1103 cx: &mut Context<Self>,
1104 ) -> BufferSnapshot {
1105 self.register_buffer(buffer, cx);
1106
1107 let registered_buffer = self
1108 .registered_buffers
1109 .get_mut(&buffer.entity_id())
1110 .unwrap();
1111 let new_snapshot = buffer.read(cx).snapshot();
1112
1113 if new_snapshot.version != registered_buffer.snapshot.version {
1114 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1115 self.push_event(Event::BufferChange {
1116 old_snapshot,
1117 new_snapshot: new_snapshot.clone(),
1118 timestamp: Instant::now(),
1119 });
1120 }
1121
1122 new_snapshot
1123 }
1124
1125 fn load_data_collection_choices() -> DataCollectionChoice {
1126 let choice = KEY_VALUE_STORE
1127 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1128 .log_err()
1129 .flatten();
1130
1131 match choice.as_deref() {
1132 Some("true") => DataCollectionChoice::Enabled,
1133 Some("false") => DataCollectionChoice::Disabled,
1134 Some(_) => {
1135 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1136 DataCollectionChoice::NotAnswered
1137 }
1138 None => DataCollectionChoice::NotAnswered,
1139 }
1140 }
1141
1142 fn gather_git_info(
1143 &mut self,
1144 cursor_point: language::Point,
1145 buffer_snapshotted_at: &Instant,
1146 snapshot: &BufferSnapshot,
1147 project: Option<&Entity<Project>>,
1148 cx: &Context<Self>,
1149 ) -> Option<PredictEditsGitInfo> {
1150 let project = project?.read(cx);
1151 let file = snapshot.file()?;
1152 let project_path = ProjectPath::from_file(file.as_ref(), cx);
1153 let entry = project.entry_for_path(&project_path, cx)?;
1154 if !worktree_entry_eligible_for_collection(&entry) {
1155 return None;
1156 }
1157
1158 let git_store = project.git_store().read(cx);
1159 let (repository, repo_path) =
1160 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1161 let repo_path_str = repo_path.to_str()?;
1162
1163 let repository = repository.read(cx);
1164 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1165 let remote_origin_url = repository.remote_origin_url.clone();
1166 let remote_upstream_url = repository.remote_upstream_url.clone();
1167 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1168
1169 Some(PredictEditsGitInfo {
1170 input_path: Some(repo_path_str.to_string()),
1171 cursor_point: Some(cloud_llm_client::Point {
1172 row: cursor_point.row,
1173 column: cursor_point.column,
1174 }),
1175 cursor_offset: Some(cursor_offset),
1176 head_sha: Some(head_sha),
1177 remote_origin_url,
1178 remote_upstream_url,
1179 recent_files: Some(recent_files),
1180 })
1181 }
1182
1183 fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1184 let now = Instant::now();
1185 if let Some(existing_ix) = self
1186 .recent_project_entries
1187 .iter()
1188 .rposition(|(id, _)| *id == project_entry_id)
1189 {
1190 self.recent_project_entries.remove(existing_ix);
1191 }
1192 // filter out rapid changes in active item, particularly since this can happen rapidly when
1193 // a workspace is loaded.
1194 if let Some(most_recent) = self.recent_project_entries.back_mut()
1195 && now.duration_since(most_recent.1) > MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES
1196 {
1197 most_recent.0 = project_entry_id;
1198 most_recent.1 = now;
1199 return;
1200 }
1201 if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1202 self.recent_project_entries.pop_front();
1203 }
1204 self.recent_project_entries
1205 .push_back((project_entry_id, now));
1206 }
1207
1208 fn recent_files(
1209 &mut self,
1210 now: &Instant,
1211 repository: &Repository,
1212 cx: &Context<Self>,
1213 ) -> Vec<PredictEditsRecentFile> {
1214 let Ok(project) = self
1215 .workspace
1216 .read_with(cx, |workspace, _cx| workspace.project().clone())
1217 else {
1218 return Vec::new();
1219 };
1220 let mut results = Vec::new();
1221 for ix in (0..self.recent_project_entries.len()).rev() {
1222 let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1223 if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1224 && let worktree = worktree.read(cx)
1225 && let Some(entry) = worktree.entry_for_id(*entry_id)
1226 && worktree_entry_eligible_for_collection(entry)
1227 {
1228 let project_path = ProjectPath {
1229 worktree_id: worktree.id(),
1230 path: entry.path.clone(),
1231 };
1232 let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1233 else {
1234 // entry not removed since queries involving other repositories might occur later
1235 continue;
1236 };
1237 let Some(repo_path_str) = repo_path.to_str() else {
1238 // paths may not be valid UTF-8
1239 self.recent_project_entries.remove(ix);
1240 continue;
1241 };
1242 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1243 self.recent_project_entries.remove(ix);
1244 continue;
1245 }
1246 let Ok(active_to_now_ms) =
1247 now.duration_since(*last_active_at).as_millis().try_into()
1248 else {
1249 self.recent_project_entries.remove(ix);
1250 continue;
1251 };
1252 results.push(PredictEditsRecentFile {
1253 path: repo_path_str.to_string(),
1254 active_to_now_ms,
1255 });
1256 } else {
1257 self.recent_project_entries.remove(ix);
1258 }
1259 }
1260 results
1261 }
1262}
1263
1264fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1265 entry.is_file()
1266 && entry.is_created()
1267 && !entry.is_ignored
1268 && !entry.is_private
1269 && !entry.is_external
1270 && !entry.is_fifo
1271}
1272
1273pub struct PerformPredictEditsParams {
1274 pub client: Arc<Client>,
1275 pub llm_token: LlmApiToken,
1276 pub app_version: SemanticVersion,
1277 pub body: PredictEditsBody,
1278}
1279
1280#[derive(Error, Debug)]
1281#[error(
1282 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1283)]
1284pub struct ZedUpdateRequiredError {
1285 minimum_version: SemanticVersion,
1286}
1287
1288fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1289 a.zip(b)
1290 .take_while(|(a, b)| a == b)
1291 .map(|(a, _)| a.len_utf8())
1292 .sum()
1293}
1294
1295pub struct GatherContextOutput {
1296 pub body: PredictEditsBody,
1297 pub editable_range: Range<usize>,
1298}
1299
1300pub fn gather_context(
1301 project: Option<&Entity<Project>>,
1302 full_path_str: String,
1303 snapshot: &BufferSnapshot,
1304 cursor_point: language::Point,
1305 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1306 can_collect_data: CanCollectData,
1307 git_info: Option<PredictEditsGitInfo>,
1308 cx: &App,
1309) -> Task<Result<GatherContextOutput>> {
1310 let local_lsp_store =
1311 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1312 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1313 if matches!(can_collect_data, CanCollectData(true))
1314 && let Some(local_lsp_store) = local_lsp_store
1315 {
1316 snapshot
1317 .diagnostic_groups(None)
1318 .into_iter()
1319 .filter_map(|(language_server_id, diagnostic_group)| {
1320 let language_server =
1321 local_lsp_store.running_language_server_for_id(language_server_id)?;
1322 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1323 let language_server_name = language_server.name().to_string();
1324 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1325 Some((language_server_name, serialized))
1326 })
1327 .collect::<Vec<_>>()
1328 } else {
1329 Vec::new()
1330 };
1331
1332 cx.background_spawn({
1333 let snapshot = snapshot.clone();
1334 async move {
1335 let diagnostic_groups = if diagnostic_groups.is_empty()
1336 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1337 {
1338 None
1339 } else {
1340 Some(diagnostic_groups)
1341 };
1342
1343 let input_excerpt = excerpt_for_cursor_position(
1344 cursor_point,
1345 &full_path_str,
1346 &snapshot,
1347 MAX_REWRITE_TOKENS,
1348 MAX_CONTEXT_TOKENS,
1349 );
1350 let input_events = make_events_prompt();
1351 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1352
1353 let body = PredictEditsBody {
1354 input_events,
1355 input_excerpt: input_excerpt.prompt,
1356 can_collect_data: can_collect_data.0,
1357 diagnostic_groups,
1358 git_info,
1359 outline: None,
1360 speculated_output: None,
1361 };
1362
1363 Ok(GatherContextOutput {
1364 body,
1365 editable_range,
1366 })
1367 }
1368 })
1369}
1370
1371fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1372 let mut result = String::new();
1373 for event in events.iter().rev() {
1374 let event_string = event.to_prompt();
1375 let event_tokens = tokens_for_bytes(event_string.len());
1376 if event_tokens > remaining_tokens {
1377 break;
1378 }
1379
1380 if !result.is_empty() {
1381 result.insert_str(0, "\n\n");
1382 }
1383 result.insert_str(0, &event_string);
1384 remaining_tokens -= event_tokens;
1385 }
1386 result
1387}
1388
1389struct RegisteredBuffer {
1390 snapshot: BufferSnapshot,
1391 _subscriptions: [gpui::Subscription; 2],
1392}
1393
1394#[derive(Clone)]
1395pub enum Event {
1396 BufferChange {
1397 old_snapshot: BufferSnapshot,
1398 new_snapshot: BufferSnapshot,
1399 timestamp: Instant,
1400 },
1401}
1402
1403impl Event {
1404 fn to_prompt(&self) -> String {
1405 match self {
1406 Event::BufferChange {
1407 old_snapshot,
1408 new_snapshot,
1409 ..
1410 } => {
1411 let mut prompt = String::new();
1412
1413 let old_path = old_snapshot
1414 .file()
1415 .map(|f| f.path().as_ref())
1416 .unwrap_or(Path::new("untitled"));
1417 let new_path = new_snapshot
1418 .file()
1419 .map(|f| f.path().as_ref())
1420 .unwrap_or(Path::new("untitled"));
1421 if old_path != new_path {
1422 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1423 }
1424
1425 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1426 if !diff.is_empty() {
1427 write!(
1428 prompt,
1429 "User edited {:?}:\n```diff\n{}\n```",
1430 new_path, diff
1431 )
1432 .unwrap();
1433 }
1434
1435 prompt
1436 }
1437 }
1438 }
1439}
1440
1441#[derive(Debug, Clone)]
1442struct CurrentEditPrediction {
1443 buffer_id: EntityId,
1444 completion: EditPrediction,
1445}
1446
1447impl CurrentEditPrediction {
1448 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1449 if self.buffer_id != old_completion.buffer_id {
1450 return true;
1451 }
1452
1453 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1454 return true;
1455 };
1456 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1457 return false;
1458 };
1459
1460 if old_edits.len() == 1 && new_edits.len() == 1 {
1461 let (old_range, old_text) = &old_edits[0];
1462 let (new_range, new_text) = &new_edits[0];
1463 new_range == old_range && new_text.starts_with(old_text)
1464 } else {
1465 true
1466 }
1467 }
1468}
1469
1470struct PendingCompletion {
1471 id: usize,
1472 _task: Task<()>,
1473}
1474
1475#[derive(Debug, Clone, Copy)]
1476pub enum DataCollectionChoice {
1477 NotAnswered,
1478 Enabled,
1479 Disabled,
1480}
1481
1482impl DataCollectionChoice {
1483 pub fn is_enabled(self) -> bool {
1484 match self {
1485 Self::Enabled => true,
1486 Self::NotAnswered | Self::Disabled => false,
1487 }
1488 }
1489
1490 pub fn is_answered(self) -> bool {
1491 match self {
1492 Self::Enabled | Self::Disabled => true,
1493 Self::NotAnswered => false,
1494 }
1495 }
1496
1497 pub fn toggle(&self) -> DataCollectionChoice {
1498 match self {
1499 Self::Enabled => Self::Disabled,
1500 Self::Disabled => Self::Enabled,
1501 Self::NotAnswered => Self::Enabled,
1502 }
1503 }
1504}
1505
1506impl From<bool> for DataCollectionChoice {
1507 fn from(value: bool) -> Self {
1508 match value {
1509 true => DataCollectionChoice::Enabled,
1510 false => DataCollectionChoice::Disabled,
1511 }
1512 }
1513}
1514
1515pub struct ProviderDataCollection {
1516 /// When set to None, data collection is not possible in the provider buffer
1517 choice: Option<Entity<DataCollectionChoice>>,
1518 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1519}
1520
1521#[derive(Debug, Clone, Copy)]
1522pub struct CanCollectData(pub bool);
1523
1524impl ProviderDataCollection {
1525 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1526 let choice_and_watcher = buffer.and_then(|buffer| {
1527 let file = buffer.read(cx).file()?;
1528
1529 if !file.is_local() || file.is_private() {
1530 return None;
1531 }
1532
1533 let zeta = zeta.read(cx);
1534 let choice = zeta.data_collection_choice.clone();
1535
1536 let license_detection_watcher = zeta
1537 .license_detection_watchers
1538 .get(&file.worktree_id(cx))
1539 .cloned()?;
1540
1541 Some((choice, license_detection_watcher))
1542 });
1543
1544 if let Some((choice, watcher)) = choice_and_watcher {
1545 ProviderDataCollection {
1546 choice: Some(choice),
1547 license_detection_watcher: Some(watcher),
1548 }
1549 } else {
1550 ProviderDataCollection {
1551 choice: None,
1552 license_detection_watcher: None,
1553 }
1554 }
1555 }
1556
1557 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1558 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1559 }
1560
1561 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1562 self.choice
1563 .as_ref()
1564 .is_some_and(|choice| choice.read(cx).is_enabled())
1565 }
1566
1567 fn is_project_open_source(&self) -> bool {
1568 self.license_detection_watcher
1569 .as_ref()
1570 .is_some_and(|watcher| watcher.is_project_open_source())
1571 }
1572
1573 pub fn toggle(&mut self, cx: &mut App) {
1574 if let Some(choice) = self.choice.as_mut() {
1575 let new_choice = choice.update(cx, |choice, _cx| {
1576 let new_choice = choice.toggle();
1577 *choice = new_choice;
1578 new_choice
1579 });
1580
1581 db::write_and_log(cx, move || {
1582 KEY_VALUE_STORE.write_kvp(
1583 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1584 new_choice.is_enabled().to_string(),
1585 )
1586 });
1587 }
1588 }
1589}
1590
1591async fn llm_token_retry(
1592 llm_token: &LlmApiToken,
1593 client: &Arc<Client>,
1594 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1595) -> Result<Response<AsyncBody>> {
1596 let mut did_retry = false;
1597 let http_client = client.http_client();
1598 let mut token = llm_token.acquire(client).await?;
1599 loop {
1600 let request = build_request(token.clone())?;
1601 let response = http_client.send(request).await?;
1602
1603 if !did_retry
1604 && !response.status().is_success()
1605 && response
1606 .headers()
1607 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1608 .is_some()
1609 {
1610 did_retry = true;
1611 token = llm_token.refresh(client).await?;
1612 continue;
1613 }
1614
1615 return Ok(response);
1616 }
1617}
1618
1619pub struct ZetaEditPredictionProvider {
1620 zeta: Entity<Zeta>,
1621 pending_completions: ArrayVec<PendingCompletion, 2>,
1622 next_pending_completion_id: usize,
1623 current_completion: Option<CurrentEditPrediction>,
1624 /// None if this is entirely disabled for this provider
1625 provider_data_collection: ProviderDataCollection,
1626 last_request_timestamp: Instant,
1627}
1628
1629impl ZetaEditPredictionProvider {
1630 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1631
1632 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1633 Self {
1634 zeta,
1635 pending_completions: ArrayVec::new(),
1636 next_pending_completion_id: 0,
1637 current_completion: None,
1638 provider_data_collection,
1639 last_request_timestamp: Instant::now(),
1640 }
1641 }
1642}
1643
1644impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1645 fn name() -> &'static str {
1646 "zed-predict"
1647 }
1648
1649 fn display_name() -> &'static str {
1650 "Zed's Edit Predictions"
1651 }
1652
1653 fn show_completions_in_menu() -> bool {
1654 true
1655 }
1656
1657 fn show_tab_accept_marker() -> bool {
1658 true
1659 }
1660
1661 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1662 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1663
1664 if self.provider_data_collection.is_data_collection_enabled(cx) {
1665 DataCollectionState::Enabled {
1666 is_project_open_source,
1667 }
1668 } else {
1669 DataCollectionState::Disabled {
1670 is_project_open_source,
1671 }
1672 }
1673 }
1674
1675 fn toggle_data_collection(&mut self, cx: &mut App) {
1676 self.provider_data_collection.toggle(cx);
1677 }
1678
1679 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1680 self.zeta.read(cx).usage(cx)
1681 }
1682
1683 fn is_enabled(
1684 &self,
1685 _buffer: &Entity<Buffer>,
1686 _cursor_position: language::Anchor,
1687 _cx: &App,
1688 ) -> bool {
1689 true
1690 }
1691 fn is_refreshing(&self) -> bool {
1692 !self.pending_completions.is_empty()
1693 }
1694
1695 fn refresh(
1696 &mut self,
1697 project: Option<Entity<Project>>,
1698 buffer: Entity<Buffer>,
1699 position: language::Anchor,
1700 _debounce: bool,
1701 cx: &mut Context<Self>,
1702 ) {
1703 if self.zeta.read(cx).update_required {
1704 return;
1705 }
1706
1707 if self
1708 .zeta
1709 .read(cx)
1710 .user_store
1711 .read_with(cx, |user_store, _cx| {
1712 user_store.account_too_young() || user_store.has_overdue_invoices()
1713 })
1714 {
1715 return;
1716 }
1717
1718 if let Some(current_completion) = self.current_completion.as_ref() {
1719 let snapshot = buffer.read(cx).snapshot();
1720 if current_completion
1721 .completion
1722 .interpolate(&snapshot)
1723 .is_some()
1724 {
1725 return;
1726 }
1727 }
1728
1729 let pending_completion_id = self.next_pending_completion_id;
1730 self.next_pending_completion_id += 1;
1731 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1732 let last_request_timestamp = self.last_request_timestamp;
1733
1734 let task = cx.spawn(async move |this, cx| {
1735 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1736 .checked_duration_since(Instant::now())
1737 {
1738 cx.background_executor().timer(timeout).await;
1739 }
1740
1741 let completion_request = this.update(cx, |this, cx| {
1742 this.last_request_timestamp = Instant::now();
1743 this.zeta.update(cx, |zeta, cx| {
1744 zeta.request_completion(
1745 project.as_ref(),
1746 &buffer,
1747 position,
1748 can_collect_data,
1749 cx,
1750 )
1751 })
1752 });
1753
1754 let completion = match completion_request {
1755 Ok(completion_request) => {
1756 let completion_request = completion_request.await;
1757 completion_request.map(|c| {
1758 c.map(|completion| CurrentEditPrediction {
1759 buffer_id: buffer.entity_id(),
1760 completion,
1761 })
1762 })
1763 }
1764 Err(error) => Err(error),
1765 };
1766 let Some(new_completion) = completion
1767 .context("edit prediction failed")
1768 .log_err()
1769 .flatten()
1770 else {
1771 this.update(cx, |this, cx| {
1772 if this.pending_completions[0].id == pending_completion_id {
1773 this.pending_completions.remove(0);
1774 } else {
1775 this.pending_completions.clear();
1776 }
1777
1778 cx.notify();
1779 })
1780 .ok();
1781 return;
1782 };
1783
1784 this.update(cx, |this, cx| {
1785 if this.pending_completions[0].id == pending_completion_id {
1786 this.pending_completions.remove(0);
1787 } else {
1788 this.pending_completions.clear();
1789 }
1790
1791 if let Some(old_completion) = this.current_completion.as_ref() {
1792 let snapshot = buffer.read(cx).snapshot();
1793 if new_completion.should_replace_completion(old_completion, &snapshot) {
1794 this.zeta.update(cx, |zeta, cx| {
1795 zeta.completion_shown(&new_completion.completion, cx);
1796 });
1797 this.current_completion = Some(new_completion);
1798 }
1799 } else {
1800 this.zeta.update(cx, |zeta, cx| {
1801 zeta.completion_shown(&new_completion.completion, cx);
1802 });
1803 this.current_completion = Some(new_completion);
1804 }
1805
1806 cx.notify();
1807 })
1808 .ok();
1809 });
1810
1811 // We always maintain at most two pending completions. When we already
1812 // have two, we replace the newest one.
1813 if self.pending_completions.len() <= 1 {
1814 self.pending_completions.push(PendingCompletion {
1815 id: pending_completion_id,
1816 _task: task,
1817 });
1818 } else if self.pending_completions.len() == 2 {
1819 self.pending_completions.pop();
1820 self.pending_completions.push(PendingCompletion {
1821 id: pending_completion_id,
1822 _task: task,
1823 });
1824 }
1825 }
1826
1827 fn cycle(
1828 &mut self,
1829 _buffer: Entity<Buffer>,
1830 _cursor_position: language::Anchor,
1831 _direction: edit_prediction::Direction,
1832 _cx: &mut Context<Self>,
1833 ) {
1834 // Right now we don't support cycling.
1835 }
1836
1837 fn accept(&mut self, cx: &mut Context<Self>) {
1838 let completion_id = self
1839 .current_completion
1840 .as_ref()
1841 .map(|completion| completion.completion.id);
1842 if let Some(completion_id) = completion_id {
1843 self.zeta
1844 .update(cx, |zeta, cx| {
1845 zeta.accept_edit_prediction(completion_id, cx)
1846 })
1847 .detach();
1848 }
1849 self.pending_completions.clear();
1850 }
1851
1852 fn discard(&mut self, _cx: &mut Context<Self>) {
1853 self.pending_completions.clear();
1854 self.current_completion.take();
1855 }
1856
1857 fn suggest(
1858 &mut self,
1859 buffer: &Entity<Buffer>,
1860 cursor_position: language::Anchor,
1861 cx: &mut Context<Self>,
1862 ) -> Option<edit_prediction::EditPrediction> {
1863 let CurrentEditPrediction {
1864 buffer_id,
1865 completion,
1866 ..
1867 } = self.current_completion.as_mut()?;
1868
1869 // Invalidate previous completion if it was generated for a different buffer.
1870 if *buffer_id != buffer.entity_id() {
1871 self.current_completion.take();
1872 return None;
1873 }
1874
1875 let buffer = buffer.read(cx);
1876 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1877 self.current_completion.take();
1878 return None;
1879 };
1880
1881 let cursor_row = cursor_position.to_point(buffer).row;
1882 let (closest_edit_ix, (closest_edit_range, _)) =
1883 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1884 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1885 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1886 cmp::min(distance_from_start, distance_from_end)
1887 })?;
1888
1889 let mut edit_start_ix = closest_edit_ix;
1890 for (range, _) in edits[..edit_start_ix].iter().rev() {
1891 let distance_from_closest_edit =
1892 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1893 if distance_from_closest_edit <= 1 {
1894 edit_start_ix -= 1;
1895 } else {
1896 break;
1897 }
1898 }
1899
1900 let mut edit_end_ix = closest_edit_ix + 1;
1901 for (range, _) in &edits[edit_end_ix..] {
1902 let distance_from_closest_edit =
1903 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1904 if distance_from_closest_edit <= 1 {
1905 edit_end_ix += 1;
1906 } else {
1907 break;
1908 }
1909 }
1910
1911 Some(edit_prediction::EditPrediction {
1912 id: Some(completion.id.to_string().into()),
1913 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1914 edit_preview: Some(completion.edit_preview.clone()),
1915 })
1916 }
1917}
1918
1919fn tokens_for_bytes(bytes: usize) -> usize {
1920 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1921 /// intentionally low to err on the side of underestimating limits.
1922 const BYTES_PER_TOKEN_GUESS: usize = 3;
1923 bytes / BYTES_PER_TOKEN_GUESS
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928 use client::UserStore;
1929 use client::test::FakeServer;
1930 use clock::FakeSystemClock;
1931 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1932 use gpui::TestAppContext;
1933 use http_client::FakeHttpClient;
1934 use indoc::indoc;
1935 use language::Point;
1936 use settings::SettingsStore;
1937
1938 use super::*;
1939
1940 #[gpui::test]
1941 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1942 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1943 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1944 to_completion_edits(
1945 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1946 &buffer,
1947 cx,
1948 )
1949 .into()
1950 });
1951
1952 let edit_preview = cx
1953 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1954 .await;
1955
1956 let completion = EditPrediction {
1957 edits,
1958 edit_preview,
1959 path: Path::new("").into(),
1960 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1961 id: EditPredictionId(Uuid::new_v4()),
1962 excerpt_range: 0..0,
1963 cursor_offset: 0,
1964 input_outline: "".into(),
1965 input_events: "".into(),
1966 input_excerpt: "".into(),
1967 output_excerpt: "".into(),
1968 buffer_snapshotted_at: Instant::now(),
1969 response_received_at: Instant::now(),
1970 };
1971
1972 cx.update(|cx| {
1973 assert_eq!(
1974 from_completion_edits(
1975 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1976 &buffer,
1977 cx
1978 ),
1979 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1980 );
1981
1982 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1983 assert_eq!(
1984 from_completion_edits(
1985 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1986 &buffer,
1987 cx
1988 ),
1989 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1990 );
1991
1992 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1993 assert_eq!(
1994 from_completion_edits(
1995 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1996 &buffer,
1997 cx
1998 ),
1999 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2000 );
2001
2002 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2003 assert_eq!(
2004 from_completion_edits(
2005 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2006 &buffer,
2007 cx
2008 ),
2009 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2010 );
2011
2012 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2013 assert_eq!(
2014 from_completion_edits(
2015 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2016 &buffer,
2017 cx
2018 ),
2019 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2020 );
2021
2022 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2023 assert_eq!(
2024 from_completion_edits(
2025 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2026 &buffer,
2027 cx
2028 ),
2029 vec![(9..11, "".to_string())]
2030 );
2031
2032 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2033 assert_eq!(
2034 from_completion_edits(
2035 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2036 &buffer,
2037 cx
2038 ),
2039 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2040 );
2041
2042 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2043 assert_eq!(
2044 from_completion_edits(
2045 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2046 &buffer,
2047 cx
2048 ),
2049 vec![(4..4, "M".to_string())]
2050 );
2051
2052 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2053 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2054 })
2055 }
2056
2057 #[gpui::test]
2058 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2059 cx.update(|cx| {
2060 let settings_store = SettingsStore::test(cx);
2061 cx.set_global(settings_store);
2062 client::init_settings(cx);
2063 });
2064
2065 let edits = edits_for_prediction(
2066 indoc! {"
2067 fn main() {
2068 let word_1 = \"lorem\";
2069 let range = word.len()..word.len();
2070 }
2071 "},
2072 indoc! {"
2073 <|editable_region_start|>
2074 fn main() {
2075 let word_1 = \"lorem\";
2076 let range = word_1.len()..word_1.len();
2077 }
2078
2079 <|editable_region_end|>
2080 "},
2081 cx,
2082 )
2083 .await;
2084 assert_eq!(
2085 edits,
2086 [
2087 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2088 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2089 ]
2090 );
2091
2092 let edits = edits_for_prediction(
2093 indoc! {"
2094 fn main() {
2095 let story = \"the quick\"
2096 }
2097 "},
2098 indoc! {"
2099 <|editable_region_start|>
2100 fn main() {
2101 let story = \"the quick brown fox jumps over the lazy dog\";
2102 }
2103
2104 <|editable_region_end|>
2105 "},
2106 cx,
2107 )
2108 .await;
2109 assert_eq!(
2110 edits,
2111 [
2112 (
2113 Point::new(1, 26)..Point::new(1, 26),
2114 " brown fox jumps over the lazy dog".to_string()
2115 ),
2116 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2117 ]
2118 );
2119 }
2120
2121 #[gpui::test]
2122 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2123 cx.update(|cx| {
2124 let settings_store = SettingsStore::test(cx);
2125 cx.set_global(settings_store);
2126 client::init_settings(cx);
2127 });
2128
2129 let buffer_content = "lorem\n";
2130 let completion_response = indoc! {"
2131 ```animals.js
2132 <|start_of_file|>
2133 <|editable_region_start|>
2134 lorem
2135 ipsum
2136 <|editable_region_end|>
2137 ```"};
2138
2139 let http_client = FakeHttpClient::create(move |req| async move {
2140 match (req.method(), req.uri().path()) {
2141 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2142 .status(200)
2143 .body(
2144 serde_json::to_string(&CreateLlmTokenResponse {
2145 token: LlmToken("the-llm-token".to_string()),
2146 })
2147 .unwrap()
2148 .into(),
2149 )
2150 .unwrap()),
2151 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2152 .status(200)
2153 .body(
2154 serde_json::to_string(&PredictEditsResponse {
2155 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2156 .unwrap(),
2157 output_excerpt: completion_response.to_string(),
2158 })
2159 .unwrap()
2160 .into(),
2161 )
2162 .unwrap()),
2163 _ => Ok(http_client::Response::builder()
2164 .status(404)
2165 .body("Not Found".into())
2166 .unwrap()),
2167 }
2168 });
2169
2170 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2171 cx.update(|cx| {
2172 RefreshLlmTokenListener::register(client.clone(), cx);
2173 });
2174 // Construct the fake server to authenticate.
2175 let _server = FakeServer::for_client(42, &client, cx).await;
2176 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2177 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2178
2179 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2180 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2181 let completion_task = zeta.update(cx, |zeta, cx| {
2182 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2183 });
2184
2185 let completion = completion_task.await.unwrap().unwrap();
2186 buffer.update(cx, |buffer, cx| {
2187 buffer.edit(completion.edits.iter().cloned(), None, cx)
2188 });
2189 assert_eq!(
2190 buffer.read_with(cx, |buffer, _| buffer.text()),
2191 "lorem\nipsum"
2192 );
2193 }
2194
2195 async fn edits_for_prediction(
2196 buffer_content: &str,
2197 completion_response: &str,
2198 cx: &mut TestAppContext,
2199 ) -> Vec<(Range<Point>, String)> {
2200 let completion_response = completion_response.to_string();
2201 let http_client = FakeHttpClient::create(move |req| {
2202 let completion = completion_response.clone();
2203 async move {
2204 match (req.method(), req.uri().path()) {
2205 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2206 .status(200)
2207 .body(
2208 serde_json::to_string(&CreateLlmTokenResponse {
2209 token: LlmToken("the-llm-token".to_string()),
2210 })
2211 .unwrap()
2212 .into(),
2213 )
2214 .unwrap()),
2215 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2216 .status(200)
2217 .body(
2218 serde_json::to_string(&PredictEditsResponse {
2219 request_id: Uuid::new_v4(),
2220 output_excerpt: completion,
2221 })
2222 .unwrap()
2223 .into(),
2224 )
2225 .unwrap()),
2226 _ => Ok(http_client::Response::builder()
2227 .status(404)
2228 .body("Not Found".into())
2229 .unwrap()),
2230 }
2231 }
2232 });
2233
2234 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2235 cx.update(|cx| {
2236 RefreshLlmTokenListener::register(client.clone(), cx);
2237 });
2238 // Construct the fake server to authenticate.
2239 let _server = FakeServer::for_client(42, &client, cx).await;
2240 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2241 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2242
2243 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2244 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2245 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2246 let completion_task = zeta.update(cx, |zeta, cx| {
2247 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2248 });
2249
2250 let completion = completion_task.await.unwrap().unwrap();
2251 completion
2252 .edits
2253 .iter()
2254 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2255 .collect::<Vec<_>>()
2256 }
2257
2258 fn to_completion_edits(
2259 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2260 buffer: &Entity<Buffer>,
2261 cx: &App,
2262 ) -> Vec<(Range<Anchor>, String)> {
2263 let buffer = buffer.read(cx);
2264 iterator
2265 .into_iter()
2266 .map(|(range, text)| {
2267 (
2268 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2269 text,
2270 )
2271 })
2272 .collect()
2273 }
2274
2275 fn from_completion_edits(
2276 editor_edits: &[(Range<Anchor>, String)],
2277 buffer: &Entity<Buffer>,
2278 cx: &App,
2279 ) -> Vec<(Range<usize>, String)> {
2280 let buffer = buffer.read(cx);
2281 editor_edits
2282 .iter()
2283 .map(|(range, text)| {
2284 (
2285 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2286 text.clone(),
2287 )
2288 })
2289 .collect()
2290 }
2291
2292 #[ctor::ctor]
2293 fn init_logger() {
2294 zlog::init_test();
2295 }
2296}