1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
24 ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use project::{Project, ProjectPath};
39use release_channel::AppVersion;
40use settings::WorktreeId;
41use std::str::FromStr;
42use std::{
43 cmp,
44 fmt::Write,
45 future::Future,
46 mem,
47 ops::Range,
48 path::Path,
49 rc::Rc,
50 sync::Arc,
51 time::{Duration, Instant},
52};
53use telemetry_events::EditPredictionRating;
54use thiserror::Error;
55use util::{ResultExt, maybe};
56use uuid::Uuid;
57use workspace::Workspace;
58use workspace::notifications::{ErrorMessagePrompt, NotificationId};
59use worktree::Worktree;
60
61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71const MAX_DIAGNOSTIC_GROUPS: usize = 10;
72
73/// Maximum number of events to track.
74const MAX_EVENT_COUNT: usize = 16;
75
76/// Maximum number of recent files to track.
77const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
78
79/// Minimum number of milliseconds between recent project entries to keep them
80const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
81
82/// Maximum file path length to include in recent files list.
83const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
84
85/// Maximum number of edit predictions to store for feedback.
86const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
87
88actions!(
89 edit_prediction,
90 [
91 /// Clears the edit prediction history.
92 ClearHistory
93 ]
94);
95
96#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
97pub struct EditPredictionId(Uuid);
98
99impl From<EditPredictionId> for gpui::ElementId {
100 fn from(value: EditPredictionId) -> Self {
101 gpui::ElementId::Uuid(value.0)
102 }
103}
104
105impl std::fmt::Display for EditPredictionId {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "{}", self.0)
108 }
109}
110
111struct ZedPredictUpsell;
112
113impl Dismissable for ZedPredictUpsell {
114 const KEY: &'static str = "dismissed-edit-predict-upsell";
115
116 fn dismissed() -> bool {
117 // To make this backwards compatible with older versions of Zed, we
118 // check if the user has seen the previous Edit Prediction Onboarding
119 // before, by checking the data collection choice which was written to
120 // the database once the user clicked on "Accept and Enable"
121 if KEY_VALUE_STORE
122 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
123 .log_err()
124 .is_some_and(|s| s.is_some())
125 {
126 return true;
127 }
128
129 KEY_VALUE_STORE
130 .read_kvp(Self::KEY)
131 .log_err()
132 .is_some_and(|s| s.is_some())
133 }
134}
135
136pub fn should_show_upsell_modal() -> bool {
137 !ZedPredictUpsell::dismissed()
138}
139
140#[derive(Clone)]
141struct ZetaGlobal(Entity<Zeta>);
142
143impl Global for ZetaGlobal {}
144
145#[derive(Clone)]
146pub struct EditPrediction {
147 id: EditPredictionId,
148 path: Arc<Path>,
149 excerpt_range: Range<usize>,
150 cursor_offset: usize,
151 edits: Arc<[(Range<Anchor>, String)]>,
152 snapshot: BufferSnapshot,
153 edit_preview: EditPreview,
154 input_outline: Arc<str>,
155 input_events: Arc<str>,
156 input_excerpt: Arc<str>,
157 output_excerpt: Arc<str>,
158 buffer_snapshotted_at: Instant,
159 response_received_at: Instant,
160}
161
162impl EditPrediction {
163 fn latency(&self) -> Duration {
164 self.response_received_at
165 .duration_since(self.buffer_snapshotted_at)
166 }
167
168 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
169 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
170 }
171}
172
173fn interpolate(
174 old_snapshot: &BufferSnapshot,
175 new_snapshot: &BufferSnapshot,
176 current_edits: Arc<[(Range<Anchor>, String)]>,
177) -> Option<Vec<(Range<Anchor>, String)>> {
178 let mut edits = Vec::new();
179
180 let mut model_edits = current_edits.iter().peekable();
181 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
182 while let Some((model_old_range, _)) = model_edits.peek() {
183 let model_old_range = model_old_range.to_offset(old_snapshot);
184 if model_old_range.end < user_edit.old.start {
185 let (model_old_range, model_new_text) = model_edits.next().unwrap();
186 edits.push((model_old_range.clone(), model_new_text.clone()));
187 } else {
188 break;
189 }
190 }
191
192 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
193 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
194 if user_edit.old == model_old_offset_range {
195 let user_new_text = new_snapshot
196 .text_for_range(user_edit.new.clone())
197 .collect::<String>();
198
199 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
200 if !model_suffix.is_empty() {
201 let anchor = old_snapshot.anchor_after(user_edit.old.end);
202 edits.push((anchor..anchor, model_suffix.to_string()));
203 }
204
205 model_edits.next();
206 continue;
207 }
208 }
209 }
210
211 return None;
212 }
213
214 edits.extend(model_edits.cloned());
215
216 if edits.is_empty() { None } else { Some(edits) }
217}
218
219impl std::fmt::Debug for EditPrediction {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 f.debug_struct("EditPrediction")
222 .field("id", &self.id)
223 .field("path", &self.path)
224 .field("edits", &self.edits)
225 .finish_non_exhaustive()
226 }
227}
228
229pub struct Zeta {
230 workspace: WeakEntity<Workspace>,
231 client: Arc<Client>,
232 events: VecDeque<Event>,
233 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
234 shown_completions: VecDeque<EditPrediction>,
235 rated_completions: HashSet<EditPredictionId>,
236 data_collection_choice: Entity<DataCollectionChoice>,
237 llm_token: LlmApiToken,
238 _llm_token_subscription: Subscription,
239 /// Whether an update to a newer version of Zed is required to continue using Zeta.
240 update_required: bool,
241 user_store: Entity<UserStore>,
242 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
243 recent_editors: VecDeque<RecentEditor>,
244}
245
246struct RecentEditor {
247 editor: WeakEntity<Editor>,
248 last_active_at: Instant,
249}
250
251impl Zeta {
252 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
253 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
254 }
255
256 pub fn register(
257 workspace: Option<Entity<Workspace>>,
258 worktree: Option<Entity<Worktree>>,
259 client: Arc<Client>,
260 user_store: Entity<UserStore>,
261 cx: &mut App,
262 ) -> Entity<Self> {
263 let this = Self::global(cx).unwrap_or_else(|| {
264 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
265 cx.set_global(ZetaGlobal(entity.clone()));
266 entity
267 });
268
269 this.update(cx, move |this, cx| {
270 if let Some(worktree) = worktree {
271 let worktree_id = worktree.read(cx).id();
272 this.license_detection_watchers
273 .entry(worktree_id)
274 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
275 }
276 });
277
278 this
279 }
280
281 pub fn clear_history(&mut self) {
282 self.events.clear();
283 }
284
285 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
286 self.user_store.read(cx).edit_prediction_usage()
287 }
288
289 fn new(
290 workspace: Option<Entity<Workspace>>,
291 client: Arc<Client>,
292 user_store: Entity<UserStore>,
293 cx: &mut Context<Self>,
294 ) -> Self {
295 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
296
297 let data_collection_choice = Self::load_data_collection_choices();
298 let data_collection_choice = cx.new(|_| data_collection_choice);
299
300 if let Some(workspace) = &workspace {
301 cx.subscribe(workspace, |this, _workspace, event, cx| match event {
302 workspace::Event::ActiveItemChanged => {
303 this.handle_active_workspace_item_changed(cx)
304 }
305 _ => {}
306 })
307 .detach();
308 }
309
310 Self {
311 workspace: workspace.map_or_else(
312 || WeakEntity::new_invalid(),
313 |workspace| workspace.downgrade(),
314 ),
315 client,
316 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
317 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
318 rated_completions: HashSet::default(),
319 registered_buffers: HashMap::default(),
320 data_collection_choice,
321 llm_token: LlmApiToken::default(),
322 _llm_token_subscription: cx.subscribe(
323 &refresh_llm_token_listener,
324 |this, _listener, _event, cx| {
325 let client = this.client.clone();
326 let llm_token = this.llm_token.clone();
327 cx.spawn(async move |_this, _cx| {
328 llm_token.refresh(&client).await?;
329 anyhow::Ok(())
330 })
331 .detach_and_log_err(cx);
332 },
333 ),
334 update_required: false,
335 license_detection_watchers: HashMap::default(),
336 user_store,
337 recent_editors: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
338 }
339 }
340
341 fn push_event(&mut self, event: Event) {
342 if let Some(Event::BufferChange {
343 new_snapshot: last_new_snapshot,
344 timestamp: last_timestamp,
345 ..
346 }) = self.events.back_mut()
347 {
348 // Coalesce edits for the same buffer when they happen one after the other.
349 let Event::BufferChange {
350 old_snapshot,
351 new_snapshot,
352 timestamp,
353 } = &event;
354
355 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
356 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
357 && old_snapshot.version == last_new_snapshot.version
358 {
359 *last_new_snapshot = new_snapshot.clone();
360 *last_timestamp = *timestamp;
361 return;
362 }
363 }
364
365 if self.events.len() >= MAX_EVENT_COUNT {
366 // These are halved instead of popping to improve prompt caching.
367 self.events.drain(..MAX_EVENT_COUNT / 2);
368 }
369
370 self.events.push_back(event);
371 }
372
373 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
374 let buffer_id = buffer.entity_id();
375 let weak_buffer = buffer.downgrade();
376
377 if let std::collections::hash_map::Entry::Vacant(entry) =
378 self.registered_buffers.entry(buffer_id)
379 {
380 let snapshot = buffer.read(cx).snapshot();
381
382 entry.insert(RegisteredBuffer {
383 snapshot,
384 _subscriptions: [
385 cx.subscribe(buffer, move |this, buffer, event, cx| {
386 this.handle_buffer_event(buffer, event, cx);
387 }),
388 cx.observe_release(buffer, move |this, _buffer, _cx| {
389 this.registered_buffers.remove(&weak_buffer.entity_id());
390 }),
391 ],
392 });
393 };
394 }
395
396 fn handle_buffer_event(
397 &mut self,
398 buffer: Entity<Buffer>,
399 event: &language::BufferEvent,
400 cx: &mut Context<Self>,
401 ) {
402 if let language::BufferEvent::Edited = event {
403 self.report_changes_for_buffer(&buffer, cx);
404 }
405 }
406
407 fn request_completion_impl<F, R>(
408 &mut self,
409 workspace: Option<Entity<Workspace>>,
410 project: Option<&Entity<Project>>,
411 buffer: &Entity<Buffer>,
412 cursor: language::Anchor,
413 can_collect_data: CanCollectData,
414 cx: &mut Context<Self>,
415 perform_predict_edits: F,
416 ) -> Task<Result<Option<EditPrediction>>>
417 where
418 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
419 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
420 + Send
421 + 'static,
422 {
423 let buffer = buffer.clone();
424 let buffer_snapshotted_at = Instant::now();
425 let snapshot = self.report_changes_for_buffer(&buffer, cx);
426 let zeta = cx.entity();
427 let events = self.events.clone();
428 let client = self.client.clone();
429 let llm_token = self.llm_token.clone();
430 let app_version = AppVersion::global(cx);
431
432 let cursor_point = cursor.to_point(&snapshot);
433 let cursor_offset = cursor_point.to_offset(&snapshot);
434 let git_info = if matches!(can_collect_data, CanCollectData(true)) {
435 self.gather_git_info(
436 cursor_point.clone(),
437 &buffer_snapshotted_at,
438 &snapshot,
439 project.clone(),
440 cx,
441 )
442 } else {
443 None
444 };
445
446 let full_path: Arc<Path> = snapshot
447 .file()
448 .map(|f| Arc::from(f.full_path(cx).as_path()))
449 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
450 let full_path_str = full_path.to_string_lossy().to_string();
451 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
452 let gather_task = gather_context(
453 project,
454 full_path_str,
455 &snapshot,
456 cursor_point,
457 make_events_prompt,
458 can_collect_data,
459 git_info,
460 cx,
461 );
462
463 cx.spawn(async move |this, cx| {
464 let GatherContextOutput {
465 body,
466 editable_range,
467 } = gather_task.await?;
468 let done_gathering_context_at = Instant::now();
469
470 log::debug!(
471 "Events:\n{}\nExcerpt:\n{:?}",
472 body.input_events,
473 body.input_excerpt
474 );
475
476 let input_outline = body.outline.clone().unwrap_or_default();
477 let input_events = body.input_events.clone();
478 let input_excerpt = body.input_excerpt.clone();
479
480 let response = perform_predict_edits(PerformPredictEditsParams {
481 client,
482 llm_token,
483 app_version,
484 body,
485 })
486 .await;
487 let (response, usage) = match response {
488 Ok(response) => response,
489 Err(err) => {
490 if err.is::<ZedUpdateRequiredError>() {
491 cx.update(|cx| {
492 zeta.update(cx, |zeta, _cx| {
493 zeta.update_required = true;
494 });
495
496 if let Some(workspace) = workspace {
497 workspace.update(cx, |workspace, cx| {
498 workspace.show_notification(
499 NotificationId::unique::<ZedUpdateRequiredError>(),
500 cx,
501 |cx| {
502 cx.new(|cx| {
503 ErrorMessagePrompt::new(err.to_string(), cx)
504 .with_link_button(
505 "Update Zed",
506 "https://zed.dev/releases",
507 )
508 })
509 },
510 );
511 });
512 }
513 })
514 .ok();
515 }
516
517 return Err(err);
518 }
519 };
520
521 let received_response_at = Instant::now();
522 log::debug!("completion response: {}", &response.output_excerpt);
523
524 if let Some(usage) = usage {
525 this.update(cx, |this, cx| {
526 this.user_store.update(cx, |user_store, cx| {
527 user_store.update_edit_prediction_usage(usage, cx);
528 });
529 })
530 .ok();
531 }
532
533 let edit_prediction = Self::process_completion_response(
534 response,
535 buffer,
536 &snapshot,
537 editable_range,
538 cursor_offset,
539 full_path,
540 input_outline,
541 input_events,
542 input_excerpt,
543 buffer_snapshotted_at,
544 cx,
545 )
546 .await;
547
548 let finished_at = Instant::now();
549
550 // record latency for ~1% of requests
551 if rand::random::<u8>() <= 2 {
552 telemetry::event!(
553 "Edit Prediction Request",
554 context_latency = done_gathering_context_at
555 .duration_since(buffer_snapshotted_at)
556 .as_millis(),
557 request_latency = received_response_at
558 .duration_since(done_gathering_context_at)
559 .as_millis(),
560 process_latency = finished_at.duration_since(received_response_at).as_millis()
561 );
562 }
563
564 edit_prediction
565 })
566 }
567
568 // Generates several example completions of various states to fill the Zeta completion modal
569 #[cfg(any(test, feature = "test-support"))]
570 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
571 use language::Point;
572
573 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
574 And maybe a short line
575
576 Then a few lines
577
578 and then another
579 "#};
580
581 let project = None;
582 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
583 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
584
585 let completion_tasks = vec![
586 self.fake_completion(
587 project,
588 &buffer,
589 position,
590 PredictEditsResponse {
591 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
592 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
593a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
594[here's an edit]
595And maybe a short line
596Then a few lines
597and then another
598{EDITABLE_REGION_END_MARKER}
599 ", ),
600 },
601 cx,
602 ),
603 self.fake_completion(
604 project,
605 &buffer,
606 position,
607 PredictEditsResponse {
608 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
609 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
610a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
611And maybe a short line
612[and another edit]
613Then a few lines
614and then another
615{EDITABLE_REGION_END_MARKER}
616 "#),
617 },
618 cx,
619 ),
620 self.fake_completion(
621 project,
622 &buffer,
623 position,
624 PredictEditsResponse {
625 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
626 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
627a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
628And maybe a short line
629
630Then a few lines
631
632and then another
633{EDITABLE_REGION_END_MARKER}
634 "#),
635 },
636 cx,
637 ),
638 self.fake_completion(
639 project,
640 &buffer,
641 position,
642 PredictEditsResponse {
643 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
644 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
645a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
646And maybe a short line
647
648Then a few lines
649
650and then another
651{EDITABLE_REGION_END_MARKER}
652 "#),
653 },
654 cx,
655 ),
656 self.fake_completion(
657 project,
658 &buffer,
659 position,
660 PredictEditsResponse {
661 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
662 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
663a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
664And maybe a short line
665Then a few lines
666[a third completion]
667and then another
668{EDITABLE_REGION_END_MARKER}
669 "#),
670 },
671 cx,
672 ),
673 self.fake_completion(
674 project,
675 &buffer,
676 position,
677 PredictEditsResponse {
678 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
679 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
680a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
681And maybe a short line
682and then another
683[fourth completion example]
684{EDITABLE_REGION_END_MARKER}
685 "#),
686 },
687 cx,
688 ),
689 self.fake_completion(
690 project,
691 &buffer,
692 position,
693 PredictEditsResponse {
694 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
695 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
696a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
697And maybe a short line
698Then a few lines
699and then another
700[fifth and final completion]
701{EDITABLE_REGION_END_MARKER}
702 "#),
703 },
704 cx,
705 ),
706 ];
707
708 cx.spawn(async move |zeta, cx| {
709 for task in completion_tasks {
710 task.await.unwrap();
711 }
712
713 zeta.update(cx, |zeta, _cx| {
714 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
715 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
716 })
717 .ok();
718 })
719 }
720
721 #[cfg(any(test, feature = "test-support"))]
722 pub fn fake_completion(
723 &mut self,
724 project: Option<&Entity<Project>>,
725 buffer: &Entity<Buffer>,
726 position: language::Anchor,
727 response: PredictEditsResponse,
728 cx: &mut Context<Self>,
729 ) -> Task<Result<Option<EditPrediction>>> {
730 use std::future::ready;
731
732 self.request_completion_impl(
733 None,
734 project,
735 buffer,
736 position,
737 CanCollectData(false),
738 cx,
739 |_params| ready(Ok((response, None))),
740 )
741 }
742
743 pub fn request_completion(
744 &mut self,
745 project: Option<&Entity<Project>>,
746 buffer: &Entity<Buffer>,
747 position: language::Anchor,
748 can_collect_data: CanCollectData,
749 cx: &mut Context<Self>,
750 ) -> Task<Result<Option<EditPrediction>>> {
751 self.request_completion_impl(
752 self.workspace.upgrade(),
753 project,
754 buffer,
755 position,
756 can_collect_data,
757 cx,
758 Self::perform_predict_edits,
759 )
760 }
761
762 pub fn perform_predict_edits(
763 params: PerformPredictEditsParams,
764 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
765 async move {
766 let PerformPredictEditsParams {
767 client,
768 llm_token,
769 app_version,
770 body,
771 ..
772 } = params;
773
774 let http_client = client.http_client();
775 let mut token = llm_token.acquire(&client).await?;
776 let mut did_retry = false;
777
778 loop {
779 let request_builder = http_client::Request::builder().method(Method::POST);
780 let request_builder =
781 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
782 request_builder.uri(predict_edits_url)
783 } else {
784 request_builder.uri(
785 http_client
786 .build_zed_llm_url("/predict_edits/v2", &[])?
787 .as_ref(),
788 )
789 };
790 let request = request_builder
791 .header("Content-Type", "application/json")
792 .header("Authorization", format!("Bearer {}", token))
793 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
794 .body(serde_json::to_string(&body)?.into())?;
795
796 let mut response = http_client.send(request).await?;
797
798 if let Some(minimum_required_version) = response
799 .headers()
800 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
801 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
802 {
803 anyhow::ensure!(
804 app_version >= minimum_required_version,
805 ZedUpdateRequiredError {
806 minimum_version: minimum_required_version
807 }
808 );
809 }
810
811 if response.status().is_success() {
812 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
813
814 let mut body = String::new();
815 response.body_mut().read_to_string(&mut body).await?;
816 return Ok((serde_json::from_str(&body)?, usage));
817 } else if !did_retry
818 && response
819 .headers()
820 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
821 .is_some()
822 {
823 did_retry = true;
824 token = llm_token.refresh(&client).await?;
825 } else {
826 let mut body = String::new();
827 response.body_mut().read_to_string(&mut body).await?;
828 anyhow::bail!(
829 "error predicting edits.\nStatus: {:?}\nBody: {}",
830 response.status(),
831 body
832 );
833 }
834 }
835 }
836 }
837
838 fn accept_edit_prediction(
839 &mut self,
840 request_id: EditPredictionId,
841 cx: &mut Context<Self>,
842 ) -> Task<Result<()>> {
843 let client = self.client.clone();
844 let llm_token = self.llm_token.clone();
845 let app_version = AppVersion::global(cx);
846 cx.spawn(async move |this, cx| {
847 let http_client = client.http_client();
848 let mut response = llm_token_retry(&llm_token, &client, |token| {
849 let request_builder = http_client::Request::builder().method(Method::POST);
850 let request_builder =
851 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
852 request_builder.uri(accept_prediction_url)
853 } else {
854 request_builder.uri(
855 http_client
856 .build_zed_llm_url("/predict_edits/accept", &[])?
857 .as_ref(),
858 )
859 };
860 Ok(request_builder
861 .header("Content-Type", "application/json")
862 .header("Authorization", format!("Bearer {}", token))
863 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
864 .body(
865 serde_json::to_string(&AcceptEditPredictionBody {
866 request_id: request_id.0,
867 })?
868 .into(),
869 )?)
870 })
871 .await?;
872
873 if let Some(minimum_required_version) = response
874 .headers()
875 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
876 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
877 && app_version < minimum_required_version
878 {
879 return Err(anyhow!(ZedUpdateRequiredError {
880 minimum_version: minimum_required_version
881 }));
882 }
883
884 if response.status().is_success() {
885 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
886 this.update(cx, |this, cx| {
887 this.user_store.update(cx, |user_store, cx| {
888 user_store.update_edit_prediction_usage(usage, cx);
889 });
890 })?;
891 }
892
893 Ok(())
894 } else {
895 let mut body = String::new();
896 response.body_mut().read_to_string(&mut body).await?;
897 Err(anyhow!(
898 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
899 response.status(),
900 body
901 ))
902 }
903 })
904 }
905
906 fn process_completion_response(
907 prediction_response: PredictEditsResponse,
908 buffer: Entity<Buffer>,
909 snapshot: &BufferSnapshot,
910 editable_range: Range<usize>,
911 cursor_offset: usize,
912 path: Arc<Path>,
913 input_outline: String,
914 input_events: String,
915 input_excerpt: String,
916 buffer_snapshotted_at: Instant,
917 cx: &AsyncApp,
918 ) -> Task<Result<Option<EditPrediction>>> {
919 let snapshot = snapshot.clone();
920 let request_id = prediction_response.request_id;
921 let output_excerpt = prediction_response.output_excerpt;
922 cx.spawn(async move |cx| {
923 let output_excerpt: Arc<str> = output_excerpt.into();
924
925 let edits: Arc<[(Range<Anchor>, String)]> = cx
926 .background_spawn({
927 let output_excerpt = output_excerpt.clone();
928 let editable_range = editable_range.clone();
929 let snapshot = snapshot.clone();
930 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
931 })
932 .await?
933 .into();
934
935 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
936 let edits = edits.clone();
937 |buffer, cx| {
938 let new_snapshot = buffer.snapshot();
939 let edits: Arc<[(Range<Anchor>, String)]> =
940 interpolate(&snapshot, &new_snapshot, edits)?.into();
941 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
942 }
943 })?
944 else {
945 return anyhow::Ok(None);
946 };
947
948 let edit_preview = edit_preview.await;
949
950 Ok(Some(EditPrediction {
951 id: EditPredictionId(request_id),
952 path,
953 excerpt_range: editable_range,
954 cursor_offset,
955 edits,
956 edit_preview,
957 snapshot,
958 input_outline: input_outline.into(),
959 input_events: input_events.into(),
960 input_excerpt: input_excerpt.into(),
961 output_excerpt,
962 buffer_snapshotted_at,
963 response_received_at: Instant::now(),
964 }))
965 })
966 }
967
968 fn parse_edits(
969 output_excerpt: Arc<str>,
970 editable_range: Range<usize>,
971 snapshot: &BufferSnapshot,
972 ) -> Result<Vec<(Range<Anchor>, String)>> {
973 let content = output_excerpt.replace(CURSOR_MARKER, "");
974
975 let start_markers = content
976 .match_indices(EDITABLE_REGION_START_MARKER)
977 .collect::<Vec<_>>();
978 anyhow::ensure!(
979 start_markers.len() == 1,
980 "expected exactly one start marker, found {}",
981 start_markers.len()
982 );
983
984 let end_markers = content
985 .match_indices(EDITABLE_REGION_END_MARKER)
986 .collect::<Vec<_>>();
987 anyhow::ensure!(
988 end_markers.len() == 1,
989 "expected exactly one end marker, found {}",
990 end_markers.len()
991 );
992
993 let sof_markers = content
994 .match_indices(START_OF_FILE_MARKER)
995 .collect::<Vec<_>>();
996 anyhow::ensure!(
997 sof_markers.len() <= 1,
998 "expected at most one start-of-file marker, found {}",
999 sof_markers.len()
1000 );
1001
1002 let codefence_start = start_markers[0].0;
1003 let content = &content[codefence_start..];
1004
1005 let newline_ix = content.find('\n').context("could not find newline")?;
1006 let content = &content[newline_ix + 1..];
1007
1008 let codefence_end = content
1009 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1010 .context("could not find end marker")?;
1011 let new_text = &content[..codefence_end];
1012
1013 let old_text = snapshot
1014 .text_for_range(editable_range.clone())
1015 .collect::<String>();
1016
1017 Ok(Self::compute_edits(
1018 old_text,
1019 new_text,
1020 editable_range.start,
1021 snapshot,
1022 ))
1023 }
1024
1025 pub fn compute_edits(
1026 old_text: String,
1027 new_text: &str,
1028 offset: usize,
1029 snapshot: &BufferSnapshot,
1030 ) -> Vec<(Range<Anchor>, String)> {
1031 text_diff(&old_text, new_text)
1032 .into_iter()
1033 .map(|(mut old_range, new_text)| {
1034 old_range.start += offset;
1035 old_range.end += offset;
1036
1037 let prefix_len = common_prefix(
1038 snapshot.chars_for_range(old_range.clone()),
1039 new_text.chars(),
1040 );
1041 old_range.start += prefix_len;
1042
1043 let suffix_len = common_prefix(
1044 snapshot.reversed_chars_for_range(old_range.clone()),
1045 new_text[prefix_len..].chars().rev(),
1046 );
1047 old_range.end = old_range.end.saturating_sub(suffix_len);
1048
1049 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1050 let range = if old_range.is_empty() {
1051 let anchor = snapshot.anchor_after(old_range.start);
1052 anchor..anchor
1053 } else {
1054 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1055 };
1056 (range, new_text)
1057 })
1058 .collect()
1059 }
1060
1061 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1062 self.rated_completions.contains(&completion_id)
1063 }
1064
1065 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1066 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1067 let completion = self.shown_completions.pop_back().unwrap();
1068 self.rated_completions.remove(&completion.id);
1069 }
1070 self.shown_completions.push_front(completion.clone());
1071 cx.notify();
1072 }
1073
1074 pub fn rate_completion(
1075 &mut self,
1076 completion: &EditPrediction,
1077 rating: EditPredictionRating,
1078 feedback: String,
1079 cx: &mut Context<Self>,
1080 ) {
1081 self.rated_completions.insert(completion.id);
1082 telemetry::event!(
1083 "Edit Prediction Rated",
1084 rating,
1085 input_events = completion.input_events,
1086 input_excerpt = completion.input_excerpt,
1087 input_outline = completion.input_outline,
1088 output_excerpt = completion.output_excerpt,
1089 feedback
1090 );
1091 self.client.telemetry().flush_events().detach();
1092 cx.notify();
1093 }
1094
1095 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1096 self.shown_completions.iter()
1097 }
1098
1099 pub fn shown_completions_len(&self) -> usize {
1100 self.shown_completions.len()
1101 }
1102
1103 fn report_changes_for_buffer(
1104 &mut self,
1105 buffer: &Entity<Buffer>,
1106 cx: &mut Context<Self>,
1107 ) -> BufferSnapshot {
1108 self.register_buffer(buffer, cx);
1109
1110 let registered_buffer = self
1111 .registered_buffers
1112 .get_mut(&buffer.entity_id())
1113 .unwrap();
1114 let new_snapshot = buffer.read(cx).snapshot();
1115
1116 if new_snapshot.version != registered_buffer.snapshot.version {
1117 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1118 self.push_event(Event::BufferChange {
1119 old_snapshot,
1120 new_snapshot: new_snapshot.clone(),
1121 timestamp: Instant::now(),
1122 });
1123 }
1124
1125 new_snapshot
1126 }
1127
1128 fn load_data_collection_choices() -> DataCollectionChoice {
1129 let choice = KEY_VALUE_STORE
1130 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1131 .log_err()
1132 .flatten();
1133
1134 match choice.as_deref() {
1135 Some("true") => DataCollectionChoice::Enabled,
1136 Some("false") => DataCollectionChoice::Disabled,
1137 Some(_) => {
1138 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1139 DataCollectionChoice::NotAnswered
1140 }
1141 None => DataCollectionChoice::NotAnswered,
1142 }
1143 }
1144
1145 fn gather_git_info(
1146 &mut self,
1147 cursor_point: language::Point,
1148 buffer_snapshotted_at: &Instant,
1149 snapshot: &BufferSnapshot,
1150 project: Option<&Entity<Project>>,
1151 cx: &mut Context<Self>,
1152 ) -> Option<PredictEditsGitInfo> {
1153 let project = project?.read(cx);
1154 let file = snapshot.file()?;
1155 let project_path = ProjectPath::from_file(file.as_ref(), cx);
1156 let entry = project.entry_for_path(&project_path, cx)?;
1157 if !worktree_entry_eligible_for_collection(&entry) {
1158 return None;
1159 }
1160
1161 let git_store = project.git_store().read(cx);
1162 let (repository, repo_path) =
1163 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1164 let repo_path_str = repo_path.to_str()?;
1165
1166 repository.update(cx, |repository, cx| {
1167 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1168 let remote_origin_url = repository.remote_origin_url.clone();
1169 let remote_upstream_url = repository.remote_upstream_url.clone();
1170 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1171
1172 Some(PredictEditsGitInfo {
1173 input_path: Some(repo_path_str.to_string()),
1174 cursor_point: Some(to_cloud_llm_client_point(cursor_point)),
1175 head_sha: Some(head_sha),
1176 remote_origin_url,
1177 remote_upstream_url,
1178 recent_files: Some(recent_files),
1179 })
1180 })
1181 }
1182
1183 fn handle_active_workspace_item_changed(&mut self, cx: &Context<Self>) {
1184 if let Some(active_editor) = self
1185 .workspace
1186 .read_with(cx, |workspace, cx| {
1187 workspace
1188 .active_item(cx)
1189 .and_then(|item| item.act_as::<Editor>(cx))
1190 })
1191 .ok()
1192 .flatten()
1193 {
1194 let now = Instant::now();
1195 let new_recent = RecentEditor {
1196 editor: active_editor.downgrade(),
1197 last_active_at: now,
1198 };
1199 if let Some(existing_ix) = self
1200 .recent_editors
1201 .iter()
1202 .rposition(|recent| &recent.editor == &new_recent.editor)
1203 {
1204 self.recent_editors.remove(existing_ix);
1205 }
1206 // filter out rapid changes in active item, particularly since this can happen rapidly when
1207 // a workspace is loaded.
1208 if let Some(previous_recent) = self.recent_editors.back_mut()
1209 && now.duration_since(previous_recent.last_active_at)
1210 < MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES
1211 {
1212 *previous_recent = new_recent;
1213 return;
1214 }
1215 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1216 self.recent_editors.pop_front();
1217 }
1218 self.recent_editors.push_back(new_recent);
1219 }
1220 }
1221
1222 fn recent_files(
1223 &mut self,
1224 now: &Instant,
1225 repository: &Repository,
1226 cx: &mut App,
1227 ) -> Vec<PredictEditsRecentFile> {
1228 let Ok(project) = self
1229 .workspace
1230 .read_with(cx, |workspace, _cx| workspace.project().clone())
1231 else {
1232 return Vec::new();
1233 };
1234 let mut results = Vec::with_capacity(self.recent_editors.len());
1235 for ix in (0..self.recent_editors.len()).rev() {
1236 let recent_editor = &self.recent_editors[ix];
1237 let keep_entry = recent_editor
1238 .editor
1239 .update(cx, |editor, cx| {
1240 maybe!({
1241 let (buffer, cursor_point, _) = editor.cursor_buffer_point(cx)?;
1242 let file = buffer.read(cx).file()?;
1243 let project_path = ProjectPath {
1244 worktree_id: file.worktree_id(cx),
1245 path: file.path().clone(),
1246 };
1247 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1248 if !worktree_entry_eligible_for_collection(entry) {
1249 return None;
1250 }
1251 let Some(repo_path) =
1252 repository.project_path_to_repo_path(&project_path, cx)
1253 else {
1254 // entry not removed since later queries may involve other repositories
1255 return Some(());
1256 };
1257 // paths may not be valid UTF-8
1258 let repo_path_str = repo_path.to_str()?;
1259 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1260 return None;
1261 }
1262 let active_to_now_ms = now
1263 .duration_since(recent_editor.last_active_at)
1264 .as_millis()
1265 .try_into()
1266 .ok()?;
1267 results.push(PredictEditsRecentFile {
1268 path: repo_path_str.to_string(),
1269 cursor_point: to_cloud_llm_client_point(cursor_point),
1270 active_to_now_ms,
1271 });
1272 Some(())
1273 })
1274 })
1275 .ok()
1276 .flatten();
1277 if keep_entry.is_none() {
1278 self.recent_editors.remove(ix);
1279 }
1280 }
1281 results
1282 }
1283}
1284
1285fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1286 cloud_llm_client::Point {
1287 row: point.row,
1288 column: point.column,
1289 }
1290}
1291
1292fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1293 entry.is_file()
1294 && entry.is_created()
1295 && !entry.is_ignored
1296 && !entry.is_private
1297 && !entry.is_external
1298 && !entry.is_fifo
1299}
1300
1301pub struct PerformPredictEditsParams {
1302 pub client: Arc<Client>,
1303 pub llm_token: LlmApiToken,
1304 pub app_version: SemanticVersion,
1305 pub body: PredictEditsBody,
1306}
1307
1308#[derive(Error, Debug)]
1309#[error(
1310 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1311)]
1312pub struct ZedUpdateRequiredError {
1313 minimum_version: SemanticVersion,
1314}
1315
1316fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1317 a.zip(b)
1318 .take_while(|(a, b)| a == b)
1319 .map(|(a, _)| a.len_utf8())
1320 .sum()
1321}
1322
1323pub struct GatherContextOutput {
1324 pub body: PredictEditsBody,
1325 pub editable_range: Range<usize>,
1326}
1327
1328pub fn gather_context(
1329 project: Option<&Entity<Project>>,
1330 full_path_str: String,
1331 snapshot: &BufferSnapshot,
1332 cursor_point: language::Point,
1333 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1334 can_collect_data: CanCollectData,
1335 git_info: Option<PredictEditsGitInfo>,
1336 cx: &App,
1337) -> Task<Result<GatherContextOutput>> {
1338 let local_lsp_store =
1339 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1340 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1341 if matches!(can_collect_data, CanCollectData(true))
1342 && let Some(local_lsp_store) = local_lsp_store
1343 {
1344 snapshot
1345 .diagnostic_groups(None)
1346 .into_iter()
1347 .filter_map(|(language_server_id, diagnostic_group)| {
1348 let language_server =
1349 local_lsp_store.running_language_server_for_id(language_server_id)?;
1350 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1351 let language_server_name = language_server.name().to_string();
1352 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1353 Some((language_server_name, serialized))
1354 })
1355 .collect::<Vec<_>>()
1356 } else {
1357 Vec::new()
1358 };
1359
1360 cx.background_spawn({
1361 let snapshot = snapshot.clone();
1362 async move {
1363 let diagnostic_groups = if diagnostic_groups.is_empty()
1364 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1365 {
1366 None
1367 } else {
1368 Some(diagnostic_groups)
1369 };
1370
1371 let input_excerpt = excerpt_for_cursor_position(
1372 cursor_point,
1373 &full_path_str,
1374 &snapshot,
1375 MAX_REWRITE_TOKENS,
1376 MAX_CONTEXT_TOKENS,
1377 );
1378 let input_events = make_events_prompt();
1379 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1380
1381 let body = PredictEditsBody {
1382 input_events,
1383 input_excerpt: input_excerpt.prompt,
1384 can_collect_data: can_collect_data.0,
1385 diagnostic_groups,
1386 git_info,
1387 outline: None,
1388 speculated_output: None,
1389 };
1390
1391 Ok(GatherContextOutput {
1392 body,
1393 editable_range,
1394 })
1395 }
1396 })
1397}
1398
1399fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1400 let mut result = String::new();
1401 for event in events.iter().rev() {
1402 let event_string = event.to_prompt();
1403 let event_tokens = tokens_for_bytes(event_string.len());
1404 if event_tokens > remaining_tokens {
1405 break;
1406 }
1407
1408 if !result.is_empty() {
1409 result.insert_str(0, "\n\n");
1410 }
1411 result.insert_str(0, &event_string);
1412 remaining_tokens -= event_tokens;
1413 }
1414 result
1415}
1416
1417struct RegisteredBuffer {
1418 snapshot: BufferSnapshot,
1419 _subscriptions: [gpui::Subscription; 2],
1420}
1421
1422#[derive(Clone)]
1423pub enum Event {
1424 BufferChange {
1425 old_snapshot: BufferSnapshot,
1426 new_snapshot: BufferSnapshot,
1427 timestamp: Instant,
1428 },
1429}
1430
1431impl Event {
1432 fn to_prompt(&self) -> String {
1433 match self {
1434 Event::BufferChange {
1435 old_snapshot,
1436 new_snapshot,
1437 ..
1438 } => {
1439 let mut prompt = String::new();
1440
1441 let old_path = old_snapshot
1442 .file()
1443 .map(|f| f.path().as_ref())
1444 .unwrap_or(Path::new("untitled"));
1445 let new_path = new_snapshot
1446 .file()
1447 .map(|f| f.path().as_ref())
1448 .unwrap_or(Path::new("untitled"));
1449 if old_path != new_path {
1450 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1451 }
1452
1453 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1454 if !diff.is_empty() {
1455 write!(
1456 prompt,
1457 "User edited {:?}:\n```diff\n{}\n```",
1458 new_path, diff
1459 )
1460 .unwrap();
1461 }
1462
1463 prompt
1464 }
1465 }
1466 }
1467}
1468
1469#[derive(Debug, Clone)]
1470struct CurrentEditPrediction {
1471 buffer_id: EntityId,
1472 completion: EditPrediction,
1473}
1474
1475impl CurrentEditPrediction {
1476 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1477 if self.buffer_id != old_completion.buffer_id {
1478 return true;
1479 }
1480
1481 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1482 return true;
1483 };
1484 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1485 return false;
1486 };
1487
1488 if old_edits.len() == 1 && new_edits.len() == 1 {
1489 let (old_range, old_text) = &old_edits[0];
1490 let (new_range, new_text) = &new_edits[0];
1491 new_range == old_range && new_text.starts_with(old_text)
1492 } else {
1493 true
1494 }
1495 }
1496}
1497
1498struct PendingCompletion {
1499 id: usize,
1500 _task: Task<()>,
1501}
1502
1503#[derive(Debug, Clone, Copy)]
1504pub enum DataCollectionChoice {
1505 NotAnswered,
1506 Enabled,
1507 Disabled,
1508}
1509
1510impl DataCollectionChoice {
1511 pub fn is_enabled(self) -> bool {
1512 match self {
1513 Self::Enabled => true,
1514 Self::NotAnswered | Self::Disabled => false,
1515 }
1516 }
1517
1518 pub fn is_answered(self) -> bool {
1519 match self {
1520 Self::Enabled | Self::Disabled => true,
1521 Self::NotAnswered => false,
1522 }
1523 }
1524
1525 pub fn toggle(&self) -> DataCollectionChoice {
1526 match self {
1527 Self::Enabled => Self::Disabled,
1528 Self::Disabled => Self::Enabled,
1529 Self::NotAnswered => Self::Enabled,
1530 }
1531 }
1532}
1533
1534impl From<bool> for DataCollectionChoice {
1535 fn from(value: bool) -> Self {
1536 match value {
1537 true => DataCollectionChoice::Enabled,
1538 false => DataCollectionChoice::Disabled,
1539 }
1540 }
1541}
1542
1543pub struct ProviderDataCollection {
1544 /// When set to None, data collection is not possible in the provider buffer
1545 choice: Option<Entity<DataCollectionChoice>>,
1546 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1547}
1548
1549#[derive(Debug, Clone, Copy)]
1550pub struct CanCollectData(pub bool);
1551
1552impl ProviderDataCollection {
1553 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1554 let choice_and_watcher = buffer.and_then(|buffer| {
1555 let file = buffer.read(cx).file()?;
1556
1557 if !file.is_local() || file.is_private() {
1558 return None;
1559 }
1560
1561 let zeta = zeta.read(cx);
1562 let choice = zeta.data_collection_choice.clone();
1563
1564 let license_detection_watcher = zeta
1565 .license_detection_watchers
1566 .get(&file.worktree_id(cx))
1567 .cloned()?;
1568
1569 Some((choice, license_detection_watcher))
1570 });
1571
1572 if let Some((choice, watcher)) = choice_and_watcher {
1573 ProviderDataCollection {
1574 choice: Some(choice),
1575 license_detection_watcher: Some(watcher),
1576 }
1577 } else {
1578 ProviderDataCollection {
1579 choice: None,
1580 license_detection_watcher: None,
1581 }
1582 }
1583 }
1584
1585 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1586 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1587 }
1588
1589 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1590 self.choice
1591 .as_ref()
1592 .is_some_and(|choice| choice.read(cx).is_enabled())
1593 }
1594
1595 fn is_project_open_source(&self) -> bool {
1596 self.license_detection_watcher
1597 .as_ref()
1598 .is_some_and(|watcher| watcher.is_project_open_source())
1599 }
1600
1601 pub fn toggle(&mut self, cx: &mut App) {
1602 if let Some(choice) = self.choice.as_mut() {
1603 let new_choice = choice.update(cx, |choice, _cx| {
1604 let new_choice = choice.toggle();
1605 *choice = new_choice;
1606 new_choice
1607 });
1608
1609 db::write_and_log(cx, move || {
1610 KEY_VALUE_STORE.write_kvp(
1611 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1612 new_choice.is_enabled().to_string(),
1613 )
1614 });
1615 }
1616 }
1617}
1618
1619async fn llm_token_retry(
1620 llm_token: &LlmApiToken,
1621 client: &Arc<Client>,
1622 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1623) -> Result<Response<AsyncBody>> {
1624 let mut did_retry = false;
1625 let http_client = client.http_client();
1626 let mut token = llm_token.acquire(client).await?;
1627 loop {
1628 let request = build_request(token.clone())?;
1629 let response = http_client.send(request).await?;
1630
1631 if !did_retry
1632 && !response.status().is_success()
1633 && response
1634 .headers()
1635 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1636 .is_some()
1637 {
1638 did_retry = true;
1639 token = llm_token.refresh(client).await?;
1640 continue;
1641 }
1642
1643 return Ok(response);
1644 }
1645}
1646
1647pub struct ZetaEditPredictionProvider {
1648 zeta: Entity<Zeta>,
1649 pending_completions: ArrayVec<PendingCompletion, 2>,
1650 next_pending_completion_id: usize,
1651 current_completion: Option<CurrentEditPrediction>,
1652 /// None if this is entirely disabled for this provider
1653 provider_data_collection: ProviderDataCollection,
1654 last_request_timestamp: Instant,
1655}
1656
1657impl ZetaEditPredictionProvider {
1658 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1659
1660 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1661 Self {
1662 zeta,
1663 pending_completions: ArrayVec::new(),
1664 next_pending_completion_id: 0,
1665 current_completion: None,
1666 provider_data_collection,
1667 last_request_timestamp: Instant::now(),
1668 }
1669 }
1670}
1671
1672impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1673 fn name() -> &'static str {
1674 "zed-predict"
1675 }
1676
1677 fn display_name() -> &'static str {
1678 "Zed's Edit Predictions"
1679 }
1680
1681 fn show_completions_in_menu() -> bool {
1682 true
1683 }
1684
1685 fn show_tab_accept_marker() -> bool {
1686 true
1687 }
1688
1689 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1690 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1691
1692 if self.provider_data_collection.is_data_collection_enabled(cx) {
1693 DataCollectionState::Enabled {
1694 is_project_open_source,
1695 }
1696 } else {
1697 DataCollectionState::Disabled {
1698 is_project_open_source,
1699 }
1700 }
1701 }
1702
1703 fn toggle_data_collection(&mut self, cx: &mut App) {
1704 self.provider_data_collection.toggle(cx);
1705 }
1706
1707 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1708 self.zeta.read(cx).usage(cx)
1709 }
1710
1711 fn is_enabled(
1712 &self,
1713 _buffer: &Entity<Buffer>,
1714 _cursor_position: language::Anchor,
1715 _cx: &App,
1716 ) -> bool {
1717 true
1718 }
1719 fn is_refreshing(&self) -> bool {
1720 !self.pending_completions.is_empty()
1721 }
1722
1723 fn refresh(
1724 &mut self,
1725 project: Option<Entity<Project>>,
1726 buffer: Entity<Buffer>,
1727 position: language::Anchor,
1728 _debounce: bool,
1729 cx: &mut Context<Self>,
1730 ) {
1731 if self.zeta.read(cx).update_required {
1732 return;
1733 }
1734
1735 if self
1736 .zeta
1737 .read(cx)
1738 .user_store
1739 .read_with(cx, |user_store, _cx| {
1740 user_store.account_too_young() || user_store.has_overdue_invoices()
1741 })
1742 {
1743 return;
1744 }
1745
1746 if let Some(current_completion) = self.current_completion.as_ref() {
1747 let snapshot = buffer.read(cx).snapshot();
1748 if current_completion
1749 .completion
1750 .interpolate(&snapshot)
1751 .is_some()
1752 {
1753 return;
1754 }
1755 }
1756
1757 let pending_completion_id = self.next_pending_completion_id;
1758 self.next_pending_completion_id += 1;
1759 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1760 let last_request_timestamp = self.last_request_timestamp;
1761
1762 let task = cx.spawn(async move |this, cx| {
1763 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1764 .checked_duration_since(Instant::now())
1765 {
1766 cx.background_executor().timer(timeout).await;
1767 }
1768
1769 let completion_request = this.update(cx, |this, cx| {
1770 this.last_request_timestamp = Instant::now();
1771 this.zeta.update(cx, |zeta, cx| {
1772 zeta.request_completion(
1773 project.as_ref(),
1774 &buffer,
1775 position,
1776 can_collect_data,
1777 cx,
1778 )
1779 })
1780 });
1781
1782 let completion = match completion_request {
1783 Ok(completion_request) => {
1784 let completion_request = completion_request.await;
1785 completion_request.map(|c| {
1786 c.map(|completion| CurrentEditPrediction {
1787 buffer_id: buffer.entity_id(),
1788 completion,
1789 })
1790 })
1791 }
1792 Err(error) => Err(error),
1793 };
1794 let Some(new_completion) = completion
1795 .context("edit prediction failed")
1796 .log_err()
1797 .flatten()
1798 else {
1799 this.update(cx, |this, cx| {
1800 if this.pending_completions[0].id == pending_completion_id {
1801 this.pending_completions.remove(0);
1802 } else {
1803 this.pending_completions.clear();
1804 }
1805
1806 cx.notify();
1807 })
1808 .ok();
1809 return;
1810 };
1811
1812 this.update(cx, |this, cx| {
1813 if this.pending_completions[0].id == pending_completion_id {
1814 this.pending_completions.remove(0);
1815 } else {
1816 this.pending_completions.clear();
1817 }
1818
1819 if let Some(old_completion) = this.current_completion.as_ref() {
1820 let snapshot = buffer.read(cx).snapshot();
1821 if new_completion.should_replace_completion(old_completion, &snapshot) {
1822 this.zeta.update(cx, |zeta, cx| {
1823 zeta.completion_shown(&new_completion.completion, cx);
1824 });
1825 this.current_completion = Some(new_completion);
1826 }
1827 } else {
1828 this.zeta.update(cx, |zeta, cx| {
1829 zeta.completion_shown(&new_completion.completion, cx);
1830 });
1831 this.current_completion = Some(new_completion);
1832 }
1833
1834 cx.notify();
1835 })
1836 .ok();
1837 });
1838
1839 // We always maintain at most two pending completions. When we already
1840 // have two, we replace the newest one.
1841 if self.pending_completions.len() <= 1 {
1842 self.pending_completions.push(PendingCompletion {
1843 id: pending_completion_id,
1844 _task: task,
1845 });
1846 } else if self.pending_completions.len() == 2 {
1847 self.pending_completions.pop();
1848 self.pending_completions.push(PendingCompletion {
1849 id: pending_completion_id,
1850 _task: task,
1851 });
1852 }
1853 }
1854
1855 fn cycle(
1856 &mut self,
1857 _buffer: Entity<Buffer>,
1858 _cursor_position: language::Anchor,
1859 _direction: edit_prediction::Direction,
1860 _cx: &mut Context<Self>,
1861 ) {
1862 // Right now we don't support cycling.
1863 }
1864
1865 fn accept(&mut self, cx: &mut Context<Self>) {
1866 let completion_id = self
1867 .current_completion
1868 .as_ref()
1869 .map(|completion| completion.completion.id);
1870 if let Some(completion_id) = completion_id {
1871 self.zeta
1872 .update(cx, |zeta, cx| {
1873 zeta.accept_edit_prediction(completion_id, cx)
1874 })
1875 .detach();
1876 }
1877 self.pending_completions.clear();
1878 }
1879
1880 fn discard(&mut self, _cx: &mut Context<Self>) {
1881 self.pending_completions.clear();
1882 self.current_completion.take();
1883 }
1884
1885 fn suggest(
1886 &mut self,
1887 buffer: &Entity<Buffer>,
1888 cursor_position: language::Anchor,
1889 cx: &mut Context<Self>,
1890 ) -> Option<edit_prediction::EditPrediction> {
1891 let CurrentEditPrediction {
1892 buffer_id,
1893 completion,
1894 ..
1895 } = self.current_completion.as_mut()?;
1896
1897 // Invalidate previous completion if it was generated for a different buffer.
1898 if *buffer_id != buffer.entity_id() {
1899 self.current_completion.take();
1900 return None;
1901 }
1902
1903 let buffer = buffer.read(cx);
1904 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1905 self.current_completion.take();
1906 return None;
1907 };
1908
1909 let cursor_row = cursor_position.to_point(buffer).row;
1910 let (closest_edit_ix, (closest_edit_range, _)) =
1911 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1912 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1913 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1914 cmp::min(distance_from_start, distance_from_end)
1915 })?;
1916
1917 let mut edit_start_ix = closest_edit_ix;
1918 for (range, _) in edits[..edit_start_ix].iter().rev() {
1919 let distance_from_closest_edit =
1920 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1921 if distance_from_closest_edit <= 1 {
1922 edit_start_ix -= 1;
1923 } else {
1924 break;
1925 }
1926 }
1927
1928 let mut edit_end_ix = closest_edit_ix + 1;
1929 for (range, _) in &edits[edit_end_ix..] {
1930 let distance_from_closest_edit =
1931 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1932 if distance_from_closest_edit <= 1 {
1933 edit_end_ix += 1;
1934 } else {
1935 break;
1936 }
1937 }
1938
1939 Some(edit_prediction::EditPrediction {
1940 id: Some(completion.id.to_string().into()),
1941 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1942 edit_preview: Some(completion.edit_preview.clone()),
1943 })
1944 }
1945}
1946
1947fn tokens_for_bytes(bytes: usize) -> usize {
1948 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1949 /// intentionally low to err on the side of underestimating limits.
1950 const BYTES_PER_TOKEN_GUESS: usize = 3;
1951 bytes / BYTES_PER_TOKEN_GUESS
1952}
1953
1954#[cfg(test)]
1955mod tests {
1956 use client::UserStore;
1957 use client::test::FakeServer;
1958 use clock::FakeSystemClock;
1959 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1960 use gpui::TestAppContext;
1961 use http_client::FakeHttpClient;
1962 use indoc::indoc;
1963 use language::Point;
1964 use settings::SettingsStore;
1965
1966 use super::*;
1967
1968 #[gpui::test]
1969 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1970 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1971 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1972 to_completion_edits(
1973 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1974 &buffer,
1975 cx,
1976 )
1977 .into()
1978 });
1979
1980 let edit_preview = cx
1981 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1982 .await;
1983
1984 let completion = EditPrediction {
1985 edits,
1986 edit_preview,
1987 path: Path::new("").into(),
1988 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1989 id: EditPredictionId(Uuid::new_v4()),
1990 excerpt_range: 0..0,
1991 cursor_offset: 0,
1992 input_outline: "".into(),
1993 input_events: "".into(),
1994 input_excerpt: "".into(),
1995 output_excerpt: "".into(),
1996 buffer_snapshotted_at: Instant::now(),
1997 response_received_at: Instant::now(),
1998 };
1999
2000 cx.update(|cx| {
2001 assert_eq!(
2002 from_completion_edits(
2003 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2004 &buffer,
2005 cx
2006 ),
2007 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2008 );
2009
2010 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2011 assert_eq!(
2012 from_completion_edits(
2013 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2014 &buffer,
2015 cx
2016 ),
2017 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2018 );
2019
2020 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2021 assert_eq!(
2022 from_completion_edits(
2023 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2024 &buffer,
2025 cx
2026 ),
2027 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2028 );
2029
2030 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2031 assert_eq!(
2032 from_completion_edits(
2033 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2034 &buffer,
2035 cx
2036 ),
2037 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2038 );
2039
2040 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2041 assert_eq!(
2042 from_completion_edits(
2043 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2044 &buffer,
2045 cx
2046 ),
2047 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2048 );
2049
2050 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2051 assert_eq!(
2052 from_completion_edits(
2053 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2054 &buffer,
2055 cx
2056 ),
2057 vec![(9..11, "".to_string())]
2058 );
2059
2060 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2061 assert_eq!(
2062 from_completion_edits(
2063 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2064 &buffer,
2065 cx
2066 ),
2067 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2068 );
2069
2070 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2071 assert_eq!(
2072 from_completion_edits(
2073 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2074 &buffer,
2075 cx
2076 ),
2077 vec![(4..4, "M".to_string())]
2078 );
2079
2080 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2081 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2082 })
2083 }
2084
2085 #[gpui::test]
2086 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2087 cx.update(|cx| {
2088 let settings_store = SettingsStore::test(cx);
2089 cx.set_global(settings_store);
2090 client::init_settings(cx);
2091 });
2092
2093 let edits = edits_for_prediction(
2094 indoc! {"
2095 fn main() {
2096 let word_1 = \"lorem\";
2097 let range = word.len()..word.len();
2098 }
2099 "},
2100 indoc! {"
2101 <|editable_region_start|>
2102 fn main() {
2103 let word_1 = \"lorem\";
2104 let range = word_1.len()..word_1.len();
2105 }
2106
2107 <|editable_region_end|>
2108 "},
2109 cx,
2110 )
2111 .await;
2112 assert_eq!(
2113 edits,
2114 [
2115 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2116 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2117 ]
2118 );
2119
2120 let edits = edits_for_prediction(
2121 indoc! {"
2122 fn main() {
2123 let story = \"the quick\"
2124 }
2125 "},
2126 indoc! {"
2127 <|editable_region_start|>
2128 fn main() {
2129 let story = \"the quick brown fox jumps over the lazy dog\";
2130 }
2131
2132 <|editable_region_end|>
2133 "},
2134 cx,
2135 )
2136 .await;
2137 assert_eq!(
2138 edits,
2139 [
2140 (
2141 Point::new(1, 26)..Point::new(1, 26),
2142 " brown fox jumps over the lazy dog".to_string()
2143 ),
2144 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2145 ]
2146 );
2147 }
2148
2149 #[gpui::test]
2150 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2151 cx.update(|cx| {
2152 let settings_store = SettingsStore::test(cx);
2153 cx.set_global(settings_store);
2154 client::init_settings(cx);
2155 });
2156
2157 let buffer_content = "lorem\n";
2158 let completion_response = indoc! {"
2159 ```animals.js
2160 <|start_of_file|>
2161 <|editable_region_start|>
2162 lorem
2163 ipsum
2164 <|editable_region_end|>
2165 ```"};
2166
2167 let http_client = FakeHttpClient::create(move |req| async move {
2168 match (req.method(), req.uri().path()) {
2169 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2170 .status(200)
2171 .body(
2172 serde_json::to_string(&CreateLlmTokenResponse {
2173 token: LlmToken("the-llm-token".to_string()),
2174 })
2175 .unwrap()
2176 .into(),
2177 )
2178 .unwrap()),
2179 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2180 .status(200)
2181 .body(
2182 serde_json::to_string(&PredictEditsResponse {
2183 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2184 .unwrap(),
2185 output_excerpt: completion_response.to_string(),
2186 })
2187 .unwrap()
2188 .into(),
2189 )
2190 .unwrap()),
2191 _ => Ok(http_client::Response::builder()
2192 .status(404)
2193 .body("Not Found".into())
2194 .unwrap()),
2195 }
2196 });
2197
2198 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2199 cx.update(|cx| {
2200 RefreshLlmTokenListener::register(client.clone(), cx);
2201 });
2202 // Construct the fake server to authenticate.
2203 let _server = FakeServer::for_client(42, &client, cx).await;
2204 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2205 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2206
2207 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2208 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2209 let completion_task = zeta.update(cx, |zeta, cx| {
2210 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2211 });
2212
2213 let completion = completion_task.await.unwrap().unwrap();
2214 buffer.update(cx, |buffer, cx| {
2215 buffer.edit(completion.edits.iter().cloned(), None, cx)
2216 });
2217 assert_eq!(
2218 buffer.read_with(cx, |buffer, _| buffer.text()),
2219 "lorem\nipsum"
2220 );
2221 }
2222
2223 async fn edits_for_prediction(
2224 buffer_content: &str,
2225 completion_response: &str,
2226 cx: &mut TestAppContext,
2227 ) -> Vec<(Range<Point>, String)> {
2228 let completion_response = completion_response.to_string();
2229 let http_client = FakeHttpClient::create(move |req| {
2230 let completion = completion_response.clone();
2231 async move {
2232 match (req.method(), req.uri().path()) {
2233 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2234 .status(200)
2235 .body(
2236 serde_json::to_string(&CreateLlmTokenResponse {
2237 token: LlmToken("the-llm-token".to_string()),
2238 })
2239 .unwrap()
2240 .into(),
2241 )
2242 .unwrap()),
2243 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2244 .status(200)
2245 .body(
2246 serde_json::to_string(&PredictEditsResponse {
2247 request_id: Uuid::new_v4(),
2248 output_excerpt: completion,
2249 })
2250 .unwrap()
2251 .into(),
2252 )
2253 .unwrap()),
2254 _ => Ok(http_client::Response::builder()
2255 .status(404)
2256 .body("Not Found".into())
2257 .unwrap()),
2258 }
2259 }
2260 });
2261
2262 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2263 cx.update(|cx| {
2264 RefreshLlmTokenListener::register(client.clone(), cx);
2265 });
2266 // Construct the fake server to authenticate.
2267 let _server = FakeServer::for_client(42, &client, cx).await;
2268 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2269 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2270
2271 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2272 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2273 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2274 let completion_task = zeta.update(cx, |zeta, cx| {
2275 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2276 });
2277
2278 let completion = completion_task.await.unwrap().unwrap();
2279 completion
2280 .edits
2281 .iter()
2282 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2283 .collect::<Vec<_>>()
2284 }
2285
2286 fn to_completion_edits(
2287 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2288 buffer: &Entity<Buffer>,
2289 cx: &App,
2290 ) -> Vec<(Range<Anchor>, String)> {
2291 let buffer = buffer.read(cx);
2292 iterator
2293 .into_iter()
2294 .map(|(range, text)| {
2295 (
2296 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2297 text,
2298 )
2299 })
2300 .collect()
2301 }
2302
2303 fn from_completion_edits(
2304 editor_edits: &[(Range<Anchor>, String)],
2305 buffer: &Entity<Buffer>,
2306 cx: &App,
2307 ) -> Vec<(Range<usize>, String)> {
2308 let buffer = buffer.read(cx);
2309 editor_edits
2310 .iter()
2311 .map(|(range, text)| {
2312 (
2313 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2314 text.clone(),
2315 )
2316 })
2317 .collect()
2318 }
2319
2320 #[ctor::ctor]
2321 fn init_logger() {
2322 zlog::init_test();
2323 }
2324}