1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13pub use init::*;
14use license_detection::LicenseDetectionWatcher;
15use project::git_store::Repository;
16pub use rate_completion_modal::*;
17
18use anyhow::{Context as _, Result, anyhow};
19use client::{Client, EditPredictionUsage, UserStore};
20use cloud_llm_client::{
21 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
22 PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
23 ZED_VERSION_HEADER_NAME,
24};
25use collections::{HashMap, HashSet, VecDeque};
26use futures::AsyncReadExt;
27use gpui::{
28 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
29 Subscription, Task, WeakEntity, actions,
30};
31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
32use input_excerpt::excerpt_for_cursor_position;
33use language::{
34 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
35};
36use language_model::{LlmApiToken, RefreshLlmTokenListener};
37use project::{Project, ProjectEntryId, ProjectPath};
38use release_channel::AppVersion;
39use settings::WorktreeId;
40use std::str::FromStr;
41use std::{
42 cmp,
43 fmt::Write,
44 future::Future,
45 mem,
46 ops::Range,
47 path::Path,
48 rc::Rc,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use telemetry_events::EditPredictionRating;
53use thiserror::Error;
54use util::ResultExt;
55use uuid::Uuid;
56use workspace::Workspace;
57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
58use worktree::Worktree;
59
60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
66
67const MAX_CONTEXT_TOKENS: usize = 150;
68const MAX_REWRITE_TOKENS: usize = 350;
69const MAX_EVENT_TOKENS: usize = 500;
70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Maximum file path length to include in recent files list.
79const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
80
81/// Maximum number of edit predictions to store for feedback.
82const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
83
84actions!(
85 edit_prediction,
86 [
87 /// Clears the edit prediction history.
88 ClearHistory
89 ]
90);
91
92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
93pub struct EditPredictionId(Uuid);
94
95impl From<EditPredictionId> for gpui::ElementId {
96 fn from(value: EditPredictionId) -> Self {
97 gpui::ElementId::Uuid(value.0)
98 }
99}
100
101impl std::fmt::Display for EditPredictionId {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{}", self.0)
104 }
105}
106
107struct ZedPredictUpsell;
108
109impl Dismissable for ZedPredictUpsell {
110 const KEY: &'static str = "dismissed-edit-predict-upsell";
111
112 fn dismissed() -> bool {
113 // To make this backwards compatible with older versions of Zed, we
114 // check if the user has seen the previous Edit Prediction Onboarding
115 // before, by checking the data collection choice which was written to
116 // the database once the user clicked on "Accept and Enable"
117 if KEY_VALUE_STORE
118 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
119 .log_err()
120 .is_some_and(|s| s.is_some())
121 {
122 return true;
123 }
124
125 KEY_VALUE_STORE
126 .read_kvp(Self::KEY)
127 .log_err()
128 .is_some_and(|s| s.is_some())
129 }
130}
131
132pub fn should_show_upsell_modal() -> bool {
133 !ZedPredictUpsell::dismissed()
134}
135
136#[derive(Clone)]
137struct ZetaGlobal(Entity<Zeta>);
138
139impl Global for ZetaGlobal {}
140
141#[derive(Clone)]
142pub struct EditPrediction {
143 id: EditPredictionId,
144 path: Arc<Path>,
145 excerpt_range: Range<usize>,
146 cursor_offset: usize,
147 edits: Arc<[(Range<Anchor>, String)]>,
148 snapshot: BufferSnapshot,
149 edit_preview: EditPreview,
150 input_outline: Arc<str>,
151 input_events: Arc<str>,
152 input_excerpt: Arc<str>,
153 output_excerpt: Arc<str>,
154 buffer_snapshotted_at: Instant,
155 response_received_at: Instant,
156}
157
158impl EditPrediction {
159 fn latency(&self) -> Duration {
160 self.response_received_at
161 .duration_since(self.buffer_snapshotted_at)
162 }
163
164 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
165 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
166 }
167}
168
169fn interpolate(
170 old_snapshot: &BufferSnapshot,
171 new_snapshot: &BufferSnapshot,
172 current_edits: Arc<[(Range<Anchor>, String)]>,
173) -> Option<Vec<(Range<Anchor>, String)>> {
174 let mut edits = Vec::new();
175
176 let mut model_edits = current_edits.iter().peekable();
177 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
178 while let Some((model_old_range, _)) = model_edits.peek() {
179 let model_old_range = model_old_range.to_offset(old_snapshot);
180 if model_old_range.end < user_edit.old.start {
181 let (model_old_range, model_new_text) = model_edits.next().unwrap();
182 edits.push((model_old_range.clone(), model_new_text.clone()));
183 } else {
184 break;
185 }
186 }
187
188 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
189 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
190 if user_edit.old == model_old_offset_range {
191 let user_new_text = new_snapshot
192 .text_for_range(user_edit.new.clone())
193 .collect::<String>();
194
195 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
196 if !model_suffix.is_empty() {
197 let anchor = old_snapshot.anchor_after(user_edit.old.end);
198 edits.push((anchor..anchor, model_suffix.to_string()));
199 }
200
201 model_edits.next();
202 continue;
203 }
204 }
205 }
206
207 return None;
208 }
209
210 edits.extend(model_edits.cloned());
211
212 if edits.is_empty() { None } else { Some(edits) }
213}
214
215impl std::fmt::Debug for EditPrediction {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 f.debug_struct("EditPrediction")
218 .field("id", &self.id)
219 .field("path", &self.path)
220 .field("edits", &self.edits)
221 .finish_non_exhaustive()
222 }
223}
224
225pub struct Zeta {
226 workspace: WeakEntity<Workspace>,
227 client: Arc<Client>,
228 events: VecDeque<Event>,
229 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
230 shown_completions: VecDeque<EditPrediction>,
231 rated_completions: HashSet<EditPredictionId>,
232 data_collection_choice: Entity<DataCollectionChoice>,
233 llm_token: LlmApiToken,
234 _llm_token_subscription: Subscription,
235 /// Whether an update to a newer version of Zed is required to continue using Zeta.
236 update_required: bool,
237 user_store: Entity<UserStore>,
238 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
239 recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
240}
241
242impl Zeta {
243 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
244 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
245 }
246
247 pub fn register(
248 workspace: Option<Entity<Workspace>>,
249 worktree: Option<Entity<Worktree>>,
250 client: Arc<Client>,
251 user_store: Entity<UserStore>,
252 cx: &mut App,
253 ) -> Entity<Self> {
254 let this = Self::global(cx).unwrap_or_else(|| {
255 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
256 cx.set_global(ZetaGlobal(entity.clone()));
257 entity
258 });
259
260 this.update(cx, move |this, cx| {
261 if let Some(worktree) = worktree {
262 let worktree_id = worktree.read(cx).id();
263 this.license_detection_watchers
264 .entry(worktree_id)
265 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
266 }
267 });
268
269 this
270 }
271
272 pub fn clear_history(&mut self) {
273 self.events.clear();
274 }
275
276 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
277 self.user_store.read(cx).edit_prediction_usage()
278 }
279
280 fn new(
281 workspace: Option<Entity<Workspace>>,
282 client: Arc<Client>,
283 user_store: Entity<UserStore>,
284 cx: &mut Context<Self>,
285 ) -> Self {
286 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
287
288 let data_collection_choice = Self::load_data_collection_choices();
289 let data_collection_choice = cx.new(|_| data_collection_choice);
290
291 if let Some(workspace) = &workspace {
292 cx.subscribe(
293 &workspace.read(cx).project().clone(),
294 |this, _workspace, event, _cx| match event {
295 project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
296 this.push_recent_project_entry(*project_entry_id)
297 }
298 _ => {}
299 },
300 )
301 .detach();
302 }
303
304 Self {
305 workspace: workspace.map_or_else(
306 || WeakEntity::new_invalid(),
307 |workspace| workspace.downgrade(),
308 ),
309 client,
310 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
311 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
312 rated_completions: HashSet::default(),
313 registered_buffers: HashMap::default(),
314 data_collection_choice,
315 llm_token: LlmApiToken::default(),
316 _llm_token_subscription: cx.subscribe(
317 &refresh_llm_token_listener,
318 |this, _listener, _event, cx| {
319 let client = this.client.clone();
320 let llm_token = this.llm_token.clone();
321 cx.spawn(async move |_this, _cx| {
322 llm_token.refresh(&client).await?;
323 anyhow::Ok(())
324 })
325 .detach_and_log_err(cx);
326 },
327 ),
328 update_required: false,
329 license_detection_watchers: HashMap::default(),
330 user_store,
331 recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
332 }
333 }
334
335 fn push_event(&mut self, event: Event) {
336 if let Some(Event::BufferChange {
337 new_snapshot: last_new_snapshot,
338 timestamp: last_timestamp,
339 ..
340 }) = self.events.back_mut()
341 {
342 // Coalesce edits for the same buffer when they happen one after the other.
343 let Event::BufferChange {
344 old_snapshot,
345 new_snapshot,
346 timestamp,
347 } = &event;
348
349 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
350 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
351 && old_snapshot.version == last_new_snapshot.version
352 {
353 *last_new_snapshot = new_snapshot.clone();
354 *last_timestamp = *timestamp;
355 return;
356 }
357 }
358
359 if self.events.len() >= MAX_EVENT_COUNT {
360 // These are halved instead of popping to improve prompt caching.
361 self.events.drain(..MAX_EVENT_COUNT / 2);
362 }
363
364 self.events.push_back(event);
365 }
366
367 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
368 let buffer_id = buffer.entity_id();
369 let weak_buffer = buffer.downgrade();
370
371 if let std::collections::hash_map::Entry::Vacant(entry) =
372 self.registered_buffers.entry(buffer_id)
373 {
374 let snapshot = buffer.read(cx).snapshot();
375
376 entry.insert(RegisteredBuffer {
377 snapshot,
378 _subscriptions: [
379 cx.subscribe(buffer, move |this, buffer, event, cx| {
380 this.handle_buffer_event(buffer, event, cx);
381 }),
382 cx.observe_release(buffer, move |this, _buffer, _cx| {
383 this.registered_buffers.remove(&weak_buffer.entity_id());
384 }),
385 ],
386 });
387 };
388 }
389
390 fn handle_buffer_event(
391 &mut self,
392 buffer: Entity<Buffer>,
393 event: &language::BufferEvent,
394 cx: &mut Context<Self>,
395 ) {
396 if let language::BufferEvent::Edited = event {
397 self.report_changes_for_buffer(&buffer, cx);
398 }
399 }
400
401 fn request_completion_impl<F, R>(
402 &mut self,
403 workspace: Option<Entity<Workspace>>,
404 project: Option<&Entity<Project>>,
405 buffer: &Entity<Buffer>,
406 cursor: language::Anchor,
407 can_collect_data: CanCollectData,
408 cx: &mut Context<Self>,
409 perform_predict_edits: F,
410 ) -> Task<Result<Option<EditPrediction>>>
411 where
412 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
413 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
414 + Send
415 + 'static,
416 {
417 let buffer = buffer.clone();
418 let buffer_snapshotted_at = Instant::now();
419 let snapshot = self.report_changes_for_buffer(&buffer, cx);
420 let zeta = cx.entity();
421 let events = self.events.clone();
422 let client = self.client.clone();
423 let llm_token = self.llm_token.clone();
424 let app_version = AppVersion::global(cx);
425
426 let git_info = if matches!(can_collect_data, CanCollectData(true)) {
427 self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
428 } else {
429 None
430 };
431
432 let full_path: Arc<Path> = snapshot
433 .file()
434 .map(|f| Arc::from(f.full_path(cx).as_path()))
435 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
436 let full_path_str = full_path.to_string_lossy().to_string();
437 let cursor_point = cursor.to_point(&snapshot);
438 let cursor_offset = cursor_point.to_offset(&snapshot);
439 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
440 let gather_task = gather_context(
441 project,
442 full_path_str,
443 &snapshot,
444 cursor_point,
445 make_events_prompt,
446 can_collect_data,
447 git_info,
448 cx,
449 );
450
451 cx.spawn(async move |this, cx| {
452 let GatherContextOutput {
453 body,
454 editable_range,
455 } = gather_task.await?;
456 let done_gathering_context_at = Instant::now();
457
458 log::debug!(
459 "Events:\n{}\nExcerpt:\n{:?}",
460 body.input_events,
461 body.input_excerpt
462 );
463
464 let input_outline = body.outline.clone().unwrap_or_default();
465 let input_events = body.input_events.clone();
466 let input_excerpt = body.input_excerpt.clone();
467
468 let response = perform_predict_edits(PerformPredictEditsParams {
469 client,
470 llm_token,
471 app_version,
472 body,
473 })
474 .await;
475 let (response, usage) = match response {
476 Ok(response) => response,
477 Err(err) => {
478 if err.is::<ZedUpdateRequiredError>() {
479 cx.update(|cx| {
480 zeta.update(cx, |zeta, _cx| {
481 zeta.update_required = true;
482 });
483
484 if let Some(workspace) = workspace {
485 workspace.update(cx, |workspace, cx| {
486 workspace.show_notification(
487 NotificationId::unique::<ZedUpdateRequiredError>(),
488 cx,
489 |cx| {
490 cx.new(|cx| {
491 ErrorMessagePrompt::new(err.to_string(), cx)
492 .with_link_button(
493 "Update Zed",
494 "https://zed.dev/releases",
495 )
496 })
497 },
498 );
499 });
500 }
501 })
502 .ok();
503 }
504
505 return Err(err);
506 }
507 };
508
509 let received_response_at = Instant::now();
510 log::debug!("completion response: {}", &response.output_excerpt);
511
512 if let Some(usage) = usage {
513 this.update(cx, |this, cx| {
514 this.user_store.update(cx, |user_store, cx| {
515 user_store.update_edit_prediction_usage(usage, cx);
516 });
517 })
518 .ok();
519 }
520
521 let edit_prediction = Self::process_completion_response(
522 response,
523 buffer,
524 &snapshot,
525 editable_range,
526 cursor_offset,
527 full_path,
528 input_outline,
529 input_events,
530 input_excerpt,
531 buffer_snapshotted_at,
532 cx,
533 )
534 .await;
535
536 let finished_at = Instant::now();
537
538 // record latency for ~1% of requests
539 if rand::random::<u8>() <= 2 {
540 telemetry::event!(
541 "Edit Prediction Request",
542 context_latency = done_gathering_context_at
543 .duration_since(buffer_snapshotted_at)
544 .as_millis(),
545 request_latency = received_response_at
546 .duration_since(done_gathering_context_at)
547 .as_millis(),
548 process_latency = finished_at.duration_since(received_response_at).as_millis()
549 );
550 }
551
552 edit_prediction
553 })
554 }
555
556 // Generates several example completions of various states to fill the Zeta completion modal
557 #[cfg(any(test, feature = "test-support"))]
558 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
559 use language::Point;
560
561 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
562 And maybe a short line
563
564 Then a few lines
565
566 and then another
567 "#};
568
569 let project = None;
570 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
571 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
572
573 let completion_tasks = vec![
574 self.fake_completion(
575 project,
576 &buffer,
577 position,
578 PredictEditsResponse {
579 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
580 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
581a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
582[here's an edit]
583And maybe a short line
584Then a few lines
585and then another
586{EDITABLE_REGION_END_MARKER}
587 ", ),
588 },
589 cx,
590 ),
591 self.fake_completion(
592 project,
593 &buffer,
594 position,
595 PredictEditsResponse {
596 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
597 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
598a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
599And maybe a short line
600[and another edit]
601Then a few lines
602and then another
603{EDITABLE_REGION_END_MARKER}
604 "#),
605 },
606 cx,
607 ),
608 self.fake_completion(
609 project,
610 &buffer,
611 position,
612 PredictEditsResponse {
613 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
614 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
615a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
616And maybe a short line
617
618Then a few lines
619
620and then another
621{EDITABLE_REGION_END_MARKER}
622 "#),
623 },
624 cx,
625 ),
626 self.fake_completion(
627 project,
628 &buffer,
629 position,
630 PredictEditsResponse {
631 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
632 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
633a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
634And maybe a short line
635
636Then a few lines
637
638and then another
639{EDITABLE_REGION_END_MARKER}
640 "#),
641 },
642 cx,
643 ),
644 self.fake_completion(
645 project,
646 &buffer,
647 position,
648 PredictEditsResponse {
649 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
650 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
652And maybe a short line
653Then a few lines
654[a third completion]
655and then another
656{EDITABLE_REGION_END_MARKER}
657 "#),
658 },
659 cx,
660 ),
661 self.fake_completion(
662 project,
663 &buffer,
664 position,
665 PredictEditsResponse {
666 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
667 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
668a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
669And maybe a short line
670and then another
671[fourth completion example]
672{EDITABLE_REGION_END_MARKER}
673 "#),
674 },
675 cx,
676 ),
677 self.fake_completion(
678 project,
679 &buffer,
680 position,
681 PredictEditsResponse {
682 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
683 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
684a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
685And maybe a short line
686Then a few lines
687and then another
688[fifth and final completion]
689{EDITABLE_REGION_END_MARKER}
690 "#),
691 },
692 cx,
693 ),
694 ];
695
696 cx.spawn(async move |zeta, cx| {
697 for task in completion_tasks {
698 task.await.unwrap();
699 }
700
701 zeta.update(cx, |zeta, _cx| {
702 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
703 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
704 })
705 .ok();
706 })
707 }
708
709 #[cfg(any(test, feature = "test-support"))]
710 pub fn fake_completion(
711 &mut self,
712 project: Option<&Entity<Project>>,
713 buffer: &Entity<Buffer>,
714 position: language::Anchor,
715 response: PredictEditsResponse,
716 cx: &mut Context<Self>,
717 ) -> Task<Result<Option<EditPrediction>>> {
718 use std::future::ready;
719
720 self.request_completion_impl(
721 None,
722 project,
723 buffer,
724 position,
725 CanCollectData(false),
726 cx,
727 |_params| ready(Ok((response, None))),
728 )
729 }
730
731 pub fn request_completion(
732 &mut self,
733 project: Option<&Entity<Project>>,
734 buffer: &Entity<Buffer>,
735 position: language::Anchor,
736 can_collect_data: CanCollectData,
737 cx: &mut Context<Self>,
738 ) -> Task<Result<Option<EditPrediction>>> {
739 self.request_completion_impl(
740 self.workspace.upgrade(),
741 project,
742 buffer,
743 position,
744 can_collect_data,
745 cx,
746 Self::perform_predict_edits,
747 )
748 }
749
750 pub fn perform_predict_edits(
751 params: PerformPredictEditsParams,
752 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
753 async move {
754 let PerformPredictEditsParams {
755 client,
756 llm_token,
757 app_version,
758 body,
759 ..
760 } = params;
761
762 let http_client = client.http_client();
763 let mut token = llm_token.acquire(&client).await?;
764 let mut did_retry = false;
765
766 loop {
767 let request_builder = http_client::Request::builder().method(Method::POST);
768 let request_builder =
769 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
770 request_builder.uri(predict_edits_url)
771 } else {
772 request_builder.uri(
773 http_client
774 .build_zed_llm_url("/predict_edits/v2", &[])?
775 .as_ref(),
776 )
777 };
778 let request = request_builder
779 .header("Content-Type", "application/json")
780 .header("Authorization", format!("Bearer {}", token))
781 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
782 .body(serde_json::to_string(&body)?.into())?;
783
784 let mut response = http_client.send(request).await?;
785
786 if let Some(minimum_required_version) = response
787 .headers()
788 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
789 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
790 {
791 anyhow::ensure!(
792 app_version >= minimum_required_version,
793 ZedUpdateRequiredError {
794 minimum_version: minimum_required_version
795 }
796 );
797 }
798
799 if response.status().is_success() {
800 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
801
802 let mut body = String::new();
803 response.body_mut().read_to_string(&mut body).await?;
804 return Ok((serde_json::from_str(&body)?, usage));
805 } else if !did_retry
806 && response
807 .headers()
808 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
809 .is_some()
810 {
811 did_retry = true;
812 token = llm_token.refresh(&client).await?;
813 } else {
814 let mut body = String::new();
815 response.body_mut().read_to_string(&mut body).await?;
816 anyhow::bail!(
817 "error predicting edits.\nStatus: {:?}\nBody: {}",
818 response.status(),
819 body
820 );
821 }
822 }
823 }
824 }
825
826 fn accept_edit_prediction(
827 &mut self,
828 request_id: EditPredictionId,
829 cx: &mut Context<Self>,
830 ) -> Task<Result<()>> {
831 let client = self.client.clone();
832 let llm_token = self.llm_token.clone();
833 let app_version = AppVersion::global(cx);
834 cx.spawn(async move |this, cx| {
835 let http_client = client.http_client();
836 let mut response = llm_token_retry(&llm_token, &client, |token| {
837 let request_builder = http_client::Request::builder().method(Method::POST);
838 let request_builder =
839 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
840 request_builder.uri(accept_prediction_url)
841 } else {
842 request_builder.uri(
843 http_client
844 .build_zed_llm_url("/predict_edits/accept", &[])?
845 .as_ref(),
846 )
847 };
848 Ok(request_builder
849 .header("Content-Type", "application/json")
850 .header("Authorization", format!("Bearer {}", token))
851 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
852 .body(
853 serde_json::to_string(&AcceptEditPredictionBody {
854 request_id: request_id.0,
855 })?
856 .into(),
857 )?)
858 })
859 .await?;
860
861 if let Some(minimum_required_version) = response
862 .headers()
863 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
864 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
865 && app_version < minimum_required_version
866 {
867 return Err(anyhow!(ZedUpdateRequiredError {
868 minimum_version: minimum_required_version
869 }));
870 }
871
872 if response.status().is_success() {
873 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
874 this.update(cx, |this, cx| {
875 this.user_store.update(cx, |user_store, cx| {
876 user_store.update_edit_prediction_usage(usage, cx);
877 });
878 })?;
879 }
880
881 Ok(())
882 } else {
883 let mut body = String::new();
884 response.body_mut().read_to_string(&mut body).await?;
885 Err(anyhow!(
886 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
887 response.status(),
888 body
889 ))
890 }
891 })
892 }
893
894 fn process_completion_response(
895 prediction_response: PredictEditsResponse,
896 buffer: Entity<Buffer>,
897 snapshot: &BufferSnapshot,
898 editable_range: Range<usize>,
899 cursor_offset: usize,
900 path: Arc<Path>,
901 input_outline: String,
902 input_events: String,
903 input_excerpt: String,
904 buffer_snapshotted_at: Instant,
905 cx: &AsyncApp,
906 ) -> Task<Result<Option<EditPrediction>>> {
907 let snapshot = snapshot.clone();
908 let request_id = prediction_response.request_id;
909 let output_excerpt = prediction_response.output_excerpt;
910 cx.spawn(async move |cx| {
911 let output_excerpt: Arc<str> = output_excerpt.into();
912
913 let edits: Arc<[(Range<Anchor>, String)]> = cx
914 .background_spawn({
915 let output_excerpt = output_excerpt.clone();
916 let editable_range = editable_range.clone();
917 let snapshot = snapshot.clone();
918 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
919 })
920 .await?
921 .into();
922
923 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
924 let edits = edits.clone();
925 |buffer, cx| {
926 let new_snapshot = buffer.snapshot();
927 let edits: Arc<[(Range<Anchor>, String)]> =
928 interpolate(&snapshot, &new_snapshot, edits)?.into();
929 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
930 }
931 })?
932 else {
933 return anyhow::Ok(None);
934 };
935
936 let edit_preview = edit_preview.await;
937
938 Ok(Some(EditPrediction {
939 id: EditPredictionId(request_id),
940 path,
941 excerpt_range: editable_range,
942 cursor_offset,
943 edits,
944 edit_preview,
945 snapshot,
946 input_outline: input_outline.into(),
947 input_events: input_events.into(),
948 input_excerpt: input_excerpt.into(),
949 output_excerpt,
950 buffer_snapshotted_at,
951 response_received_at: Instant::now(),
952 }))
953 })
954 }
955
956 fn parse_edits(
957 output_excerpt: Arc<str>,
958 editable_range: Range<usize>,
959 snapshot: &BufferSnapshot,
960 ) -> Result<Vec<(Range<Anchor>, String)>> {
961 let content = output_excerpt.replace(CURSOR_MARKER, "");
962
963 let start_markers = content
964 .match_indices(EDITABLE_REGION_START_MARKER)
965 .collect::<Vec<_>>();
966 anyhow::ensure!(
967 start_markers.len() == 1,
968 "expected exactly one start marker, found {}",
969 start_markers.len()
970 );
971
972 let end_markers = content
973 .match_indices(EDITABLE_REGION_END_MARKER)
974 .collect::<Vec<_>>();
975 anyhow::ensure!(
976 end_markers.len() == 1,
977 "expected exactly one end marker, found {}",
978 end_markers.len()
979 );
980
981 let sof_markers = content
982 .match_indices(START_OF_FILE_MARKER)
983 .collect::<Vec<_>>();
984 anyhow::ensure!(
985 sof_markers.len() <= 1,
986 "expected at most one start-of-file marker, found {}",
987 sof_markers.len()
988 );
989
990 let codefence_start = start_markers[0].0;
991 let content = &content[codefence_start..];
992
993 let newline_ix = content.find('\n').context("could not find newline")?;
994 let content = &content[newline_ix + 1..];
995
996 let codefence_end = content
997 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
998 .context("could not find end marker")?;
999 let new_text = &content[..codefence_end];
1000
1001 let old_text = snapshot
1002 .text_for_range(editable_range.clone())
1003 .collect::<String>();
1004
1005 Ok(Self::compute_edits(
1006 old_text,
1007 new_text,
1008 editable_range.start,
1009 snapshot,
1010 ))
1011 }
1012
1013 pub fn compute_edits(
1014 old_text: String,
1015 new_text: &str,
1016 offset: usize,
1017 snapshot: &BufferSnapshot,
1018 ) -> Vec<(Range<Anchor>, String)> {
1019 text_diff(&old_text, new_text)
1020 .into_iter()
1021 .map(|(mut old_range, new_text)| {
1022 old_range.start += offset;
1023 old_range.end += offset;
1024
1025 let prefix_len = common_prefix(
1026 snapshot.chars_for_range(old_range.clone()),
1027 new_text.chars(),
1028 );
1029 old_range.start += prefix_len;
1030
1031 let suffix_len = common_prefix(
1032 snapshot.reversed_chars_for_range(old_range.clone()),
1033 new_text[prefix_len..].chars().rev(),
1034 );
1035 old_range.end = old_range.end.saturating_sub(suffix_len);
1036
1037 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1038 let range = if old_range.is_empty() {
1039 let anchor = snapshot.anchor_after(old_range.start);
1040 anchor..anchor
1041 } else {
1042 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1043 };
1044 (range, new_text)
1045 })
1046 .collect()
1047 }
1048
1049 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1050 self.rated_completions.contains(&completion_id)
1051 }
1052
1053 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1054 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1055 let completion = self.shown_completions.pop_back().unwrap();
1056 self.rated_completions.remove(&completion.id);
1057 }
1058 self.shown_completions.push_front(completion.clone());
1059 cx.notify();
1060 }
1061
1062 pub fn rate_completion(
1063 &mut self,
1064 completion: &EditPrediction,
1065 rating: EditPredictionRating,
1066 feedback: String,
1067 cx: &mut Context<Self>,
1068 ) {
1069 self.rated_completions.insert(completion.id);
1070 telemetry::event!(
1071 "Edit Prediction Rated",
1072 rating,
1073 input_events = completion.input_events,
1074 input_excerpt = completion.input_excerpt,
1075 input_outline = completion.input_outline,
1076 output_excerpt = completion.output_excerpt,
1077 feedback
1078 );
1079 self.client.telemetry().flush_events().detach();
1080 cx.notify();
1081 }
1082
1083 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1084 self.shown_completions.iter()
1085 }
1086
1087 pub fn shown_completions_len(&self) -> usize {
1088 self.shown_completions.len()
1089 }
1090
1091 fn report_changes_for_buffer(
1092 &mut self,
1093 buffer: &Entity<Buffer>,
1094 cx: &mut Context<Self>,
1095 ) -> BufferSnapshot {
1096 self.register_buffer(buffer, cx);
1097
1098 let registered_buffer = self
1099 .registered_buffers
1100 .get_mut(&buffer.entity_id())
1101 .unwrap();
1102 let new_snapshot = buffer.read(cx).snapshot();
1103
1104 if new_snapshot.version != registered_buffer.snapshot.version {
1105 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1106 self.push_event(Event::BufferChange {
1107 old_snapshot,
1108 new_snapshot: new_snapshot.clone(),
1109 timestamp: Instant::now(),
1110 });
1111 }
1112
1113 new_snapshot
1114 }
1115
1116 fn load_data_collection_choices() -> DataCollectionChoice {
1117 let choice = KEY_VALUE_STORE
1118 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1119 .log_err()
1120 .flatten();
1121
1122 match choice.as_deref() {
1123 Some("true") => DataCollectionChoice::Enabled,
1124 Some("false") => DataCollectionChoice::Disabled,
1125 Some(_) => {
1126 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1127 DataCollectionChoice::NotAnswered
1128 }
1129 None => DataCollectionChoice::NotAnswered,
1130 }
1131 }
1132
1133 fn gather_git_info(
1134 &mut self,
1135 project: Option<&Entity<Project>>,
1136 buffer_snapshotted_at: &Instant,
1137 snapshot: &BufferSnapshot,
1138 cx: &Context<Self>,
1139 ) -> Option<PredictEditsGitInfo> {
1140 let project = project?.read(cx);
1141 let file = snapshot.file()?;
1142 let project_path = ProjectPath::from_file(file.as_ref(), cx);
1143 let entry = project.entry_for_path(&project_path, cx)?;
1144 if !worktree_entry_eligible_for_collection(&entry) {
1145 return None;
1146 }
1147
1148 let git_store = project.git_store().read(cx);
1149 let (repository, repo_path) =
1150 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1151 let repo_path_str = repo_path.to_str()?;
1152
1153 let repository = repository.read(cx);
1154 let head_sha = repository
1155 .head_commit
1156 .as_ref()
1157 .map(|head_commit| head_commit.sha.to_string());
1158 let remote_origin_url = repository.remote_origin_url.clone();
1159 let remote_upstream_url = repository.remote_upstream_url.clone();
1160 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1161 return None;
1162 }
1163
1164 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1165
1166 Some(PredictEditsGitInfo {
1167 input_path: Some(repo_path_str.to_string()),
1168 head_sha,
1169 remote_origin_url,
1170 remote_upstream_url,
1171 recent_files: Some(recent_files),
1172 })
1173 }
1174
1175 fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1176 let now = Instant::now();
1177 if let Some(existing_ix) = self
1178 .recent_project_entries
1179 .iter()
1180 .rposition(|(id, _)| *id == project_entry_id)
1181 {
1182 self.recent_project_entries.remove(existing_ix);
1183 }
1184 if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1185 self.recent_project_entries.pop_front();
1186 }
1187 self.recent_project_entries
1188 .push_back((project_entry_id, now));
1189 }
1190
1191 fn recent_files(
1192 &mut self,
1193 now: &Instant,
1194 repository: &Repository,
1195 cx: &Context<Self>,
1196 ) -> Vec<PredictEditsRecentFile> {
1197 let Ok(project) = self
1198 .workspace
1199 .read_with(cx, |workspace, _cx| workspace.project().clone())
1200 else {
1201 return Vec::new();
1202 };
1203 let mut results = Vec::new();
1204 for ix in (0..self.recent_project_entries.len()).rev() {
1205 let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1206 if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1207 && let worktree = worktree.read(cx)
1208 && let Some(entry) = worktree.entry_for_id(*entry_id)
1209 && worktree_entry_eligible_for_collection(entry)
1210 {
1211 let project_path = ProjectPath {
1212 worktree_id: worktree.id(),
1213 path: entry.path.clone(),
1214 };
1215 let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1216 else {
1217 // entry not removed since queries involving other repositories might occur later
1218 continue;
1219 };
1220 let Some(repo_path_str) = repo_path.to_str() else {
1221 // paths may not be valid UTF-8
1222 self.recent_project_entries.remove(ix);
1223 continue;
1224 };
1225 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1226 self.recent_project_entries.remove(ix);
1227 continue;
1228 }
1229 let Ok(active_to_now_ms) =
1230 now.duration_since(*last_active_at).as_millis().try_into()
1231 else {
1232 self.recent_project_entries.remove(ix);
1233 continue;
1234 };
1235 results.push(PredictEditsRecentFile {
1236 path: repo_path_str.to_string(),
1237 active_to_now_ms,
1238 });
1239 } else {
1240 self.recent_project_entries.remove(ix);
1241 }
1242 }
1243 results
1244 }
1245}
1246
1247fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1248 entry.is_file()
1249 && entry.is_created()
1250 && !entry.is_ignored
1251 && !entry.is_private
1252 && !entry.is_external
1253 && !entry.is_fifo
1254}
1255
1256pub struct PerformPredictEditsParams {
1257 pub client: Arc<Client>,
1258 pub llm_token: LlmApiToken,
1259 pub app_version: SemanticVersion,
1260 pub body: PredictEditsBody,
1261}
1262
1263#[derive(Error, Debug)]
1264#[error(
1265 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1266)]
1267pub struct ZedUpdateRequiredError {
1268 minimum_version: SemanticVersion,
1269}
1270
1271fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1272 a.zip(b)
1273 .take_while(|(a, b)| a == b)
1274 .map(|(a, _)| a.len_utf8())
1275 .sum()
1276}
1277
1278pub struct GatherContextOutput {
1279 pub body: PredictEditsBody,
1280 pub editable_range: Range<usize>,
1281}
1282
1283pub fn gather_context(
1284 project: Option<&Entity<Project>>,
1285 full_path_str: String,
1286 snapshot: &BufferSnapshot,
1287 cursor_point: language::Point,
1288 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1289 can_collect_data: CanCollectData,
1290 git_info: Option<PredictEditsGitInfo>,
1291 cx: &App,
1292) -> Task<Result<GatherContextOutput>> {
1293 let local_lsp_store =
1294 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1295 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1296 if matches!(can_collect_data, CanCollectData(true))
1297 && let Some(local_lsp_store) = local_lsp_store
1298 {
1299 snapshot
1300 .diagnostic_groups(None)
1301 .into_iter()
1302 .filter_map(|(language_server_id, diagnostic_group)| {
1303 let language_server =
1304 local_lsp_store.running_language_server_for_id(language_server_id)?;
1305 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1306 let language_server_name = language_server.name().to_string();
1307 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1308 Some((language_server_name, serialized))
1309 })
1310 .collect::<Vec<_>>()
1311 } else {
1312 Vec::new()
1313 };
1314
1315 cx.background_spawn({
1316 let snapshot = snapshot.clone();
1317 async move {
1318 let diagnostic_groups = if diagnostic_groups.is_empty()
1319 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1320 {
1321 None
1322 } else {
1323 Some(diagnostic_groups)
1324 };
1325
1326 let input_excerpt = excerpt_for_cursor_position(
1327 cursor_point,
1328 &full_path_str,
1329 &snapshot,
1330 MAX_REWRITE_TOKENS,
1331 MAX_CONTEXT_TOKENS,
1332 );
1333 let input_events = make_events_prompt();
1334 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1335
1336 let body = PredictEditsBody {
1337 input_events,
1338 input_excerpt: input_excerpt.prompt,
1339 can_collect_data: can_collect_data.0,
1340 diagnostic_groups,
1341 git_info,
1342 outline: None,
1343 speculated_output: None,
1344 };
1345
1346 Ok(GatherContextOutput {
1347 body,
1348 editable_range,
1349 })
1350 }
1351 })
1352}
1353
1354fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1355 let mut result = String::new();
1356 for event in events.iter().rev() {
1357 let event_string = event.to_prompt();
1358 let event_tokens = tokens_for_bytes(event_string.len());
1359 if event_tokens > remaining_tokens {
1360 break;
1361 }
1362
1363 if !result.is_empty() {
1364 result.insert_str(0, "\n\n");
1365 }
1366 result.insert_str(0, &event_string);
1367 remaining_tokens -= event_tokens;
1368 }
1369 result
1370}
1371
1372struct RegisteredBuffer {
1373 snapshot: BufferSnapshot,
1374 _subscriptions: [gpui::Subscription; 2],
1375}
1376
1377#[derive(Clone)]
1378pub enum Event {
1379 BufferChange {
1380 old_snapshot: BufferSnapshot,
1381 new_snapshot: BufferSnapshot,
1382 timestamp: Instant,
1383 },
1384}
1385
1386impl Event {
1387 fn to_prompt(&self) -> String {
1388 match self {
1389 Event::BufferChange {
1390 old_snapshot,
1391 new_snapshot,
1392 ..
1393 } => {
1394 let mut prompt = String::new();
1395
1396 let old_path = old_snapshot
1397 .file()
1398 .map(|f| f.path().as_ref())
1399 .unwrap_or(Path::new("untitled"));
1400 let new_path = new_snapshot
1401 .file()
1402 .map(|f| f.path().as_ref())
1403 .unwrap_or(Path::new("untitled"));
1404 if old_path != new_path {
1405 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1406 }
1407
1408 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1409 if !diff.is_empty() {
1410 write!(
1411 prompt,
1412 "User edited {:?}:\n```diff\n{}\n```",
1413 new_path, diff
1414 )
1415 .unwrap();
1416 }
1417
1418 prompt
1419 }
1420 }
1421 }
1422}
1423
1424#[derive(Debug, Clone)]
1425struct CurrentEditPrediction {
1426 buffer_id: EntityId,
1427 completion: EditPrediction,
1428}
1429
1430impl CurrentEditPrediction {
1431 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1432 if self.buffer_id != old_completion.buffer_id {
1433 return true;
1434 }
1435
1436 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1437 return true;
1438 };
1439 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1440 return false;
1441 };
1442
1443 if old_edits.len() == 1 && new_edits.len() == 1 {
1444 let (old_range, old_text) = &old_edits[0];
1445 let (new_range, new_text) = &new_edits[0];
1446 new_range == old_range && new_text.starts_with(old_text)
1447 } else {
1448 true
1449 }
1450 }
1451}
1452
1453struct PendingCompletion {
1454 id: usize,
1455 _task: Task<()>,
1456}
1457
1458#[derive(Debug, Clone, Copy)]
1459pub enum DataCollectionChoice {
1460 NotAnswered,
1461 Enabled,
1462 Disabled,
1463}
1464
1465impl DataCollectionChoice {
1466 pub fn is_enabled(self) -> bool {
1467 match self {
1468 Self::Enabled => true,
1469 Self::NotAnswered | Self::Disabled => false,
1470 }
1471 }
1472
1473 pub fn is_answered(self) -> bool {
1474 match self {
1475 Self::Enabled | Self::Disabled => true,
1476 Self::NotAnswered => false,
1477 }
1478 }
1479
1480 pub fn toggle(&self) -> DataCollectionChoice {
1481 match self {
1482 Self::Enabled => Self::Disabled,
1483 Self::Disabled => Self::Enabled,
1484 Self::NotAnswered => Self::Enabled,
1485 }
1486 }
1487}
1488
1489impl From<bool> for DataCollectionChoice {
1490 fn from(value: bool) -> Self {
1491 match value {
1492 true => DataCollectionChoice::Enabled,
1493 false => DataCollectionChoice::Disabled,
1494 }
1495 }
1496}
1497
1498pub struct ProviderDataCollection {
1499 /// When set to None, data collection is not possible in the provider buffer
1500 choice: Option<Entity<DataCollectionChoice>>,
1501 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1502}
1503
1504#[derive(Debug, Clone, Copy)]
1505pub struct CanCollectData(pub bool);
1506
1507impl ProviderDataCollection {
1508 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1509 let choice_and_watcher = buffer.and_then(|buffer| {
1510 let file = buffer.read(cx).file()?;
1511
1512 if !file.is_local() || file.is_private() {
1513 return None;
1514 }
1515
1516 let zeta = zeta.read(cx);
1517 let choice = zeta.data_collection_choice.clone();
1518
1519 let license_detection_watcher = zeta
1520 .license_detection_watchers
1521 .get(&file.worktree_id(cx))
1522 .cloned()?;
1523
1524 Some((choice, license_detection_watcher))
1525 });
1526
1527 if let Some((choice, watcher)) = choice_and_watcher {
1528 ProviderDataCollection {
1529 choice: Some(choice),
1530 license_detection_watcher: Some(watcher),
1531 }
1532 } else {
1533 ProviderDataCollection {
1534 choice: None,
1535 license_detection_watcher: None,
1536 }
1537 }
1538 }
1539
1540 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1541 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1542 }
1543
1544 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1545 self.choice
1546 .as_ref()
1547 .is_some_and(|choice| choice.read(cx).is_enabled())
1548 }
1549
1550 fn is_project_open_source(&self) -> bool {
1551 self.license_detection_watcher
1552 .as_ref()
1553 .is_some_and(|watcher| watcher.is_project_open_source())
1554 }
1555
1556 pub fn toggle(&mut self, cx: &mut App) {
1557 if let Some(choice) = self.choice.as_mut() {
1558 let new_choice = choice.update(cx, |choice, _cx| {
1559 let new_choice = choice.toggle();
1560 *choice = new_choice;
1561 new_choice
1562 });
1563
1564 db::write_and_log(cx, move || {
1565 KEY_VALUE_STORE.write_kvp(
1566 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1567 new_choice.is_enabled().to_string(),
1568 )
1569 });
1570 }
1571 }
1572}
1573
1574async fn llm_token_retry(
1575 llm_token: &LlmApiToken,
1576 client: &Arc<Client>,
1577 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1578) -> Result<Response<AsyncBody>> {
1579 let mut did_retry = false;
1580 let http_client = client.http_client();
1581 let mut token = llm_token.acquire(client).await?;
1582 loop {
1583 let request = build_request(token.clone())?;
1584 let response = http_client.send(request).await?;
1585
1586 if !did_retry
1587 && !response.status().is_success()
1588 && response
1589 .headers()
1590 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1591 .is_some()
1592 {
1593 did_retry = true;
1594 token = llm_token.refresh(client).await?;
1595 continue;
1596 }
1597
1598 return Ok(response);
1599 }
1600}
1601
1602pub struct ZetaEditPredictionProvider {
1603 zeta: Entity<Zeta>,
1604 pending_completions: ArrayVec<PendingCompletion, 2>,
1605 next_pending_completion_id: usize,
1606 current_completion: Option<CurrentEditPrediction>,
1607 /// None if this is entirely disabled for this provider
1608 provider_data_collection: ProviderDataCollection,
1609 last_request_timestamp: Instant,
1610}
1611
1612impl ZetaEditPredictionProvider {
1613 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1614
1615 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1616 Self {
1617 zeta,
1618 pending_completions: ArrayVec::new(),
1619 next_pending_completion_id: 0,
1620 current_completion: None,
1621 provider_data_collection,
1622 last_request_timestamp: Instant::now(),
1623 }
1624 }
1625}
1626
1627impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1628 fn name() -> &'static str {
1629 "zed-predict"
1630 }
1631
1632 fn display_name() -> &'static str {
1633 "Zed's Edit Predictions"
1634 }
1635
1636 fn show_completions_in_menu() -> bool {
1637 true
1638 }
1639
1640 fn show_tab_accept_marker() -> bool {
1641 true
1642 }
1643
1644 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1645 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1646
1647 if self.provider_data_collection.is_data_collection_enabled(cx) {
1648 DataCollectionState::Enabled {
1649 is_project_open_source,
1650 }
1651 } else {
1652 DataCollectionState::Disabled {
1653 is_project_open_source,
1654 }
1655 }
1656 }
1657
1658 fn toggle_data_collection(&mut self, cx: &mut App) {
1659 self.provider_data_collection.toggle(cx);
1660 }
1661
1662 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1663 self.zeta.read(cx).usage(cx)
1664 }
1665
1666 fn is_enabled(
1667 &self,
1668 _buffer: &Entity<Buffer>,
1669 _cursor_position: language::Anchor,
1670 _cx: &App,
1671 ) -> bool {
1672 true
1673 }
1674 fn is_refreshing(&self) -> bool {
1675 !self.pending_completions.is_empty()
1676 }
1677
1678 fn refresh(
1679 &mut self,
1680 project: Option<Entity<Project>>,
1681 buffer: Entity<Buffer>,
1682 position: language::Anchor,
1683 _debounce: bool,
1684 cx: &mut Context<Self>,
1685 ) {
1686 if self.zeta.read(cx).update_required {
1687 return;
1688 }
1689
1690 if self
1691 .zeta
1692 .read(cx)
1693 .user_store
1694 .read_with(cx, |user_store, _cx| {
1695 user_store.account_too_young() || user_store.has_overdue_invoices()
1696 })
1697 {
1698 return;
1699 }
1700
1701 if let Some(current_completion) = self.current_completion.as_ref() {
1702 let snapshot = buffer.read(cx).snapshot();
1703 if current_completion
1704 .completion
1705 .interpolate(&snapshot)
1706 .is_some()
1707 {
1708 return;
1709 }
1710 }
1711
1712 let pending_completion_id = self.next_pending_completion_id;
1713 self.next_pending_completion_id += 1;
1714 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1715 let last_request_timestamp = self.last_request_timestamp;
1716
1717 let task = cx.spawn(async move |this, cx| {
1718 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1719 .checked_duration_since(Instant::now())
1720 {
1721 cx.background_executor().timer(timeout).await;
1722 }
1723
1724 let completion_request = this.update(cx, |this, cx| {
1725 this.last_request_timestamp = Instant::now();
1726 this.zeta.update(cx, |zeta, cx| {
1727 zeta.request_completion(
1728 project.as_ref(),
1729 &buffer,
1730 position,
1731 can_collect_data,
1732 cx,
1733 )
1734 })
1735 });
1736
1737 let completion = match completion_request {
1738 Ok(completion_request) => {
1739 let completion_request = completion_request.await;
1740 completion_request.map(|c| {
1741 c.map(|completion| CurrentEditPrediction {
1742 buffer_id: buffer.entity_id(),
1743 completion,
1744 })
1745 })
1746 }
1747 Err(error) => Err(error),
1748 };
1749 let Some(new_completion) = completion
1750 .context("edit prediction failed")
1751 .log_err()
1752 .flatten()
1753 else {
1754 this.update(cx, |this, cx| {
1755 if this.pending_completions[0].id == pending_completion_id {
1756 this.pending_completions.remove(0);
1757 } else {
1758 this.pending_completions.clear();
1759 }
1760
1761 cx.notify();
1762 })
1763 .ok();
1764 return;
1765 };
1766
1767 this.update(cx, |this, cx| {
1768 if this.pending_completions[0].id == pending_completion_id {
1769 this.pending_completions.remove(0);
1770 } else {
1771 this.pending_completions.clear();
1772 }
1773
1774 if let Some(old_completion) = this.current_completion.as_ref() {
1775 let snapshot = buffer.read(cx).snapshot();
1776 if new_completion.should_replace_completion(old_completion, &snapshot) {
1777 this.zeta.update(cx, |zeta, cx| {
1778 zeta.completion_shown(&new_completion.completion, cx);
1779 });
1780 this.current_completion = Some(new_completion);
1781 }
1782 } else {
1783 this.zeta.update(cx, |zeta, cx| {
1784 zeta.completion_shown(&new_completion.completion, cx);
1785 });
1786 this.current_completion = Some(new_completion);
1787 }
1788
1789 cx.notify();
1790 })
1791 .ok();
1792 });
1793
1794 // We always maintain at most two pending completions. When we already
1795 // have two, we replace the newest one.
1796 if self.pending_completions.len() <= 1 {
1797 self.pending_completions.push(PendingCompletion {
1798 id: pending_completion_id,
1799 _task: task,
1800 });
1801 } else if self.pending_completions.len() == 2 {
1802 self.pending_completions.pop();
1803 self.pending_completions.push(PendingCompletion {
1804 id: pending_completion_id,
1805 _task: task,
1806 });
1807 }
1808 }
1809
1810 fn cycle(
1811 &mut self,
1812 _buffer: Entity<Buffer>,
1813 _cursor_position: language::Anchor,
1814 _direction: edit_prediction::Direction,
1815 _cx: &mut Context<Self>,
1816 ) {
1817 // Right now we don't support cycling.
1818 }
1819
1820 fn accept(&mut self, cx: &mut Context<Self>) {
1821 let completion_id = self
1822 .current_completion
1823 .as_ref()
1824 .map(|completion| completion.completion.id);
1825 if let Some(completion_id) = completion_id {
1826 self.zeta
1827 .update(cx, |zeta, cx| {
1828 zeta.accept_edit_prediction(completion_id, cx)
1829 })
1830 .detach();
1831 }
1832 self.pending_completions.clear();
1833 }
1834
1835 fn discard(&mut self, _cx: &mut Context<Self>) {
1836 self.pending_completions.clear();
1837 self.current_completion.take();
1838 }
1839
1840 fn suggest(
1841 &mut self,
1842 buffer: &Entity<Buffer>,
1843 cursor_position: language::Anchor,
1844 cx: &mut Context<Self>,
1845 ) -> Option<edit_prediction::EditPrediction> {
1846 let CurrentEditPrediction {
1847 buffer_id,
1848 completion,
1849 ..
1850 } = self.current_completion.as_mut()?;
1851
1852 // Invalidate previous completion if it was generated for a different buffer.
1853 if *buffer_id != buffer.entity_id() {
1854 self.current_completion.take();
1855 return None;
1856 }
1857
1858 let buffer = buffer.read(cx);
1859 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1860 self.current_completion.take();
1861 return None;
1862 };
1863
1864 let cursor_row = cursor_position.to_point(buffer).row;
1865 let (closest_edit_ix, (closest_edit_range, _)) =
1866 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1867 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1868 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1869 cmp::min(distance_from_start, distance_from_end)
1870 })?;
1871
1872 let mut edit_start_ix = closest_edit_ix;
1873 for (range, _) in edits[..edit_start_ix].iter().rev() {
1874 let distance_from_closest_edit =
1875 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1876 if distance_from_closest_edit <= 1 {
1877 edit_start_ix -= 1;
1878 } else {
1879 break;
1880 }
1881 }
1882
1883 let mut edit_end_ix = closest_edit_ix + 1;
1884 for (range, _) in &edits[edit_end_ix..] {
1885 let distance_from_closest_edit =
1886 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1887 if distance_from_closest_edit <= 1 {
1888 edit_end_ix += 1;
1889 } else {
1890 break;
1891 }
1892 }
1893
1894 Some(edit_prediction::EditPrediction {
1895 id: Some(completion.id.to_string().into()),
1896 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1897 edit_preview: Some(completion.edit_preview.clone()),
1898 })
1899 }
1900}
1901
1902fn tokens_for_bytes(bytes: usize) -> usize {
1903 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1904 /// intentionally low to err on the side of underestimating limits.
1905 const BYTES_PER_TOKEN_GUESS: usize = 3;
1906 bytes / BYTES_PER_TOKEN_GUESS
1907}
1908
1909#[cfg(test)]
1910mod tests {
1911 use client::UserStore;
1912 use client::test::FakeServer;
1913 use clock::FakeSystemClock;
1914 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1915 use gpui::TestAppContext;
1916 use http_client::FakeHttpClient;
1917 use indoc::indoc;
1918 use language::Point;
1919 use settings::SettingsStore;
1920
1921 use super::*;
1922
1923 #[gpui::test]
1924 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1925 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1926 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1927 to_completion_edits(
1928 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1929 &buffer,
1930 cx,
1931 )
1932 .into()
1933 });
1934
1935 let edit_preview = cx
1936 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1937 .await;
1938
1939 let completion = EditPrediction {
1940 edits,
1941 edit_preview,
1942 path: Path::new("").into(),
1943 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1944 id: EditPredictionId(Uuid::new_v4()),
1945 excerpt_range: 0..0,
1946 cursor_offset: 0,
1947 input_outline: "".into(),
1948 input_events: "".into(),
1949 input_excerpt: "".into(),
1950 output_excerpt: "".into(),
1951 buffer_snapshotted_at: Instant::now(),
1952 response_received_at: Instant::now(),
1953 };
1954
1955 cx.update(|cx| {
1956 assert_eq!(
1957 from_completion_edits(
1958 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1959 &buffer,
1960 cx
1961 ),
1962 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1963 );
1964
1965 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1966 assert_eq!(
1967 from_completion_edits(
1968 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1969 &buffer,
1970 cx
1971 ),
1972 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1973 );
1974
1975 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1976 assert_eq!(
1977 from_completion_edits(
1978 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1979 &buffer,
1980 cx
1981 ),
1982 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1983 );
1984
1985 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1986 assert_eq!(
1987 from_completion_edits(
1988 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1989 &buffer,
1990 cx
1991 ),
1992 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1993 );
1994
1995 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1996 assert_eq!(
1997 from_completion_edits(
1998 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1999 &buffer,
2000 cx
2001 ),
2002 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2003 );
2004
2005 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2006 assert_eq!(
2007 from_completion_edits(
2008 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2009 &buffer,
2010 cx
2011 ),
2012 vec![(9..11, "".to_string())]
2013 );
2014
2015 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2016 assert_eq!(
2017 from_completion_edits(
2018 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2019 &buffer,
2020 cx
2021 ),
2022 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2023 );
2024
2025 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2026 assert_eq!(
2027 from_completion_edits(
2028 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2029 &buffer,
2030 cx
2031 ),
2032 vec![(4..4, "M".to_string())]
2033 );
2034
2035 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2036 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2037 })
2038 }
2039
2040 #[gpui::test]
2041 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2042 cx.update(|cx| {
2043 let settings_store = SettingsStore::test(cx);
2044 cx.set_global(settings_store);
2045 client::init_settings(cx);
2046 });
2047
2048 let edits = edits_for_prediction(
2049 indoc! {"
2050 fn main() {
2051 let word_1 = \"lorem\";
2052 let range = word.len()..word.len();
2053 }
2054 "},
2055 indoc! {"
2056 <|editable_region_start|>
2057 fn main() {
2058 let word_1 = \"lorem\";
2059 let range = word_1.len()..word_1.len();
2060 }
2061
2062 <|editable_region_end|>
2063 "},
2064 cx,
2065 )
2066 .await;
2067 assert_eq!(
2068 edits,
2069 [
2070 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2071 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2072 ]
2073 );
2074
2075 let edits = edits_for_prediction(
2076 indoc! {"
2077 fn main() {
2078 let story = \"the quick\"
2079 }
2080 "},
2081 indoc! {"
2082 <|editable_region_start|>
2083 fn main() {
2084 let story = \"the quick brown fox jumps over the lazy dog\";
2085 }
2086
2087 <|editable_region_end|>
2088 "},
2089 cx,
2090 )
2091 .await;
2092 assert_eq!(
2093 edits,
2094 [
2095 (
2096 Point::new(1, 26)..Point::new(1, 26),
2097 " brown fox jumps over the lazy dog".to_string()
2098 ),
2099 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2100 ]
2101 );
2102 }
2103
2104 #[gpui::test]
2105 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2106 cx.update(|cx| {
2107 let settings_store = SettingsStore::test(cx);
2108 cx.set_global(settings_store);
2109 client::init_settings(cx);
2110 });
2111
2112 let buffer_content = "lorem\n";
2113 let completion_response = indoc! {"
2114 ```animals.js
2115 <|start_of_file|>
2116 <|editable_region_start|>
2117 lorem
2118 ipsum
2119 <|editable_region_end|>
2120 ```"};
2121
2122 let http_client = FakeHttpClient::create(move |req| async move {
2123 match (req.method(), req.uri().path()) {
2124 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2125 .status(200)
2126 .body(
2127 serde_json::to_string(&CreateLlmTokenResponse {
2128 token: LlmToken("the-llm-token".to_string()),
2129 })
2130 .unwrap()
2131 .into(),
2132 )
2133 .unwrap()),
2134 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2135 .status(200)
2136 .body(
2137 serde_json::to_string(&PredictEditsResponse {
2138 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2139 .unwrap(),
2140 output_excerpt: completion_response.to_string(),
2141 })
2142 .unwrap()
2143 .into(),
2144 )
2145 .unwrap()),
2146 _ => Ok(http_client::Response::builder()
2147 .status(404)
2148 .body("Not Found".into())
2149 .unwrap()),
2150 }
2151 });
2152
2153 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2154 cx.update(|cx| {
2155 RefreshLlmTokenListener::register(client.clone(), cx);
2156 });
2157 // Construct the fake server to authenticate.
2158 let _server = FakeServer::for_client(42, &client, cx).await;
2159 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2160 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2161
2162 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2163 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2164 let completion_task = zeta.update(cx, |zeta, cx| {
2165 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2166 });
2167
2168 let completion = completion_task.await.unwrap().unwrap();
2169 buffer.update(cx, |buffer, cx| {
2170 buffer.edit(completion.edits.iter().cloned(), None, cx)
2171 });
2172 assert_eq!(
2173 buffer.read_with(cx, |buffer, _| buffer.text()),
2174 "lorem\nipsum"
2175 );
2176 }
2177
2178 async fn edits_for_prediction(
2179 buffer_content: &str,
2180 completion_response: &str,
2181 cx: &mut TestAppContext,
2182 ) -> Vec<(Range<Point>, String)> {
2183 let completion_response = completion_response.to_string();
2184 let http_client = FakeHttpClient::create(move |req| {
2185 let completion = completion_response.clone();
2186 async move {
2187 match (req.method(), req.uri().path()) {
2188 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2189 .status(200)
2190 .body(
2191 serde_json::to_string(&CreateLlmTokenResponse {
2192 token: LlmToken("the-llm-token".to_string()),
2193 })
2194 .unwrap()
2195 .into(),
2196 )
2197 .unwrap()),
2198 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2199 .status(200)
2200 .body(
2201 serde_json::to_string(&PredictEditsResponse {
2202 request_id: Uuid::new_v4(),
2203 output_excerpt: completion,
2204 })
2205 .unwrap()
2206 .into(),
2207 )
2208 .unwrap()),
2209 _ => Ok(http_client::Response::builder()
2210 .status(404)
2211 .body("Not Found".into())
2212 .unwrap()),
2213 }
2214 }
2215 });
2216
2217 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2218 cx.update(|cx| {
2219 RefreshLlmTokenListener::register(client.clone(), cx);
2220 });
2221 // Construct the fake server to authenticate.
2222 let _server = FakeServer::for_client(42, &client, cx).await;
2223 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2224 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2225
2226 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2227 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2228 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2229 let completion_task = zeta.update(cx, |zeta, cx| {
2230 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2231 });
2232
2233 let completion = completion_task.await.unwrap().unwrap();
2234 completion
2235 .edits
2236 .iter()
2237 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2238 .collect::<Vec<_>>()
2239 }
2240
2241 fn to_completion_edits(
2242 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2243 buffer: &Entity<Buffer>,
2244 cx: &App,
2245 ) -> Vec<(Range<Anchor>, String)> {
2246 let buffer = buffer.read(cx);
2247 iterator
2248 .into_iter()
2249 .map(|(range, text)| {
2250 (
2251 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2252 text,
2253 )
2254 })
2255 .collect()
2256 }
2257
2258 fn from_completion_edits(
2259 editor_edits: &[(Range<Anchor>, String)],
2260 buffer: &Entity<Buffer>,
2261 cx: &App,
2262 ) -> Vec<(Range<usize>, String)> {
2263 let buffer = buffer.read(cx);
2264 editor_edits
2265 .iter()
2266 .map(|(range, text)| {
2267 (
2268 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2269 text.clone(),
2270 )
2271 })
2272 .collect()
2273 }
2274
2275 #[ctor::ctor]
2276 fn init_logger() {
2277 zlog::init_test();
2278 }
2279}