1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13pub use init::*;
14use license_detection::LicenseDetectionWatcher;
15use project::git_store::Repository;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use client::{Client, EditPredictionUsage, UserStore};
20use cloud_llm_client::{
21 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
22 PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
23 ZED_VERSION_HEADER_NAME,
24};
25use collections::{HashMap, HashSet, VecDeque};
26use futures::AsyncReadExt;
27use gpui::{
28 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
29 Subscription, Task, WeakEntity, actions,
30};
31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
32use input_excerpt::excerpt_for_cursor_position;
33use language::{
34 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
35};
36use language_model::{LlmApiToken, RefreshLlmTokenListener};
37use project::{Project, ProjectEntryId, ProjectPath};
38use release_channel::AppVersion;
39use settings::WorktreeId;
40use std::str::FromStr;
41use std::{
42 cmp,
43 fmt::Write,
44 future::Future,
45 mem,
46 ops::Range,
47 path::Path,
48 rc::Rc,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use telemetry_events::EditPredictionRating;
53use thiserror::Error;
54use util::ResultExt;
55use uuid::Uuid;
56use workspace::Workspace;
57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
58use worktree::Worktree;
59
60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
66
67const MAX_CONTEXT_TOKENS: usize = 150;
68const MAX_REWRITE_TOKENS: usize = 350;
69const MAX_EVENT_TOKENS: usize = 500;
70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Maximum file path length to include in recent files list.
79const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
80
81/// Maximum number of edit predictions to store for feedback.
82const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
83
84actions!(
85 edit_prediction,
86 [
87 /// Clears the edit prediction history.
88 ClearHistory
89 ]
90);
91
92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
93pub struct EditPredictionId(Uuid);
94
95impl From<EditPredictionId> for gpui::ElementId {
96 fn from(value: EditPredictionId) -> Self {
97 gpui::ElementId::Uuid(value.0)
98 }
99}
100
101impl std::fmt::Display for EditPredictionId {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{}", self.0)
104 }
105}
106
107struct ZedPredictUpsell;
108
109impl Dismissable for ZedPredictUpsell {
110 const KEY: &'static str = "dismissed-edit-predict-upsell";
111
112 fn dismissed() -> bool {
113 // To make this backwards compatible with older versions of Zed, we
114 // check if the user has seen the previous Edit Prediction Onboarding
115 // before, by checking the data collection choice which was written to
116 // the database once the user clicked on "Accept and Enable"
117 if KEY_VALUE_STORE
118 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
119 .log_err()
120 .is_some_and(|s| s.is_some())
121 {
122 return true;
123 }
124
125 KEY_VALUE_STORE
126 .read_kvp(Self::KEY)
127 .log_err()
128 .is_some_and(|s| s.is_some())
129 }
130}
131
132pub fn should_show_upsell_modal() -> bool {
133 !ZedPredictUpsell::dismissed()
134}
135
136#[derive(Clone)]
137struct ZetaGlobal(Entity<Zeta>);
138
139impl Global for ZetaGlobal {}
140
141#[derive(Clone)]
142pub struct EditPrediction {
143 id: EditPredictionId,
144 path: Arc<Path>,
145 excerpt_range: Range<usize>,
146 cursor_offset: usize,
147 edits: Arc<[(Range<Anchor>, String)]>,
148 snapshot: BufferSnapshot,
149 edit_preview: EditPreview,
150 input_outline: Arc<str>,
151 input_events: Arc<str>,
152 input_excerpt: Arc<str>,
153 output_excerpt: Arc<str>,
154 buffer_snapshotted_at: Instant,
155 response_received_at: Instant,
156}
157
158impl EditPrediction {
159 fn latency(&self) -> Duration {
160 self.response_received_at
161 .duration_since(self.buffer_snapshotted_at)
162 }
163
164 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
165 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
166 }
167}
168
169fn interpolate(
170 old_snapshot: &BufferSnapshot,
171 new_snapshot: &BufferSnapshot,
172 current_edits: Arc<[(Range<Anchor>, String)]>,
173) -> Option<Vec<(Range<Anchor>, String)>> {
174 let mut edits = Vec::new();
175
176 let mut model_edits = current_edits.iter().peekable();
177 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
178 while let Some((model_old_range, _)) = model_edits.peek() {
179 let model_old_range = model_old_range.to_offset(old_snapshot);
180 if model_old_range.end < user_edit.old.start {
181 let (model_old_range, model_new_text) = model_edits.next().unwrap();
182 edits.push((model_old_range.clone(), model_new_text.clone()));
183 } else {
184 break;
185 }
186 }
187
188 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
189 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
190 if user_edit.old == model_old_offset_range {
191 let user_new_text = new_snapshot
192 .text_for_range(user_edit.new.clone())
193 .collect::<String>();
194
195 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
196 if !model_suffix.is_empty() {
197 let anchor = old_snapshot.anchor_after(user_edit.old.end);
198 edits.push((anchor..anchor, model_suffix.to_string()));
199 }
200
201 model_edits.next();
202 continue;
203 }
204 }
205 }
206
207 return None;
208 }
209
210 edits.extend(model_edits.cloned());
211
212 if edits.is_empty() { None } else { Some(edits) }
213}
214
215impl std::fmt::Debug for EditPrediction {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 f.debug_struct("EditPrediction")
218 .field("id", &self.id)
219 .field("path", &self.path)
220 .field("edits", &self.edits)
221 .finish_non_exhaustive()
222 }
223}
224
225pub struct Zeta {
226 workspace: WeakEntity<Workspace>,
227 client: Arc<Client>,
228 events: VecDeque<Event>,
229 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
230 shown_completions: VecDeque<EditPrediction>,
231 rated_completions: HashSet<EditPredictionId>,
232 data_collection_choice: Entity<DataCollectionChoice>,
233 llm_token: LlmApiToken,
234 _llm_token_subscription: Subscription,
235 /// Whether an update to a newer version of Zed is required to continue using Zeta.
236 update_required: bool,
237 user_store: Entity<UserStore>,
238 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
239 recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
240}
241
242impl Zeta {
243 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
244 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
245 }
246
247 pub fn register(
248 workspace: Option<Entity<Workspace>>,
249 worktree: Option<Entity<Worktree>>,
250 client: Arc<Client>,
251 user_store: Entity<UserStore>,
252 cx: &mut App,
253 ) -> Entity<Self> {
254 let this = Self::global(cx).unwrap_or_else(|| {
255 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
256 cx.set_global(ZetaGlobal(entity.clone()));
257 entity
258 });
259
260 this.update(cx, move |this, cx| {
261 if let Some(worktree) = worktree {
262 let worktree_id = worktree.read(cx).id();
263 this.license_detection_watchers
264 .entry(worktree_id)
265 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
266 }
267 });
268
269 this
270 }
271
272 pub fn clear_history(&mut self) {
273 self.events.clear();
274 }
275
276 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
277 self.user_store.read(cx).edit_prediction_usage()
278 }
279
280 fn new(
281 workspace: Option<Entity<Workspace>>,
282 client: Arc<Client>,
283 user_store: Entity<UserStore>,
284 cx: &mut Context<Self>,
285 ) -> Self {
286 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
287
288 let data_collection_choice = Self::load_data_collection_choices();
289 let data_collection_choice = cx.new(|_| data_collection_choice);
290
291 if let Some(workspace) = &workspace {
292 cx.subscribe(
293 &workspace.read(cx).project().clone(),
294 |this, _workspace, event, _cx| match event {
295 project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
296 this.push_recent_project_entry(*project_entry_id)
297 }
298 _ => {}
299 },
300 )
301 .detach();
302 }
303
304 Self {
305 workspace: workspace.map_or_else(
306 || WeakEntity::new_invalid(),
307 |workspace| workspace.downgrade(),
308 ),
309 client,
310 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
311 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
312 rated_completions: HashSet::default(),
313 registered_buffers: HashMap::default(),
314 data_collection_choice,
315 llm_token: LlmApiToken::default(),
316 _llm_token_subscription: cx.subscribe(
317 &refresh_llm_token_listener,
318 |this, _listener, _event, cx| {
319 let client = this.client.clone();
320 let llm_token = this.llm_token.clone();
321 cx.spawn(async move |_this, _cx| {
322 llm_token.refresh(&client).await?;
323 anyhow::Ok(())
324 })
325 .detach_and_log_err(cx);
326 },
327 ),
328 update_required: false,
329 license_detection_watchers: HashMap::default(),
330 user_store,
331 recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
332 }
333 }
334
335 fn push_event(&mut self, event: Event) {
336 if let Some(Event::BufferChange {
337 new_snapshot: last_new_snapshot,
338 timestamp: last_timestamp,
339 ..
340 }) = self.events.back_mut()
341 {
342 // Coalesce edits for the same buffer when they happen one after the other.
343 let Event::BufferChange {
344 old_snapshot,
345 new_snapshot,
346 timestamp,
347 } = &event;
348
349 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
350 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
351 && old_snapshot.version == last_new_snapshot.version
352 {
353 *last_new_snapshot = new_snapshot.clone();
354 *last_timestamp = *timestamp;
355 return;
356 }
357 }
358
359 if self.events.len() >= MAX_EVENT_COUNT {
360 // These are halved instead of popping to improve prompt caching.
361 self.events.drain(..MAX_EVENT_COUNT / 2);
362 }
363
364 self.events.push_back(event);
365 }
366
367 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
368 let buffer_id = buffer.entity_id();
369 let weak_buffer = buffer.downgrade();
370
371 if let std::collections::hash_map::Entry::Vacant(entry) =
372 self.registered_buffers.entry(buffer_id)
373 {
374 let snapshot = buffer.read(cx).snapshot();
375
376 entry.insert(RegisteredBuffer {
377 snapshot,
378 _subscriptions: [
379 cx.subscribe(buffer, move |this, buffer, event, cx| {
380 this.handle_buffer_event(buffer, event, cx);
381 }),
382 cx.observe_release(buffer, move |this, _buffer, _cx| {
383 this.registered_buffers.remove(&weak_buffer.entity_id());
384 }),
385 ],
386 });
387 };
388 }
389
390 fn handle_buffer_event(
391 &mut self,
392 buffer: Entity<Buffer>,
393 event: &language::BufferEvent,
394 cx: &mut Context<Self>,
395 ) {
396 if let language::BufferEvent::Edited = event {
397 self.report_changes_for_buffer(&buffer, cx);
398 }
399 }
400
401 fn request_completion_impl<F, R>(
402 &mut self,
403 workspace: Option<Entity<Workspace>>,
404 project: Option<&Entity<Project>>,
405 buffer: &Entity<Buffer>,
406 cursor: language::Anchor,
407 can_collect_data: CanCollectData,
408 cx: &mut Context<Self>,
409 perform_predict_edits: F,
410 ) -> Task<Result<Option<EditPrediction>>>
411 where
412 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
413 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
414 + Send
415 + 'static,
416 {
417 let buffer = buffer.clone();
418 let buffer_snapshotted_at = Instant::now();
419 let snapshot = self.report_changes_for_buffer(&buffer, cx);
420 let zeta = cx.entity();
421 let events = self.events.clone();
422 let client = self.client.clone();
423 let llm_token = self.llm_token.clone();
424 let app_version = AppVersion::global(cx);
425
426 let git_info = if matches!(can_collect_data, CanCollectData(true)) {
427 self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
428 } else {
429 None
430 };
431
432 let full_path: Arc<Path> = snapshot
433 .file()
434 .map(|f| Arc::from(f.full_path(cx).as_path()))
435 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
436 let full_path_str = full_path.to_string_lossy().to_string();
437 let cursor_point = cursor.to_point(&snapshot);
438 let cursor_offset = cursor_point.to_offset(&snapshot);
439 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
440 let gather_task = gather_context(
441 project,
442 full_path_str,
443 &snapshot,
444 cursor_point,
445 make_events_prompt,
446 can_collect_data,
447 git_info,
448 cx,
449 );
450
451 cx.spawn(async move |this, cx| {
452 let GatherContextOutput {
453 body,
454 editable_range,
455 } = gather_task.await?;
456 let done_gathering_context_at = Instant::now();
457
458 log::debug!(
459 "Events:\n{}\nExcerpt:\n{:?}",
460 body.input_events,
461 body.input_excerpt
462 );
463
464 let input_outline = body.outline.clone().unwrap_or_default();
465 let input_events = body.input_events.clone();
466 let input_excerpt = body.input_excerpt.clone();
467
468 let response = perform_predict_edits(PerformPredictEditsParams {
469 client,
470 llm_token,
471 app_version,
472 body,
473 })
474 .await;
475 let (response, usage) = match response {
476 Ok(response) => response,
477 Err(err) => {
478 if err.is::<ZedUpdateRequiredError>() {
479 cx.update(|cx| {
480 zeta.update(cx, |zeta, _cx| {
481 zeta.update_required = true;
482 });
483
484 if let Some(workspace) = workspace {
485 workspace.update(cx, |workspace, cx| {
486 workspace.show_notification(
487 NotificationId::unique::<ZedUpdateRequiredError>(),
488 cx,
489 |cx| {
490 cx.new(|cx| {
491 ErrorMessagePrompt::new(err.to_string(), cx)
492 .with_link_button(
493 "Update Zed",
494 "https://zed.dev/releases",
495 )
496 })
497 },
498 );
499 });
500 }
501 })
502 .ok();
503 }
504
505 return Err(err);
506 }
507 };
508
509 let received_response_at = Instant::now();
510 log::debug!("completion response: {}", &response.output_excerpt);
511
512 if let Some(usage) = usage {
513 this.update(cx, |this, cx| {
514 this.user_store.update(cx, |user_store, cx| {
515 user_store.update_edit_prediction_usage(usage, cx);
516 });
517 })
518 .ok();
519 }
520
521 let edit_prediction = Self::process_completion_response(
522 response,
523 buffer,
524 &snapshot,
525 editable_range,
526 cursor_offset,
527 full_path,
528 input_outline,
529 input_events,
530 input_excerpt,
531 buffer_snapshotted_at,
532 cx,
533 )
534 .await;
535
536 let finished_at = Instant::now();
537
538 // record latency for ~1% of requests
539 if rand::random::<u8>() <= 2 {
540 telemetry::event!(
541 "Edit Prediction Request",
542 context_latency = done_gathering_context_at
543 .duration_since(buffer_snapshotted_at)
544 .as_millis(),
545 request_latency = received_response_at
546 .duration_since(done_gathering_context_at)
547 .as_millis(),
548 process_latency = finished_at.duration_since(received_response_at).as_millis()
549 );
550 }
551
552 edit_prediction
553 })
554 }
555
556 // Generates several example completions of various states to fill the Zeta completion modal
557 #[cfg(any(test, feature = "test-support"))]
558 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
559 use language::Point;
560
561 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
562 And maybe a short line
563
564 Then a few lines
565
566 and then another
567 "#};
568
569 let project = None;
570 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
571 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
572
573 let completion_tasks = vec![
574 self.fake_completion(
575 project,
576 &buffer,
577 position,
578 PredictEditsResponse {
579 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
580 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
581a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
582[here's an edit]
583And maybe a short line
584Then a few lines
585and then another
586{EDITABLE_REGION_END_MARKER}
587 ", ),
588 },
589 cx,
590 ),
591 self.fake_completion(
592 project,
593 &buffer,
594 position,
595 PredictEditsResponse {
596 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
597 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
598a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
599And maybe a short line
600[and another edit]
601Then a few lines
602and then another
603{EDITABLE_REGION_END_MARKER}
604 "#),
605 },
606 cx,
607 ),
608 self.fake_completion(
609 project,
610 &buffer,
611 position,
612 PredictEditsResponse {
613 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
614 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
615a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
616And maybe a short line
617
618Then a few lines
619
620and then another
621{EDITABLE_REGION_END_MARKER}
622 "#),
623 },
624 cx,
625 ),
626 self.fake_completion(
627 project,
628 &buffer,
629 position,
630 PredictEditsResponse {
631 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
632 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
633a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
634And maybe a short line
635
636Then a few lines
637
638and then another
639{EDITABLE_REGION_END_MARKER}
640 "#),
641 },
642 cx,
643 ),
644 self.fake_completion(
645 project,
646 &buffer,
647 position,
648 PredictEditsResponse {
649 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
650 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
652And maybe a short line
653Then a few lines
654[a third completion]
655and then another
656{EDITABLE_REGION_END_MARKER}
657 "#),
658 },
659 cx,
660 ),
661 self.fake_completion(
662 project,
663 &buffer,
664 position,
665 PredictEditsResponse {
666 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
667 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
668a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
669And maybe a short line
670and then another
671[fourth completion example]
672{EDITABLE_REGION_END_MARKER}
673 "#),
674 },
675 cx,
676 ),
677 self.fake_completion(
678 project,
679 &buffer,
680 position,
681 PredictEditsResponse {
682 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
683 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
684a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
685And maybe a short line
686Then a few lines
687and then another
688[fifth and final completion]
689{EDITABLE_REGION_END_MARKER}
690 "#),
691 },
692 cx,
693 ),
694 ];
695
696 cx.spawn(async move |zeta, cx| {
697 for task in completion_tasks {
698 task.await.unwrap();
699 }
700
701 zeta.update(cx, |zeta, _cx| {
702 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
703 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
704 })
705 .ok();
706 })
707 }
708
709 #[cfg(any(test, feature = "test-support"))]
710 pub fn fake_completion(
711 &mut self,
712 project: Option<&Entity<Project>>,
713 buffer: &Entity<Buffer>,
714 position: language::Anchor,
715 response: PredictEditsResponse,
716 cx: &mut Context<Self>,
717 ) -> Task<Result<Option<EditPrediction>>> {
718 use std::future::ready;
719
720 self.request_completion_impl(
721 None,
722 project,
723 buffer,
724 position,
725 CanCollectData(false),
726 cx,
727 |_params| ready(Ok((response, None))),
728 )
729 }
730
731 pub fn request_completion(
732 &mut self,
733 project: Option<&Entity<Project>>,
734 buffer: &Entity<Buffer>,
735 position: language::Anchor,
736 can_collect_data: CanCollectData,
737 cx: &mut Context<Self>,
738 ) -> Task<Result<Option<EditPrediction>>> {
739 self.request_completion_impl(
740 self.workspace.upgrade(),
741 project,
742 buffer,
743 position,
744 can_collect_data,
745 cx,
746 Self::perform_predict_edits,
747 )
748 }
749
750 pub fn perform_predict_edits(
751 params: PerformPredictEditsParams,
752 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
753 async move {
754 let PerformPredictEditsParams {
755 client,
756 llm_token,
757 app_version,
758 body,
759 ..
760 } = params;
761
762 let http_client = client.http_client();
763 let mut token = llm_token.acquire(&client).await?;
764 let mut did_retry = false;
765
766 loop {
767 let request_builder = http_client::Request::builder().method(Method::POST);
768 let request_builder =
769 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
770 request_builder.uri(predict_edits_url)
771 } else {
772 request_builder.uri(
773 http_client
774 .build_zed_llm_url("/predict_edits/v2", &[])?
775 .as_ref(),
776 )
777 };
778 let request = request_builder
779 .header("Content-Type", "application/json")
780 .header("Authorization", format!("Bearer {}", token))
781 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
782 .body(serde_json::to_string(&body)?.into())?;
783
784 let mut response = http_client.send(request).await?;
785
786 if let Some(minimum_required_version) = response
787 .headers()
788 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
789 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
790 {
791 anyhow::ensure!(
792 app_version >= minimum_required_version,
793 ZedUpdateRequiredError {
794 minimum_version: minimum_required_version
795 }
796 );
797 }
798
799 if response.status().is_success() {
800 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
801
802 let mut body = String::new();
803 response.body_mut().read_to_string(&mut body).await?;
804 return Ok((serde_json::from_str(&body)?, usage));
805 } else if !did_retry
806 && response
807 .headers()
808 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
809 .is_some()
810 {
811 did_retry = true;
812 token = llm_token.refresh(&client).await?;
813 } else {
814 let mut body = String::new();
815 response.body_mut().read_to_string(&mut body).await?;
816 anyhow::bail!(
817 "error predicting edits.\nStatus: {:?}\nBody: {}",
818 response.status(),
819 body
820 );
821 }
822 }
823 }
824 }
825
826 fn accept_edit_prediction(
827 &mut self,
828 request_id: EditPredictionId,
829 cx: &mut Context<Self>,
830 ) -> Task<Result<()>> {
831 let client = self.client.clone();
832 let llm_token = self.llm_token.clone();
833 let app_version = AppVersion::global(cx);
834 cx.spawn(async move |this, cx| {
835 let http_client = client.http_client();
836 let mut response = llm_token_retry(&llm_token, &client, |token| {
837 let request_builder = http_client::Request::builder().method(Method::POST);
838 let request_builder =
839 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
840 request_builder.uri(accept_prediction_url)
841 } else {
842 request_builder.uri(
843 http_client
844 .build_zed_llm_url("/predict_edits/accept", &[])?
845 .as_ref(),
846 )
847 };
848 Ok(request_builder
849 .header("Content-Type", "application/json")
850 .header("Authorization", format!("Bearer {}", token))
851 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
852 .body(
853 serde_json::to_string(&AcceptEditPredictionBody {
854 request_id: request_id.0,
855 })?
856 .into(),
857 )?)
858 })
859 .await?;
860
861 if let Some(minimum_required_version) = response
862 .headers()
863 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
864 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
865 && app_version < minimum_required_version
866 {
867 return Err(anyhow!(ZedUpdateRequiredError {
868 minimum_version: minimum_required_version
869 }));
870 }
871
872 if response.status().is_success() {
873 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
874 this.update(cx, |this, cx| {
875 this.user_store.update(cx, |user_store, cx| {
876 user_store.update_edit_prediction_usage(usage, cx);
877 });
878 })?;
879 }
880
881 Ok(())
882 } else {
883 let mut body = String::new();
884 response.body_mut().read_to_string(&mut body).await?;
885 Err(anyhow!(
886 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
887 response.status(),
888 body
889 ))
890 }
891 })
892 }
893
894 fn process_completion_response(
895 prediction_response: PredictEditsResponse,
896 buffer: Entity<Buffer>,
897 snapshot: &BufferSnapshot,
898 editable_range: Range<usize>,
899 cursor_offset: usize,
900 path: Arc<Path>,
901 input_outline: String,
902 input_events: String,
903 input_excerpt: String,
904 buffer_snapshotted_at: Instant,
905 cx: &AsyncApp,
906 ) -> Task<Result<Option<EditPrediction>>> {
907 let snapshot = snapshot.clone();
908 let request_id = prediction_response.request_id;
909 let output_excerpt = prediction_response.output_excerpt;
910 cx.spawn(async move |cx| {
911 let output_excerpt: Arc<str> = output_excerpt.into();
912
913 let edits: Arc<[(Range<Anchor>, String)]> = cx
914 .background_spawn({
915 let output_excerpt = output_excerpt.clone();
916 let editable_range = editable_range.clone();
917 let snapshot = snapshot.clone();
918 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
919 })
920 .await?
921 .into();
922
923 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
924 let edits = edits.clone();
925 |buffer, cx| {
926 let new_snapshot = buffer.snapshot();
927 let edits: Arc<[(Range<Anchor>, String)]> =
928 interpolate(&snapshot, &new_snapshot, edits)?.into();
929 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
930 }
931 })?
932 else {
933 return anyhow::Ok(None);
934 };
935
936 let edit_preview = edit_preview.await;
937
938 Ok(Some(EditPrediction {
939 id: EditPredictionId(request_id),
940 path,
941 excerpt_range: editable_range,
942 cursor_offset,
943 edits,
944 edit_preview,
945 snapshot,
946 input_outline: input_outline.into(),
947 input_events: input_events.into(),
948 input_excerpt: input_excerpt.into(),
949 output_excerpt,
950 buffer_snapshotted_at,
951 response_received_at: Instant::now(),
952 }))
953 })
954 }
955
956 fn parse_edits(
957 output_excerpt: Arc<str>,
958 editable_range: Range<usize>,
959 snapshot: &BufferSnapshot,
960 ) -> Result<Vec<(Range<Anchor>, String)>> {
961 let content = output_excerpt.replace(CURSOR_MARKER, "");
962
963 let start_markers = content
964 .match_indices(EDITABLE_REGION_START_MARKER)
965 .collect::<Vec<_>>();
966 anyhow::ensure!(
967 start_markers.len() == 1,
968 "expected exactly one start marker, found {}",
969 start_markers.len()
970 );
971
972 let end_markers = content
973 .match_indices(EDITABLE_REGION_END_MARKER)
974 .collect::<Vec<_>>();
975 anyhow::ensure!(
976 end_markers.len() == 1,
977 "expected exactly one end marker, found {}",
978 end_markers.len()
979 );
980
981 let sof_markers = content
982 .match_indices(START_OF_FILE_MARKER)
983 .collect::<Vec<_>>();
984 anyhow::ensure!(
985 sof_markers.len() <= 1,
986 "expected at most one start-of-file marker, found {}",
987 sof_markers.len()
988 );
989
990 let codefence_start = start_markers[0].0;
991 let content = &content[codefence_start..];
992
993 let newline_ix = content.find('\n').context("could not find newline")?;
994 let content = &content[newline_ix + 1..];
995
996 let codefence_end = content
997 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
998 .context("could not find end marker")?;
999 let new_text = &content[..codefence_end];
1000
1001 let old_text = snapshot
1002 .text_for_range(editable_range.clone())
1003 .collect::<String>();
1004
1005 Ok(Self::compute_edits(
1006 old_text,
1007 new_text,
1008 editable_range.start,
1009 snapshot,
1010 ))
1011 }
1012
1013 pub fn compute_edits(
1014 old_text: String,
1015 new_text: &str,
1016 offset: usize,
1017 snapshot: &BufferSnapshot,
1018 ) -> Vec<(Range<Anchor>, String)> {
1019 text_diff(&old_text, new_text)
1020 .into_iter()
1021 .map(|(mut old_range, new_text)| {
1022 old_range.start += offset;
1023 old_range.end += offset;
1024
1025 let prefix_len = common_prefix(
1026 snapshot.chars_for_range(old_range.clone()),
1027 new_text.chars(),
1028 );
1029 old_range.start += prefix_len;
1030
1031 let suffix_len = common_prefix(
1032 snapshot.reversed_chars_for_range(old_range.clone()),
1033 new_text[prefix_len..].chars().rev(),
1034 );
1035 old_range.end = old_range.end.saturating_sub(suffix_len);
1036
1037 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1038 let range = if old_range.is_empty() {
1039 let anchor = snapshot.anchor_after(old_range.start);
1040 anchor..anchor
1041 } else {
1042 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1043 };
1044 (range, new_text)
1045 })
1046 .collect()
1047 }
1048
1049 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1050 self.rated_completions.contains(&completion_id)
1051 }
1052
1053 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1054 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1055 let completion = self.shown_completions.pop_back().unwrap();
1056 self.rated_completions.remove(&completion.id);
1057 }
1058 self.shown_completions.push_front(completion.clone());
1059 cx.notify();
1060 }
1061
1062 pub fn rate_completion(
1063 &mut self,
1064 completion: &EditPrediction,
1065 rating: EditPredictionRating,
1066 feedback: String,
1067 cx: &mut Context<Self>,
1068 ) {
1069 self.rated_completions.insert(completion.id);
1070 telemetry::event!(
1071 "Edit Prediction Rated",
1072 rating,
1073 input_events = completion.input_events,
1074 input_excerpt = completion.input_excerpt,
1075 input_outline = completion.input_outline,
1076 output_excerpt = completion.output_excerpt,
1077 feedback
1078 );
1079 self.client.telemetry().flush_events().detach();
1080 cx.notify();
1081 }
1082
1083 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1084 self.shown_completions.iter()
1085 }
1086
1087 pub fn shown_completions_len(&self) -> usize {
1088 self.shown_completions.len()
1089 }
1090
1091 fn report_changes_for_buffer(
1092 &mut self,
1093 buffer: &Entity<Buffer>,
1094 cx: &mut Context<Self>,
1095 ) -> BufferSnapshot {
1096 self.register_buffer(buffer, cx);
1097
1098 let registered_buffer = self
1099 .registered_buffers
1100 .get_mut(&buffer.entity_id())
1101 .unwrap();
1102 let new_snapshot = buffer.read(cx).snapshot();
1103
1104 if new_snapshot.version != registered_buffer.snapshot.version {
1105 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1106 self.push_event(Event::BufferChange {
1107 old_snapshot,
1108 new_snapshot: new_snapshot.clone(),
1109 timestamp: Instant::now(),
1110 });
1111 }
1112
1113 new_snapshot
1114 }
1115
1116 fn load_data_collection_choices() -> DataCollectionChoice {
1117 let choice = KEY_VALUE_STORE
1118 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1119 .log_err()
1120 .flatten();
1121
1122 match choice.as_deref() {
1123 Some("true") => DataCollectionChoice::Enabled,
1124 Some("false") => DataCollectionChoice::Disabled,
1125 Some(_) => {
1126 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1127 DataCollectionChoice::NotAnswered
1128 }
1129 None => DataCollectionChoice::NotAnswered,
1130 }
1131 }
1132
1133 fn gather_git_info(
1134 &mut self,
1135 project: Option<&Entity<Project>>,
1136 buffer_snapshotted_at: &Instant,
1137 snapshot: &BufferSnapshot,
1138 cx: &Context<Self>,
1139 ) -> Option<PredictEditsGitInfo> {
1140 let project = project?.read(cx);
1141 let file = snapshot.file()?;
1142 let project_path = ProjectPath::from_file(file.as_ref(), cx);
1143 let entry = project.entry_for_path(&project_path, cx)?;
1144 if !worktree_entry_eligible_for_collection(&entry) {
1145 return None;
1146 }
1147
1148 let git_store = project.git_store().read(cx);
1149 let (repository, _repo_path) =
1150 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1151
1152 let repository = repository.read(cx);
1153 let head_sha = repository
1154 .head_commit
1155 .as_ref()
1156 .map(|head_commit| head_commit.sha.to_string());
1157 let remote_origin_url = repository.remote_origin_url.clone();
1158 let remote_upstream_url = repository.remote_upstream_url.clone();
1159 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1160 return None;
1161 }
1162
1163 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1164
1165 Some(PredictEditsGitInfo {
1166 head_sha,
1167 remote_origin_url,
1168 remote_upstream_url,
1169 recent_files: Some(recent_files),
1170 })
1171 }
1172
1173 fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1174 let now = Instant::now();
1175 if let Some(existing_ix) = self
1176 .recent_project_entries
1177 .iter()
1178 .rposition(|(id, _)| *id == project_entry_id)
1179 {
1180 self.recent_project_entries.remove(existing_ix);
1181 }
1182 if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1183 self.recent_project_entries.pop_front();
1184 }
1185 self.recent_project_entries
1186 .push_back((project_entry_id, now));
1187 }
1188
1189 fn recent_files(
1190 &mut self,
1191 now: &Instant,
1192 repository: &Repository,
1193 cx: &Context<Self>,
1194 ) -> Vec<PredictEditsRecentFile> {
1195 let Ok(project) = self
1196 .workspace
1197 .read_with(cx, |workspace, _cx| workspace.project().clone())
1198 else {
1199 return Vec::new();
1200 };
1201 let mut results = Vec::new();
1202 for ix in (0..self.recent_project_entries.len()).rev() {
1203 let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1204 if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1205 && let worktree = worktree.read(cx)
1206 && let Some(entry) = worktree.entry_for_id(*entry_id)
1207 && worktree_entry_eligible_for_collection(entry)
1208 {
1209 let project_path = ProjectPath {
1210 worktree_id: worktree.id(),
1211 path: entry.path.clone(),
1212 };
1213 let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1214 else {
1215 // entry not removed since queries involving other repositories might occur later
1216 continue;
1217 };
1218 let Some(repo_path_str) = repo_path.to_str() else {
1219 // paths may not be valid UTF-8
1220 self.recent_project_entries.remove(ix);
1221 continue;
1222 };
1223 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1224 self.recent_project_entries.remove(ix);
1225 continue;
1226 }
1227 let Ok(active_to_now_ms) =
1228 now.duration_since(*last_active_at).as_millis().try_into()
1229 else {
1230 self.recent_project_entries.remove(ix);
1231 continue;
1232 };
1233 results.push(PredictEditsRecentFile {
1234 repo_path: repo_path_str.to_string(),
1235 active_to_now_ms,
1236 });
1237 } else {
1238 self.recent_project_entries.remove(ix);
1239 }
1240 }
1241 results
1242 }
1243}
1244
1245fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1246 entry.is_file()
1247 && entry.is_created()
1248 && !entry.is_ignored
1249 && !entry.is_private
1250 && !entry.is_external
1251 && !entry.is_fifo
1252}
1253
1254pub struct PerformPredictEditsParams {
1255 pub client: Arc<Client>,
1256 pub llm_token: LlmApiToken,
1257 pub app_version: SemanticVersion,
1258 pub body: PredictEditsBody,
1259}
1260
1261#[derive(Error, Debug)]
1262#[error(
1263 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1264)]
1265pub struct ZedUpdateRequiredError {
1266 minimum_version: SemanticVersion,
1267}
1268
1269fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1270 a.zip(b)
1271 .take_while(|(a, b)| a == b)
1272 .map(|(a, _)| a.len_utf8())
1273 .sum()
1274}
1275
1276pub struct GatherContextOutput {
1277 pub body: PredictEditsBody,
1278 pub editable_range: Range<usize>,
1279}
1280
1281pub fn gather_context(
1282 project: Option<&Entity<Project>>,
1283 full_path_str: String,
1284 snapshot: &BufferSnapshot,
1285 cursor_point: language::Point,
1286 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1287 can_collect_data: CanCollectData,
1288 git_info: Option<PredictEditsGitInfo>,
1289 cx: &App,
1290) -> Task<Result<GatherContextOutput>> {
1291 let local_lsp_store =
1292 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1293 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1294 if matches!(can_collect_data, CanCollectData(true))
1295 && let Some(local_lsp_store) = local_lsp_store
1296 {
1297 snapshot
1298 .diagnostic_groups(None)
1299 .into_iter()
1300 .filter_map(|(language_server_id, diagnostic_group)| {
1301 let language_server =
1302 local_lsp_store.running_language_server_for_id(language_server_id)?;
1303 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1304 let language_server_name = language_server.name().to_string();
1305 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1306 Some((language_server_name, serialized))
1307 })
1308 .collect::<Vec<_>>()
1309 } else {
1310 Vec::new()
1311 };
1312
1313 cx.background_spawn({
1314 let snapshot = snapshot.clone();
1315 async move {
1316 let diagnostic_groups = if diagnostic_groups.is_empty()
1317 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1318 {
1319 None
1320 } else {
1321 Some(diagnostic_groups)
1322 };
1323
1324 let input_excerpt = excerpt_for_cursor_position(
1325 cursor_point,
1326 &full_path_str,
1327 &snapshot,
1328 MAX_REWRITE_TOKENS,
1329 MAX_CONTEXT_TOKENS,
1330 );
1331 let input_events = make_events_prompt();
1332 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1333
1334 let body = PredictEditsBody {
1335 input_events,
1336 input_excerpt: input_excerpt.prompt,
1337 can_collect_data: can_collect_data.0,
1338 diagnostic_groups,
1339 git_info,
1340 outline: None,
1341 speculated_output: None,
1342 };
1343
1344 Ok(GatherContextOutput {
1345 body,
1346 editable_range,
1347 })
1348 }
1349 })
1350}
1351
1352fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1353 let mut result = String::new();
1354 for event in events.iter().rev() {
1355 let event_string = event.to_prompt();
1356 let event_tokens = tokens_for_bytes(event_string.len());
1357 if event_tokens > remaining_tokens {
1358 break;
1359 }
1360
1361 if !result.is_empty() {
1362 result.insert_str(0, "\n\n");
1363 }
1364 result.insert_str(0, &event_string);
1365 remaining_tokens -= event_tokens;
1366 }
1367 result
1368}
1369
1370struct RegisteredBuffer {
1371 snapshot: BufferSnapshot,
1372 _subscriptions: [gpui::Subscription; 2],
1373}
1374
1375#[derive(Clone)]
1376pub enum Event {
1377 BufferChange {
1378 old_snapshot: BufferSnapshot,
1379 new_snapshot: BufferSnapshot,
1380 timestamp: Instant,
1381 },
1382}
1383
1384impl Event {
1385 fn to_prompt(&self) -> String {
1386 match self {
1387 Event::BufferChange {
1388 old_snapshot,
1389 new_snapshot,
1390 ..
1391 } => {
1392 let mut prompt = String::new();
1393
1394 let old_path = old_snapshot
1395 .file()
1396 .map(|f| f.path().as_ref())
1397 .unwrap_or(Path::new("untitled"));
1398 let new_path = new_snapshot
1399 .file()
1400 .map(|f| f.path().as_ref())
1401 .unwrap_or(Path::new("untitled"));
1402 if old_path != new_path {
1403 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1404 }
1405
1406 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1407 if !diff.is_empty() {
1408 write!(
1409 prompt,
1410 "User edited {:?}:\n```diff\n{}\n```",
1411 new_path, diff
1412 )
1413 .unwrap();
1414 }
1415
1416 prompt
1417 }
1418 }
1419 }
1420}
1421
1422#[derive(Debug, Clone)]
1423struct CurrentEditPrediction {
1424 buffer_id: EntityId,
1425 completion: EditPrediction,
1426}
1427
1428impl CurrentEditPrediction {
1429 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1430 if self.buffer_id != old_completion.buffer_id {
1431 return true;
1432 }
1433
1434 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1435 return true;
1436 };
1437 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1438 return false;
1439 };
1440
1441 if old_edits.len() == 1 && new_edits.len() == 1 {
1442 let (old_range, old_text) = &old_edits[0];
1443 let (new_range, new_text) = &new_edits[0];
1444 new_range == old_range && new_text.starts_with(old_text)
1445 } else {
1446 true
1447 }
1448 }
1449}
1450
1451struct PendingCompletion {
1452 id: usize,
1453 _task: Task<()>,
1454}
1455
1456#[derive(Debug, Clone, Copy)]
1457pub enum DataCollectionChoice {
1458 NotAnswered,
1459 Enabled,
1460 Disabled,
1461}
1462
1463impl DataCollectionChoice {
1464 pub fn is_enabled(self) -> bool {
1465 match self {
1466 Self::Enabled => true,
1467 Self::NotAnswered | Self::Disabled => false,
1468 }
1469 }
1470
1471 pub fn is_answered(self) -> bool {
1472 match self {
1473 Self::Enabled | Self::Disabled => true,
1474 Self::NotAnswered => false,
1475 }
1476 }
1477
1478 pub fn toggle(&self) -> DataCollectionChoice {
1479 match self {
1480 Self::Enabled => Self::Disabled,
1481 Self::Disabled => Self::Enabled,
1482 Self::NotAnswered => Self::Enabled,
1483 }
1484 }
1485}
1486
1487impl From<bool> for DataCollectionChoice {
1488 fn from(value: bool) -> Self {
1489 match value {
1490 true => DataCollectionChoice::Enabled,
1491 false => DataCollectionChoice::Disabled,
1492 }
1493 }
1494}
1495
1496pub struct ProviderDataCollection {
1497 /// When set to None, data collection is not possible in the provider buffer
1498 choice: Option<Entity<DataCollectionChoice>>,
1499 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1500}
1501
1502#[derive(Debug, Clone, Copy)]
1503pub struct CanCollectData(pub bool);
1504
1505impl ProviderDataCollection {
1506 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1507 let choice_and_watcher = buffer.and_then(|buffer| {
1508 let file = buffer.read(cx).file()?;
1509
1510 if !file.is_local() || file.is_private() {
1511 return None;
1512 }
1513
1514 let zeta = zeta.read(cx);
1515 let choice = zeta.data_collection_choice.clone();
1516
1517 let license_detection_watcher = zeta
1518 .license_detection_watchers
1519 .get(&file.worktree_id(cx))
1520 .cloned()?;
1521
1522 Some((choice, license_detection_watcher))
1523 });
1524
1525 if let Some((choice, watcher)) = choice_and_watcher {
1526 ProviderDataCollection {
1527 choice: Some(choice),
1528 license_detection_watcher: Some(watcher),
1529 }
1530 } else {
1531 ProviderDataCollection {
1532 choice: None,
1533 license_detection_watcher: None,
1534 }
1535 }
1536 }
1537
1538 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1539 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1540 }
1541
1542 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1543 self.choice
1544 .as_ref()
1545 .is_some_and(|choice| choice.read(cx).is_enabled())
1546 }
1547
1548 fn is_project_open_source(&self) -> bool {
1549 self.license_detection_watcher
1550 .as_ref()
1551 .is_some_and(|watcher| watcher.is_project_open_source())
1552 }
1553
1554 pub fn toggle(&mut self, cx: &mut App) {
1555 if let Some(choice) = self.choice.as_mut() {
1556 let new_choice = choice.update(cx, |choice, _cx| {
1557 let new_choice = choice.toggle();
1558 *choice = new_choice;
1559 new_choice
1560 });
1561
1562 db::write_and_log(cx, move || {
1563 KEY_VALUE_STORE.write_kvp(
1564 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1565 new_choice.is_enabled().to_string(),
1566 )
1567 });
1568 }
1569 }
1570}
1571
1572async fn llm_token_retry(
1573 llm_token: &LlmApiToken,
1574 client: &Arc<Client>,
1575 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1576) -> Result<Response<AsyncBody>> {
1577 let mut did_retry = false;
1578 let http_client = client.http_client();
1579 let mut token = llm_token.acquire(client).await?;
1580 loop {
1581 let request = build_request(token.clone())?;
1582 let response = http_client.send(request).await?;
1583
1584 if !did_retry
1585 && !response.status().is_success()
1586 && response
1587 .headers()
1588 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1589 .is_some()
1590 {
1591 did_retry = true;
1592 token = llm_token.refresh(client).await?;
1593 continue;
1594 }
1595
1596 return Ok(response);
1597 }
1598}
1599
1600pub struct ZetaEditPredictionProvider {
1601 zeta: Entity<Zeta>,
1602 pending_completions: ArrayVec<PendingCompletion, 2>,
1603 next_pending_completion_id: usize,
1604 current_completion: Option<CurrentEditPrediction>,
1605 /// None if this is entirely disabled for this provider
1606 provider_data_collection: ProviderDataCollection,
1607 last_request_timestamp: Instant,
1608}
1609
1610impl ZetaEditPredictionProvider {
1611 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1612
1613 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1614 Self {
1615 zeta,
1616 pending_completions: ArrayVec::new(),
1617 next_pending_completion_id: 0,
1618 current_completion: None,
1619 provider_data_collection,
1620 last_request_timestamp: Instant::now(),
1621 }
1622 }
1623}
1624
1625impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1626 fn name() -> &'static str {
1627 "zed-predict"
1628 }
1629
1630 fn display_name() -> &'static str {
1631 "Zed's Edit Predictions"
1632 }
1633
1634 fn show_completions_in_menu() -> bool {
1635 true
1636 }
1637
1638 fn show_tab_accept_marker() -> bool {
1639 true
1640 }
1641
1642 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1643 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1644
1645 if self.provider_data_collection.is_data_collection_enabled(cx) {
1646 DataCollectionState::Enabled {
1647 is_project_open_source,
1648 }
1649 } else {
1650 DataCollectionState::Disabled {
1651 is_project_open_source,
1652 }
1653 }
1654 }
1655
1656 fn toggle_data_collection(&mut self, cx: &mut App) {
1657 self.provider_data_collection.toggle(cx);
1658 }
1659
1660 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1661 self.zeta.read(cx).usage(cx)
1662 }
1663
1664 fn is_enabled(
1665 &self,
1666 _buffer: &Entity<Buffer>,
1667 _cursor_position: language::Anchor,
1668 _cx: &App,
1669 ) -> bool {
1670 true
1671 }
1672 fn is_refreshing(&self) -> bool {
1673 !self.pending_completions.is_empty()
1674 }
1675
1676 fn refresh(
1677 &mut self,
1678 project: Option<Entity<Project>>,
1679 buffer: Entity<Buffer>,
1680 position: language::Anchor,
1681 _debounce: bool,
1682 cx: &mut Context<Self>,
1683 ) {
1684 if self.zeta.read(cx).update_required {
1685 return;
1686 }
1687
1688 if self
1689 .zeta
1690 .read(cx)
1691 .user_store
1692 .read_with(cx, |user_store, _cx| {
1693 user_store.account_too_young() || user_store.has_overdue_invoices()
1694 })
1695 {
1696 return;
1697 }
1698
1699 if let Some(current_completion) = self.current_completion.as_ref() {
1700 let snapshot = buffer.read(cx).snapshot();
1701 if current_completion
1702 .completion
1703 .interpolate(&snapshot)
1704 .is_some()
1705 {
1706 return;
1707 }
1708 }
1709
1710 let pending_completion_id = self.next_pending_completion_id;
1711 self.next_pending_completion_id += 1;
1712 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1713 let last_request_timestamp = self.last_request_timestamp;
1714
1715 let task = cx.spawn(async move |this, cx| {
1716 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1717 .checked_duration_since(Instant::now())
1718 {
1719 cx.background_executor().timer(timeout).await;
1720 }
1721
1722 let completion_request = this.update(cx, |this, cx| {
1723 this.last_request_timestamp = Instant::now();
1724 this.zeta.update(cx, |zeta, cx| {
1725 zeta.request_completion(
1726 project.as_ref(),
1727 &buffer,
1728 position,
1729 can_collect_data,
1730 cx,
1731 )
1732 })
1733 });
1734
1735 let completion = match completion_request {
1736 Ok(completion_request) => {
1737 let completion_request = completion_request.await;
1738 completion_request.map(|c| {
1739 c.map(|completion| CurrentEditPrediction {
1740 buffer_id: buffer.entity_id(),
1741 completion,
1742 })
1743 })
1744 }
1745 Err(error) => Err(error),
1746 };
1747 let Some(new_completion) = completion
1748 .context("edit prediction failed")
1749 .log_err()
1750 .flatten()
1751 else {
1752 this.update(cx, |this, cx| {
1753 if this.pending_completions[0].id == pending_completion_id {
1754 this.pending_completions.remove(0);
1755 } else {
1756 this.pending_completions.clear();
1757 }
1758
1759 cx.notify();
1760 })
1761 .ok();
1762 return;
1763 };
1764
1765 this.update(cx, |this, cx| {
1766 if this.pending_completions[0].id == pending_completion_id {
1767 this.pending_completions.remove(0);
1768 } else {
1769 this.pending_completions.clear();
1770 }
1771
1772 if let Some(old_completion) = this.current_completion.as_ref() {
1773 let snapshot = buffer.read(cx).snapshot();
1774 if new_completion.should_replace_completion(old_completion, &snapshot) {
1775 this.zeta.update(cx, |zeta, cx| {
1776 zeta.completion_shown(&new_completion.completion, cx);
1777 });
1778 this.current_completion = Some(new_completion);
1779 }
1780 } else {
1781 this.zeta.update(cx, |zeta, cx| {
1782 zeta.completion_shown(&new_completion.completion, cx);
1783 });
1784 this.current_completion = Some(new_completion);
1785 }
1786
1787 cx.notify();
1788 })
1789 .ok();
1790 });
1791
1792 // We always maintain at most two pending completions. When we already
1793 // have two, we replace the newest one.
1794 if self.pending_completions.len() <= 1 {
1795 self.pending_completions.push(PendingCompletion {
1796 id: pending_completion_id,
1797 _task: task,
1798 });
1799 } else if self.pending_completions.len() == 2 {
1800 self.pending_completions.pop();
1801 self.pending_completions.push(PendingCompletion {
1802 id: pending_completion_id,
1803 _task: task,
1804 });
1805 }
1806 }
1807
1808 fn cycle(
1809 &mut self,
1810 _buffer: Entity<Buffer>,
1811 _cursor_position: language::Anchor,
1812 _direction: edit_prediction::Direction,
1813 _cx: &mut Context<Self>,
1814 ) {
1815 // Right now we don't support cycling.
1816 }
1817
1818 fn accept(&mut self, cx: &mut Context<Self>) {
1819 let completion_id = self
1820 .current_completion
1821 .as_ref()
1822 .map(|completion| completion.completion.id);
1823 if let Some(completion_id) = completion_id {
1824 self.zeta
1825 .update(cx, |zeta, cx| {
1826 zeta.accept_edit_prediction(completion_id, cx)
1827 })
1828 .detach();
1829 }
1830 self.pending_completions.clear();
1831 }
1832
1833 fn discard(&mut self, _cx: &mut Context<Self>) {
1834 self.pending_completions.clear();
1835 self.current_completion.take();
1836 }
1837
1838 fn suggest(
1839 &mut self,
1840 buffer: &Entity<Buffer>,
1841 cursor_position: language::Anchor,
1842 cx: &mut Context<Self>,
1843 ) -> Option<edit_prediction::EditPrediction> {
1844 let CurrentEditPrediction {
1845 buffer_id,
1846 completion,
1847 ..
1848 } = self.current_completion.as_mut()?;
1849
1850 // Invalidate previous completion if it was generated for a different buffer.
1851 if *buffer_id != buffer.entity_id() {
1852 self.current_completion.take();
1853 return None;
1854 }
1855
1856 let buffer = buffer.read(cx);
1857 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1858 self.current_completion.take();
1859 return None;
1860 };
1861
1862 let cursor_row = cursor_position.to_point(buffer).row;
1863 let (closest_edit_ix, (closest_edit_range, _)) =
1864 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1865 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1866 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1867 cmp::min(distance_from_start, distance_from_end)
1868 })?;
1869
1870 let mut edit_start_ix = closest_edit_ix;
1871 for (range, _) in edits[..edit_start_ix].iter().rev() {
1872 let distance_from_closest_edit =
1873 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1874 if distance_from_closest_edit <= 1 {
1875 edit_start_ix -= 1;
1876 } else {
1877 break;
1878 }
1879 }
1880
1881 let mut edit_end_ix = closest_edit_ix + 1;
1882 for (range, _) in &edits[edit_end_ix..] {
1883 let distance_from_closest_edit =
1884 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1885 if distance_from_closest_edit <= 1 {
1886 edit_end_ix += 1;
1887 } else {
1888 break;
1889 }
1890 }
1891
1892 Some(edit_prediction::EditPrediction {
1893 id: Some(completion.id.to_string().into()),
1894 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1895 edit_preview: Some(completion.edit_preview.clone()),
1896 })
1897 }
1898}
1899
1900fn tokens_for_bytes(bytes: usize) -> usize {
1901 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1902 /// intentionally low to err on the side of underestimating limits.
1903 const BYTES_PER_TOKEN_GUESS: usize = 3;
1904 bytes / BYTES_PER_TOKEN_GUESS
1905}
1906
1907#[cfg(test)]
1908mod tests {
1909 use client::UserStore;
1910 use client::test::FakeServer;
1911 use clock::FakeSystemClock;
1912 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1913 use gpui::TestAppContext;
1914 use http_client::FakeHttpClient;
1915 use indoc::indoc;
1916 use language::Point;
1917 use settings::SettingsStore;
1918
1919 use super::*;
1920
1921 #[gpui::test]
1922 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1923 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1924 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1925 to_completion_edits(
1926 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1927 &buffer,
1928 cx,
1929 )
1930 .into()
1931 });
1932
1933 let edit_preview = cx
1934 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1935 .await;
1936
1937 let completion = EditPrediction {
1938 edits,
1939 edit_preview,
1940 path: Path::new("").into(),
1941 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1942 id: EditPredictionId(Uuid::new_v4()),
1943 excerpt_range: 0..0,
1944 cursor_offset: 0,
1945 input_outline: "".into(),
1946 input_events: "".into(),
1947 input_excerpt: "".into(),
1948 output_excerpt: "".into(),
1949 buffer_snapshotted_at: Instant::now(),
1950 response_received_at: Instant::now(),
1951 };
1952
1953 cx.update(|cx| {
1954 assert_eq!(
1955 from_completion_edits(
1956 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1957 &buffer,
1958 cx
1959 ),
1960 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1961 );
1962
1963 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1964 assert_eq!(
1965 from_completion_edits(
1966 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1967 &buffer,
1968 cx
1969 ),
1970 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1971 );
1972
1973 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1974 assert_eq!(
1975 from_completion_edits(
1976 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1977 &buffer,
1978 cx
1979 ),
1980 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1981 );
1982
1983 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1984 assert_eq!(
1985 from_completion_edits(
1986 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1987 &buffer,
1988 cx
1989 ),
1990 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1991 );
1992
1993 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1994 assert_eq!(
1995 from_completion_edits(
1996 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1997 &buffer,
1998 cx
1999 ),
2000 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2001 );
2002
2003 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2004 assert_eq!(
2005 from_completion_edits(
2006 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2007 &buffer,
2008 cx
2009 ),
2010 vec![(9..11, "".to_string())]
2011 );
2012
2013 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2014 assert_eq!(
2015 from_completion_edits(
2016 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2017 &buffer,
2018 cx
2019 ),
2020 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2021 );
2022
2023 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2024 assert_eq!(
2025 from_completion_edits(
2026 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2027 &buffer,
2028 cx
2029 ),
2030 vec![(4..4, "M".to_string())]
2031 );
2032
2033 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2034 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2035 })
2036 }
2037
2038 #[gpui::test]
2039 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2040 cx.update(|cx| {
2041 let settings_store = SettingsStore::test(cx);
2042 cx.set_global(settings_store);
2043 client::init_settings(cx);
2044 });
2045
2046 let edits = edits_for_prediction(
2047 indoc! {"
2048 fn main() {
2049 let word_1 = \"lorem\";
2050 let range = word.len()..word.len();
2051 }
2052 "},
2053 indoc! {"
2054 <|editable_region_start|>
2055 fn main() {
2056 let word_1 = \"lorem\";
2057 let range = word_1.len()..word_1.len();
2058 }
2059
2060 <|editable_region_end|>
2061 "},
2062 cx,
2063 )
2064 .await;
2065 assert_eq!(
2066 edits,
2067 [
2068 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2069 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2070 ]
2071 );
2072
2073 let edits = edits_for_prediction(
2074 indoc! {"
2075 fn main() {
2076 let story = \"the quick\"
2077 }
2078 "},
2079 indoc! {"
2080 <|editable_region_start|>
2081 fn main() {
2082 let story = \"the quick brown fox jumps over the lazy dog\";
2083 }
2084
2085 <|editable_region_end|>
2086 "},
2087 cx,
2088 )
2089 .await;
2090 assert_eq!(
2091 edits,
2092 [
2093 (
2094 Point::new(1, 26)..Point::new(1, 26),
2095 " brown fox jumps over the lazy dog".to_string()
2096 ),
2097 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2098 ]
2099 );
2100 }
2101
2102 #[gpui::test]
2103 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2104 cx.update(|cx| {
2105 let settings_store = SettingsStore::test(cx);
2106 cx.set_global(settings_store);
2107 client::init_settings(cx);
2108 });
2109
2110 let buffer_content = "lorem\n";
2111 let completion_response = indoc! {"
2112 ```animals.js
2113 <|start_of_file|>
2114 <|editable_region_start|>
2115 lorem
2116 ipsum
2117 <|editable_region_end|>
2118 ```"};
2119
2120 let http_client = FakeHttpClient::create(move |req| async move {
2121 match (req.method(), req.uri().path()) {
2122 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2123 .status(200)
2124 .body(
2125 serde_json::to_string(&CreateLlmTokenResponse {
2126 token: LlmToken("the-llm-token".to_string()),
2127 })
2128 .unwrap()
2129 .into(),
2130 )
2131 .unwrap()),
2132 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2133 .status(200)
2134 .body(
2135 serde_json::to_string(&PredictEditsResponse {
2136 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2137 .unwrap(),
2138 output_excerpt: completion_response.to_string(),
2139 })
2140 .unwrap()
2141 .into(),
2142 )
2143 .unwrap()),
2144 _ => Ok(http_client::Response::builder()
2145 .status(404)
2146 .body("Not Found".into())
2147 .unwrap()),
2148 }
2149 });
2150
2151 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2152 cx.update(|cx| {
2153 RefreshLlmTokenListener::register(client.clone(), cx);
2154 });
2155 // Construct the fake server to authenticate.
2156 let _server = FakeServer::for_client(42, &client, cx).await;
2157 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2158 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2159
2160 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2161 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2162 let completion_task = zeta.update(cx, |zeta, cx| {
2163 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2164 });
2165
2166 let completion = completion_task.await.unwrap().unwrap();
2167 buffer.update(cx, |buffer, cx| {
2168 buffer.edit(completion.edits.iter().cloned(), None, cx)
2169 });
2170 assert_eq!(
2171 buffer.read_with(cx, |buffer, _| buffer.text()),
2172 "lorem\nipsum"
2173 );
2174 }
2175
2176 async fn edits_for_prediction(
2177 buffer_content: &str,
2178 completion_response: &str,
2179 cx: &mut TestAppContext,
2180 ) -> Vec<(Range<Point>, String)> {
2181 let completion_response = completion_response.to_string();
2182 let http_client = FakeHttpClient::create(move |req| {
2183 let completion = completion_response.clone();
2184 async move {
2185 match (req.method(), req.uri().path()) {
2186 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2187 .status(200)
2188 .body(
2189 serde_json::to_string(&CreateLlmTokenResponse {
2190 token: LlmToken("the-llm-token".to_string()),
2191 })
2192 .unwrap()
2193 .into(),
2194 )
2195 .unwrap()),
2196 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2197 .status(200)
2198 .body(
2199 serde_json::to_string(&PredictEditsResponse {
2200 request_id: Uuid::new_v4(),
2201 output_excerpt: completion,
2202 })
2203 .unwrap()
2204 .into(),
2205 )
2206 .unwrap()),
2207 _ => Ok(http_client::Response::builder()
2208 .status(404)
2209 .body("Not Found".into())
2210 .unwrap()),
2211 }
2212 }
2213 });
2214
2215 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2216 cx.update(|cx| {
2217 RefreshLlmTokenListener::register(client.clone(), cx);
2218 });
2219 // Construct the fake server to authenticate.
2220 let _server = FakeServer::for_client(42, &client, cx).await;
2221 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2222 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2223
2224 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2225 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2226 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2227 let completion_task = zeta.update(cx, |zeta, cx| {
2228 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2229 });
2230
2231 let completion = completion_task.await.unwrap().unwrap();
2232 completion
2233 .edits
2234 .iter()
2235 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2236 .collect::<Vec<_>>()
2237 }
2238
2239 fn to_completion_edits(
2240 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2241 buffer: &Entity<Buffer>,
2242 cx: &App,
2243 ) -> Vec<(Range<Anchor>, String)> {
2244 let buffer = buffer.read(cx);
2245 iterator
2246 .into_iter()
2247 .map(|(range, text)| {
2248 (
2249 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2250 text,
2251 )
2252 })
2253 .collect()
2254 }
2255
2256 fn from_completion_edits(
2257 editor_edits: &[(Range<Anchor>, String)],
2258 buffer: &Entity<Buffer>,
2259 cx: &App,
2260 ) -> Vec<(Range<usize>, String)> {
2261 let buffer = buffer.read(cx);
2262 editor_edits
2263 .iter()
2264 .map(|(range, text)| {
2265 (
2266 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2267 text.clone(),
2268 )
2269 })
2270 .collect()
2271 }
2272
2273 #[ctor::ctor]
2274 fn init_logger() {
2275 zlog::init_test();
2276 }
2277}