1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use edit_prediction::DataCollectionState;
12pub use init::*;
13use license_detection::LicenseDetectionWatcher;
14pub use rate_completion_modal::*;
15
16use anyhow::{Context as _, Result, anyhow};
17use arrayvec::ArrayVec;
18use client::{Client, EditPredictionUsage, UserStore};
19use cloud_llm_client::{
20 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
21 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
22};
23use collections::{HashMap, HashSet, VecDeque};
24use futures::AsyncReadExt;
25use gpui::{
26 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
27 SharedString, Subscription, Task, actions,
28};
29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
30use input_excerpt::excerpt_for_cursor_position;
31use language::{
32 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
33};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use project::{Project, ProjectPath};
36use release_channel::AppVersion;
37use settings::WorktreeId;
38use std::collections::hash_map;
39use std::mem;
40use std::str::FromStr;
41use std::{
42 cmp,
43 fmt::Write,
44 future::Future,
45 ops::Range,
46 path::Path,
47 rc::Rc,
48 sync::Arc,
49 time::{Duration, Instant},
50};
51use telemetry_events::EditPredictionRating;
52use thiserror::Error;
53use util::ResultExt;
54use uuid::Uuid;
55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
56use worktree::Worktree;
57
58const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
59const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
60const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
61const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
62const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
63const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
64
65const MAX_CONTEXT_TOKENS: usize = 150;
66const MAX_REWRITE_TOKENS: usize = 350;
67const MAX_EVENT_TOKENS: usize = 500;
68const MAX_DIAGNOSTIC_GROUPS: usize = 10;
69
70/// Maximum number of events to track.
71const MAX_EVENT_COUNT: usize = 16;
72
73actions!(
74 edit_prediction,
75 [
76 /// Clears the edit prediction history.
77 ClearHistory
78 ]
79);
80
81#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
82pub struct EditPredictionId(Uuid);
83
84impl From<EditPredictionId> for gpui::ElementId {
85 fn from(value: EditPredictionId) -> Self {
86 gpui::ElementId::Uuid(value.0)
87 }
88}
89
90impl std::fmt::Display for EditPredictionId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96struct ZedPredictUpsell;
97
98impl Dismissable for ZedPredictUpsell {
99 const KEY: &'static str = "dismissed-edit-predict-upsell";
100
101 fn dismissed() -> bool {
102 // To make this backwards compatible with older versions of Zed, we
103 // check if the user has seen the previous Edit Prediction Onboarding
104 // before, by checking the data collection choice which was written to
105 // the database once the user clicked on "Accept and Enable"
106 if KEY_VALUE_STORE
107 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
108 .log_err()
109 .is_some_and(|s| s.is_some())
110 {
111 return true;
112 }
113
114 KEY_VALUE_STORE
115 .read_kvp(Self::KEY)
116 .log_err()
117 .is_some_and(|s| s.is_some())
118 }
119}
120
121pub fn should_show_upsell_modal() -> bool {
122 !ZedPredictUpsell::dismissed()
123}
124
125#[derive(Clone)]
126struct ZetaGlobal(Entity<Zeta>);
127
128impl Global for ZetaGlobal {}
129
130#[derive(Clone)]
131pub struct EditPrediction {
132 id: EditPredictionId,
133 path: Arc<Path>,
134 excerpt_range: Range<usize>,
135 cursor_offset: usize,
136 edits: Arc<[(Range<Anchor>, String)]>,
137 snapshot: BufferSnapshot,
138 edit_preview: EditPreview,
139 input_outline: Arc<str>,
140 input_events: Arc<str>,
141 input_excerpt: Arc<str>,
142 output_excerpt: Arc<str>,
143 buffer_snapshotted_at: Instant,
144 response_received_at: Instant,
145}
146
147impl EditPrediction {
148 fn latency(&self) -> Duration {
149 self.response_received_at
150 .duration_since(self.buffer_snapshotted_at)
151 }
152
153 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
154 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
155 }
156}
157
158fn interpolate(
159 old_snapshot: &BufferSnapshot,
160 new_snapshot: &BufferSnapshot,
161 current_edits: Arc<[(Range<Anchor>, String)]>,
162) -> Option<Vec<(Range<Anchor>, String)>> {
163 let mut edits = Vec::new();
164
165 let mut model_edits = current_edits.iter().peekable();
166 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
167 while let Some((model_old_range, _)) = model_edits.peek() {
168 let model_old_range = model_old_range.to_offset(old_snapshot);
169 if model_old_range.end < user_edit.old.start {
170 let (model_old_range, model_new_text) = model_edits.next().unwrap();
171 edits.push((model_old_range.clone(), model_new_text.clone()));
172 } else {
173 break;
174 }
175 }
176
177 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
178 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
179 if user_edit.old == model_old_offset_range {
180 let user_new_text = new_snapshot
181 .text_for_range(user_edit.new.clone())
182 .collect::<String>();
183
184 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
185 if !model_suffix.is_empty() {
186 let anchor = old_snapshot.anchor_after(user_edit.old.end);
187 edits.push((anchor..anchor, model_suffix.to_string()));
188 }
189
190 model_edits.next();
191 continue;
192 }
193 }
194 }
195
196 return None;
197 }
198
199 edits.extend(model_edits.cloned());
200
201 if edits.is_empty() { None } else { Some(edits) }
202}
203
204impl std::fmt::Debug for EditPrediction {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("EditPrediction")
207 .field("id", &self.id)
208 .field("path", &self.path)
209 .field("edits", &self.edits)
210 .finish_non_exhaustive()
211 }
212}
213
214pub struct Zeta {
215 projects: HashMap<EntityId, ZetaProject>,
216 client: Arc<Client>,
217 shown_completions: VecDeque<EditPrediction>,
218 rated_completions: HashSet<EditPredictionId>,
219 data_collection_choice: Entity<DataCollectionChoice>,
220 llm_token: LlmApiToken,
221 _llm_token_subscription: Subscription,
222 /// Whether an update to a newer version of Zed is required to continue using Zeta.
223 update_required: bool,
224 user_store: Entity<UserStore>,
225 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
226}
227
228struct ZetaProject {
229 events: VecDeque<Event>,
230 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
231}
232
233impl Zeta {
234 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
235 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
236 }
237
238 pub fn register(
239 worktree: Option<Entity<Worktree>>,
240 client: Arc<Client>,
241 user_store: Entity<UserStore>,
242 cx: &mut App,
243 ) -> Entity<Self> {
244 let this = Self::global(cx).unwrap_or_else(|| {
245 let entity = cx.new(|cx| Self::new(client, user_store, cx));
246 cx.set_global(ZetaGlobal(entity.clone()));
247 entity
248 });
249
250 this.update(cx, move |this, cx| {
251 if let Some(worktree) = worktree {
252 let worktree_id = worktree.read(cx).id();
253 this.license_detection_watchers
254 .entry(worktree_id)
255 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
256 }
257 });
258
259 this
260 }
261
262 pub fn clear_history(&mut self) {
263 for zeta_project in self.projects.values_mut() {
264 zeta_project.events.clear();
265 }
266 }
267
268 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
269 self.user_store.read(cx).edit_prediction_usage()
270 }
271
272 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
273 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
274
275 let data_collection_choice = Self::load_data_collection_choices();
276 let data_collection_choice = cx.new(|_| data_collection_choice);
277
278 Self {
279 projects: HashMap::default(),
280 client,
281 shown_completions: VecDeque::new(),
282 rated_completions: HashSet::default(),
283 data_collection_choice,
284 llm_token: LlmApiToken::default(),
285 _llm_token_subscription: cx.subscribe(
286 &refresh_llm_token_listener,
287 |this, _listener, _event, cx| {
288 let client = this.client.clone();
289 let llm_token = this.llm_token.clone();
290 cx.spawn(async move |_this, _cx| {
291 llm_token.refresh(&client).await?;
292 anyhow::Ok(())
293 })
294 .detach_and_log_err(cx);
295 },
296 ),
297 update_required: false,
298 license_detection_watchers: HashMap::default(),
299 user_store,
300 }
301 }
302
303 fn get_or_init_zeta_project(
304 &mut self,
305 project: &Entity<Project>,
306 cx: &mut Context<Self>,
307 ) -> &mut ZetaProject {
308 let project_id = project.entity_id();
309 match self.projects.entry(project_id) {
310 hash_map::Entry::Occupied(entry) => entry.into_mut(),
311 hash_map::Entry::Vacant(entry) => {
312 cx.observe_release(project, move |this, _, _cx| {
313 this.projects.remove(&project_id);
314 })
315 .detach();
316 entry.insert(ZetaProject {
317 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
318 registered_buffers: HashMap::default(),
319 })
320 }
321 }
322 }
323
324 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
325 let events = &mut zeta_project.events;
326
327 if let Some(Event::BufferChange {
328 new_snapshot: last_new_snapshot,
329 timestamp: last_timestamp,
330 ..
331 }) = events.back_mut()
332 {
333 // Coalesce edits for the same buffer when they happen one after the other.
334 let Event::BufferChange {
335 old_snapshot,
336 new_snapshot,
337 timestamp,
338 } = &event;
339
340 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
341 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
342 && old_snapshot.version == last_new_snapshot.version
343 {
344 *last_new_snapshot = new_snapshot.clone();
345 *last_timestamp = *timestamp;
346 return;
347 }
348 }
349
350 if events.len() >= MAX_EVENT_COUNT {
351 // These are halved instead of popping to improve prompt caching.
352 events.drain(..MAX_EVENT_COUNT / 2);
353 }
354
355 events.push_back(event);
356 }
357
358 pub fn register_buffer(
359 &mut self,
360 buffer: &Entity<Buffer>,
361 project: &Entity<Project>,
362 cx: &mut Context<Self>,
363 ) {
364 let zeta_project = self.get_or_init_zeta_project(project, cx);
365 Self::register_buffer_impl(zeta_project, buffer, project, cx);
366 }
367
368 fn register_buffer_impl<'a>(
369 zeta_project: &'a mut ZetaProject,
370 buffer: &Entity<Buffer>,
371 project: &Entity<Project>,
372 cx: &mut Context<Self>,
373 ) -> &'a mut RegisteredBuffer {
374 let buffer_id = buffer.entity_id();
375 match zeta_project.registered_buffers.entry(buffer_id) {
376 hash_map::Entry::Occupied(entry) => entry.into_mut(),
377 hash_map::Entry::Vacant(entry) => {
378 let snapshot = buffer.read(cx).snapshot();
379 let project_entity_id = project.entity_id();
380 entry.insert(RegisteredBuffer {
381 snapshot,
382 _subscriptions: [
383 cx.subscribe(buffer, {
384 let project = project.downgrade();
385 move |this, buffer, event, cx| {
386 if let language::BufferEvent::Edited = event
387 && let Some(project) = project.upgrade()
388 {
389 this.report_changes_for_buffer(&buffer, &project, cx);
390 }
391 }
392 }),
393 cx.observe_release(buffer, move |this, _buffer, _cx| {
394 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
395 else {
396 return;
397 };
398 zeta_project.registered_buffers.remove(&buffer_id);
399 }),
400 ],
401 })
402 }
403 }
404 }
405
406 fn request_completion_impl<F, R>(
407 &mut self,
408 project: &Entity<Project>,
409 buffer: &Entity<Buffer>,
410 cursor: language::Anchor,
411 can_collect_data: bool,
412 cx: &mut Context<Self>,
413 perform_predict_edits: F,
414 ) -> Task<Result<Option<EditPrediction>>>
415 where
416 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
417 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
418 + Send
419 + 'static,
420 {
421 let buffer = buffer.clone();
422 let buffer_snapshotted_at = Instant::now();
423 let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
424 let zeta = cx.entity();
425 let events = self.get_or_init_zeta_project(project, cx).events.clone();
426 let client = self.client.clone();
427 let llm_token = self.llm_token.clone();
428 let app_version = AppVersion::global(cx);
429
430 let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) {
431 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
432 } else {
433 None
434 };
435
436 let full_path: Arc<Path> = snapshot
437 .file()
438 .map(|f| Arc::from(f.full_path(cx).as_path()))
439 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
440 let full_path_str = full_path.to_string_lossy().to_string();
441 let cursor_point = cursor.to_point(&snapshot);
442 let cursor_offset = cursor_point.to_offset(&snapshot);
443 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
444 let gather_task = gather_context(
445 project,
446 full_path_str,
447 &snapshot,
448 cursor_point,
449 make_events_prompt,
450 can_collect_data,
451 git_info,
452 cx,
453 );
454
455 cx.spawn(async move |this, cx| {
456 let GatherContextOutput {
457 body,
458 editable_range,
459 } = gather_task.await?;
460 let done_gathering_context_at = Instant::now();
461
462 log::debug!(
463 "Events:\n{}\nExcerpt:\n{:?}",
464 body.input_events,
465 body.input_excerpt
466 );
467
468 let input_outline = body.outline.clone().unwrap_or_default();
469 let input_events = body.input_events.clone();
470 let input_excerpt = body.input_excerpt.clone();
471
472 let response = perform_predict_edits(PerformPredictEditsParams {
473 client,
474 llm_token,
475 app_version,
476 body,
477 })
478 .await;
479 let (response, usage) = match response {
480 Ok(response) => response,
481 Err(err) => {
482 if err.is::<ZedUpdateRequiredError>() {
483 cx.update(|cx| {
484 zeta.update(cx, |zeta, _cx| {
485 zeta.update_required = true;
486 });
487
488 let error_message: SharedString = err.to_string().into();
489 show_app_notification(
490 NotificationId::unique::<ZedUpdateRequiredError>(),
491 cx,
492 move |cx| {
493 cx.new(|cx| {
494 ErrorMessagePrompt::new(error_message.clone(), cx)
495 .with_link_button(
496 "Update Zed",
497 "https://zed.dev/releases",
498 )
499 })
500 },
501 );
502 })
503 .ok();
504 }
505
506 return Err(err);
507 }
508 };
509
510 let received_response_at = Instant::now();
511 log::debug!("completion response: {}", &response.output_excerpt);
512
513 if let Some(usage) = usage {
514 this.update(cx, |this, cx| {
515 this.user_store.update(cx, |user_store, cx| {
516 user_store.update_edit_prediction_usage(usage, cx);
517 });
518 })
519 .ok();
520 }
521
522 let edit_prediction = Self::process_completion_response(
523 response,
524 buffer,
525 &snapshot,
526 editable_range,
527 cursor_offset,
528 full_path,
529 input_outline,
530 input_events,
531 input_excerpt,
532 buffer_snapshotted_at,
533 cx,
534 )
535 .await;
536
537 let finished_at = Instant::now();
538
539 // record latency for ~1% of requests
540 if rand::random::<u8>() <= 2 {
541 telemetry::event!(
542 "Edit Prediction Request",
543 context_latency = done_gathering_context_at
544 .duration_since(buffer_snapshotted_at)
545 .as_millis(),
546 request_latency = received_response_at
547 .duration_since(done_gathering_context_at)
548 .as_millis(),
549 process_latency = finished_at.duration_since(received_response_at).as_millis()
550 );
551 }
552
553 edit_prediction
554 })
555 }
556
557 #[cfg(any(test, feature = "test-support"))]
558 pub fn fake_completion(
559 &mut self,
560 project: &Entity<Project>,
561 buffer: &Entity<Buffer>,
562 position: language::Anchor,
563 response: PredictEditsResponse,
564 cx: &mut Context<Self>,
565 ) -> Task<Result<Option<EditPrediction>>> {
566 use std::future::ready;
567
568 self.request_completion_impl(project, buffer, position, false, cx, |_params| {
569 ready(Ok((response, None)))
570 })
571 }
572
573 pub fn request_completion(
574 &mut self,
575 project: &Entity<Project>,
576 buffer: &Entity<Buffer>,
577 position: language::Anchor,
578 can_collect_data: bool,
579 cx: &mut Context<Self>,
580 ) -> Task<Result<Option<EditPrediction>>> {
581 self.request_completion_impl(
582 project,
583 buffer,
584 position,
585 can_collect_data,
586 cx,
587 Self::perform_predict_edits,
588 )
589 }
590
591 pub fn perform_predict_edits(
592 params: PerformPredictEditsParams,
593 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
594 async move {
595 let PerformPredictEditsParams {
596 client,
597 llm_token,
598 app_version,
599 body,
600 ..
601 } = params;
602
603 let http_client = client.http_client();
604 let mut token = llm_token.acquire(&client).await?;
605 let mut did_retry = false;
606
607 loop {
608 let request_builder = http_client::Request::builder().method(Method::POST);
609 let request_builder =
610 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
611 request_builder.uri(predict_edits_url)
612 } else {
613 request_builder.uri(
614 http_client
615 .build_zed_llm_url("/predict_edits/v2", &[])?
616 .as_ref(),
617 )
618 };
619 let request = request_builder
620 .header("Content-Type", "application/json")
621 .header("Authorization", format!("Bearer {}", token))
622 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
623 .body(serde_json::to_string(&body)?.into())?;
624
625 let mut response = http_client.send(request).await?;
626
627 if let Some(minimum_required_version) = response
628 .headers()
629 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
630 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
631 {
632 anyhow::ensure!(
633 app_version >= minimum_required_version,
634 ZedUpdateRequiredError {
635 minimum_version: minimum_required_version
636 }
637 );
638 }
639
640 if response.status().is_success() {
641 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
642
643 let mut body = String::new();
644 response.body_mut().read_to_string(&mut body).await?;
645 return Ok((serde_json::from_str(&body)?, usage));
646 } else if !did_retry
647 && response
648 .headers()
649 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
650 .is_some()
651 {
652 did_retry = true;
653 token = llm_token.refresh(&client).await?;
654 } else {
655 let mut body = String::new();
656 response.body_mut().read_to_string(&mut body).await?;
657 anyhow::bail!(
658 "error predicting edits.\nStatus: {:?}\nBody: {}",
659 response.status(),
660 body
661 );
662 }
663 }
664 }
665 }
666
667 fn accept_edit_prediction(
668 &mut self,
669 request_id: EditPredictionId,
670 cx: &mut Context<Self>,
671 ) -> Task<Result<()>> {
672 let client = self.client.clone();
673 let llm_token = self.llm_token.clone();
674 let app_version = AppVersion::global(cx);
675 cx.spawn(async move |this, cx| {
676 let http_client = client.http_client();
677 let mut response = llm_token_retry(&llm_token, &client, |token| {
678 let request_builder = http_client::Request::builder().method(Method::POST);
679 let request_builder =
680 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
681 request_builder.uri(accept_prediction_url)
682 } else {
683 request_builder.uri(
684 http_client
685 .build_zed_llm_url("/predict_edits/accept", &[])?
686 .as_ref(),
687 )
688 };
689 Ok(request_builder
690 .header("Content-Type", "application/json")
691 .header("Authorization", format!("Bearer {}", token))
692 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
693 .body(
694 serde_json::to_string(&AcceptEditPredictionBody {
695 request_id: request_id.0,
696 })?
697 .into(),
698 )?)
699 })
700 .await?;
701
702 if let Some(minimum_required_version) = response
703 .headers()
704 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
705 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
706 && app_version < minimum_required_version
707 {
708 return Err(anyhow!(ZedUpdateRequiredError {
709 minimum_version: minimum_required_version
710 }));
711 }
712
713 if response.status().is_success() {
714 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
715 this.update(cx, |this, cx| {
716 this.user_store.update(cx, |user_store, cx| {
717 user_store.update_edit_prediction_usage(usage, cx);
718 });
719 })?;
720 }
721
722 Ok(())
723 } else {
724 let mut body = String::new();
725 response.body_mut().read_to_string(&mut body).await?;
726 Err(anyhow!(
727 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
728 response.status(),
729 body
730 ))
731 }
732 })
733 }
734
735 fn process_completion_response(
736 prediction_response: PredictEditsResponse,
737 buffer: Entity<Buffer>,
738 snapshot: &BufferSnapshot,
739 editable_range: Range<usize>,
740 cursor_offset: usize,
741 path: Arc<Path>,
742 input_outline: String,
743 input_events: String,
744 input_excerpt: String,
745 buffer_snapshotted_at: Instant,
746 cx: &AsyncApp,
747 ) -> Task<Result<Option<EditPrediction>>> {
748 let snapshot = snapshot.clone();
749 let request_id = prediction_response.request_id;
750 let output_excerpt = prediction_response.output_excerpt;
751 cx.spawn(async move |cx| {
752 let output_excerpt: Arc<str> = output_excerpt.into();
753
754 let edits: Arc<[(Range<Anchor>, String)]> = cx
755 .background_spawn({
756 let output_excerpt = output_excerpt.clone();
757 let editable_range = editable_range.clone();
758 let snapshot = snapshot.clone();
759 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
760 })
761 .await?
762 .into();
763
764 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
765 let edits = edits.clone();
766 |buffer, cx| {
767 let new_snapshot = buffer.snapshot();
768 let edits: Arc<[(Range<Anchor>, String)]> =
769 interpolate(&snapshot, &new_snapshot, edits)?.into();
770 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
771 }
772 })?
773 else {
774 return anyhow::Ok(None);
775 };
776
777 let edit_preview = edit_preview.await;
778
779 Ok(Some(EditPrediction {
780 id: EditPredictionId(request_id),
781 path,
782 excerpt_range: editable_range,
783 cursor_offset,
784 edits,
785 edit_preview,
786 snapshot,
787 input_outline: input_outline.into(),
788 input_events: input_events.into(),
789 input_excerpt: input_excerpt.into(),
790 output_excerpt,
791 buffer_snapshotted_at,
792 response_received_at: Instant::now(),
793 }))
794 })
795 }
796
797 fn parse_edits(
798 output_excerpt: Arc<str>,
799 editable_range: Range<usize>,
800 snapshot: &BufferSnapshot,
801 ) -> Result<Vec<(Range<Anchor>, String)>> {
802 let content = output_excerpt.replace(CURSOR_MARKER, "");
803
804 let start_markers = content
805 .match_indices(EDITABLE_REGION_START_MARKER)
806 .collect::<Vec<_>>();
807 anyhow::ensure!(
808 start_markers.len() == 1,
809 "expected exactly one start marker, found {}",
810 start_markers.len()
811 );
812
813 let end_markers = content
814 .match_indices(EDITABLE_REGION_END_MARKER)
815 .collect::<Vec<_>>();
816 anyhow::ensure!(
817 end_markers.len() == 1,
818 "expected exactly one end marker, found {}",
819 end_markers.len()
820 );
821
822 let sof_markers = content
823 .match_indices(START_OF_FILE_MARKER)
824 .collect::<Vec<_>>();
825 anyhow::ensure!(
826 sof_markers.len() <= 1,
827 "expected at most one start-of-file marker, found {}",
828 sof_markers.len()
829 );
830
831 let codefence_start = start_markers[0].0;
832 let content = &content[codefence_start..];
833
834 let newline_ix = content.find('\n').context("could not find newline")?;
835 let content = &content[newline_ix + 1..];
836
837 let codefence_end = content
838 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
839 .context("could not find end marker")?;
840 let new_text = &content[..codefence_end];
841
842 let old_text = snapshot
843 .text_for_range(editable_range.clone())
844 .collect::<String>();
845
846 Ok(Self::compute_edits(
847 old_text,
848 new_text,
849 editable_range.start,
850 snapshot,
851 ))
852 }
853
854 pub fn compute_edits(
855 old_text: String,
856 new_text: &str,
857 offset: usize,
858 snapshot: &BufferSnapshot,
859 ) -> Vec<(Range<Anchor>, String)> {
860 text_diff(&old_text, new_text)
861 .into_iter()
862 .map(|(mut old_range, new_text)| {
863 old_range.start += offset;
864 old_range.end += offset;
865
866 let prefix_len = common_prefix(
867 snapshot.chars_for_range(old_range.clone()),
868 new_text.chars(),
869 );
870 old_range.start += prefix_len;
871
872 let suffix_len = common_prefix(
873 snapshot.reversed_chars_for_range(old_range.clone()),
874 new_text[prefix_len..].chars().rev(),
875 );
876 old_range.end = old_range.end.saturating_sub(suffix_len);
877
878 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
879 let range = if old_range.is_empty() {
880 let anchor = snapshot.anchor_after(old_range.start);
881 anchor..anchor
882 } else {
883 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
884 };
885 (range, new_text)
886 })
887 .collect()
888 }
889
890 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
891 self.rated_completions.contains(&completion_id)
892 }
893
894 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
895 self.shown_completions.push_front(completion.clone());
896 if self.shown_completions.len() > 50 {
897 let completion = self.shown_completions.pop_back().unwrap();
898 self.rated_completions.remove(&completion.id);
899 }
900 cx.notify();
901 }
902
903 pub fn rate_completion(
904 &mut self,
905 completion: &EditPrediction,
906 rating: EditPredictionRating,
907 feedback: String,
908 cx: &mut Context<Self>,
909 ) {
910 self.rated_completions.insert(completion.id);
911 telemetry::event!(
912 "Edit Prediction Rated",
913 rating,
914 input_events = completion.input_events,
915 input_excerpt = completion.input_excerpt,
916 input_outline = completion.input_outline,
917 output_excerpt = completion.output_excerpt,
918 feedback
919 );
920 self.client.telemetry().flush_events().detach();
921 cx.notify();
922 }
923
924 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
925 self.shown_completions.iter()
926 }
927
928 pub fn shown_completions_len(&self) -> usize {
929 self.shown_completions.len()
930 }
931
932 fn report_changes_for_buffer(
933 &mut self,
934 buffer: &Entity<Buffer>,
935 project: &Entity<Project>,
936 cx: &mut Context<Self>,
937 ) -> BufferSnapshot {
938 let zeta_project = self.get_or_init_zeta_project(project, cx);
939 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
940
941 let new_snapshot = buffer.read(cx).snapshot();
942 if new_snapshot.version != registered_buffer.snapshot.version {
943 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
944 Self::push_event(
945 zeta_project,
946 Event::BufferChange {
947 old_snapshot,
948 new_snapshot: new_snapshot.clone(),
949 timestamp: Instant::now(),
950 },
951 );
952 }
953
954 new_snapshot
955 }
956
957 fn load_data_collection_choices() -> DataCollectionChoice {
958 let choice = KEY_VALUE_STORE
959 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
960 .log_err()
961 .flatten();
962
963 match choice.as_deref() {
964 Some("true") => DataCollectionChoice::Enabled,
965 Some("false") => DataCollectionChoice::Disabled,
966 Some(_) => {
967 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
968 DataCollectionChoice::NotAnswered
969 }
970 None => DataCollectionChoice::NotAnswered,
971 }
972 }
973}
974
975pub struct PerformPredictEditsParams {
976 pub client: Arc<Client>,
977 pub llm_token: LlmApiToken,
978 pub app_version: SemanticVersion,
979 pub body: PredictEditsBody,
980}
981
982#[derive(Error, Debug)]
983#[error(
984 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
985)]
986pub struct ZedUpdateRequiredError {
987 minimum_version: SemanticVersion,
988}
989
990fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
991 a.zip(b)
992 .take_while(|(a, b)| a == b)
993 .map(|(a, _)| a.len_utf8())
994 .sum()
995}
996
997fn git_info_for_file(
998 project: &Entity<Project>,
999 project_path: &ProjectPath,
1000 cx: &App,
1001) -> Option<PredictEditsGitInfo> {
1002 let git_store = project.read(cx).git_store().read(cx);
1003 if let Some((repository, _repo_path)) =
1004 git_store.repository_and_path_for_project_path(project_path, cx)
1005 {
1006 let repository = repository.read(cx);
1007 let head_sha = repository
1008 .head_commit
1009 .as_ref()
1010 .map(|head_commit| head_commit.sha.to_string());
1011 let remote_origin_url = repository.remote_origin_url.clone();
1012 let remote_upstream_url = repository.remote_upstream_url.clone();
1013 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1014 return None;
1015 }
1016 Some(PredictEditsGitInfo {
1017 head_sha,
1018 remote_origin_url,
1019 remote_upstream_url,
1020 })
1021 } else {
1022 None
1023 }
1024}
1025
1026pub struct GatherContextOutput {
1027 pub body: PredictEditsBody,
1028 pub editable_range: Range<usize>,
1029}
1030
1031pub fn gather_context(
1032 project: &Entity<Project>,
1033 full_path_str: String,
1034 snapshot: &BufferSnapshot,
1035 cursor_point: language::Point,
1036 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1037 can_collect_data: bool,
1038 git_info: Option<PredictEditsGitInfo>,
1039 cx: &App,
1040) -> Task<Result<GatherContextOutput>> {
1041 let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local();
1042 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1043 if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1044 snapshot
1045 .diagnostic_groups(None)
1046 .into_iter()
1047 .filter_map(|(language_server_id, diagnostic_group)| {
1048 let language_server =
1049 local_lsp_store.running_language_server_for_id(language_server_id)?;
1050 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1051 let language_server_name = language_server.name().to_string();
1052 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1053 Some((language_server_name, serialized))
1054 })
1055 .collect::<Vec<_>>()
1056 } else {
1057 Vec::new()
1058 };
1059
1060 cx.background_spawn({
1061 let snapshot = snapshot.clone();
1062 async move {
1063 let diagnostic_groups = if diagnostic_groups.is_empty()
1064 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1065 {
1066 None
1067 } else {
1068 Some(diagnostic_groups)
1069 };
1070
1071 let input_excerpt = excerpt_for_cursor_position(
1072 cursor_point,
1073 &full_path_str,
1074 &snapshot,
1075 MAX_REWRITE_TOKENS,
1076 MAX_CONTEXT_TOKENS,
1077 );
1078 let input_events = make_events_prompt();
1079 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1080
1081 let body = PredictEditsBody {
1082 input_events,
1083 input_excerpt: input_excerpt.prompt,
1084 can_collect_data,
1085 diagnostic_groups,
1086 git_info,
1087 outline: None,
1088 speculated_output: None,
1089 };
1090
1091 Ok(GatherContextOutput {
1092 body,
1093 editable_range,
1094 })
1095 }
1096 })
1097}
1098
1099fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1100 let mut result = String::new();
1101 for event in events.iter().rev() {
1102 let event_string = event.to_prompt();
1103 let event_tokens = tokens_for_bytes(event_string.len());
1104 if event_tokens > remaining_tokens {
1105 break;
1106 }
1107
1108 if !result.is_empty() {
1109 result.insert_str(0, "\n\n");
1110 }
1111 result.insert_str(0, &event_string);
1112 remaining_tokens -= event_tokens;
1113 }
1114 result
1115}
1116
1117struct RegisteredBuffer {
1118 snapshot: BufferSnapshot,
1119 _subscriptions: [gpui::Subscription; 2],
1120}
1121
1122#[derive(Clone)]
1123pub enum Event {
1124 BufferChange {
1125 old_snapshot: BufferSnapshot,
1126 new_snapshot: BufferSnapshot,
1127 timestamp: Instant,
1128 },
1129}
1130
1131impl Event {
1132 fn to_prompt(&self) -> String {
1133 match self {
1134 Event::BufferChange {
1135 old_snapshot,
1136 new_snapshot,
1137 ..
1138 } => {
1139 let mut prompt = String::new();
1140
1141 let old_path = old_snapshot
1142 .file()
1143 .map(|f| f.path().as_ref())
1144 .unwrap_or(Path::new("untitled"));
1145 let new_path = new_snapshot
1146 .file()
1147 .map(|f| f.path().as_ref())
1148 .unwrap_or(Path::new("untitled"));
1149 if old_path != new_path {
1150 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1151 }
1152
1153 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1154 if !diff.is_empty() {
1155 write!(
1156 prompt,
1157 "User edited {:?}:\n```diff\n{}\n```",
1158 new_path, diff
1159 )
1160 .unwrap();
1161 }
1162
1163 prompt
1164 }
1165 }
1166 }
1167}
1168
1169#[derive(Debug, Clone)]
1170struct CurrentEditPrediction {
1171 buffer_id: EntityId,
1172 completion: EditPrediction,
1173}
1174
1175impl CurrentEditPrediction {
1176 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1177 if self.buffer_id != old_completion.buffer_id {
1178 return true;
1179 }
1180
1181 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1182 return true;
1183 };
1184 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1185 return false;
1186 };
1187
1188 if old_edits.len() == 1 && new_edits.len() == 1 {
1189 let (old_range, old_text) = &old_edits[0];
1190 let (new_range, new_text) = &new_edits[0];
1191 new_range == old_range && new_text.starts_with(old_text)
1192 } else {
1193 true
1194 }
1195 }
1196}
1197
1198struct PendingCompletion {
1199 id: usize,
1200 _task: Task<()>,
1201}
1202
1203#[derive(Debug, Clone, Copy)]
1204pub enum DataCollectionChoice {
1205 NotAnswered,
1206 Enabled,
1207 Disabled,
1208}
1209
1210impl DataCollectionChoice {
1211 pub fn is_enabled(self) -> bool {
1212 match self {
1213 Self::Enabled => true,
1214 Self::NotAnswered | Self::Disabled => false,
1215 }
1216 }
1217
1218 pub fn is_answered(self) -> bool {
1219 match self {
1220 Self::Enabled | Self::Disabled => true,
1221 Self::NotAnswered => false,
1222 }
1223 }
1224
1225 pub fn toggle(&self) -> DataCollectionChoice {
1226 match self {
1227 Self::Enabled => Self::Disabled,
1228 Self::Disabled => Self::Enabled,
1229 Self::NotAnswered => Self::Enabled,
1230 }
1231 }
1232}
1233
1234impl From<bool> for DataCollectionChoice {
1235 fn from(value: bool) -> Self {
1236 match value {
1237 true => DataCollectionChoice::Enabled,
1238 false => DataCollectionChoice::Disabled,
1239 }
1240 }
1241}
1242
1243pub struct ProviderDataCollection {
1244 /// When set to None, data collection is not possible in the provider buffer
1245 choice: Option<Entity<DataCollectionChoice>>,
1246 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1247}
1248
1249impl ProviderDataCollection {
1250 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1251 let choice_and_watcher = buffer.and_then(|buffer| {
1252 let file = buffer.read(cx).file()?;
1253
1254 if !file.is_local() || file.is_private() {
1255 return None;
1256 }
1257
1258 let zeta = zeta.read(cx);
1259 let choice = zeta.data_collection_choice.clone();
1260
1261 let license_detection_watcher = zeta
1262 .license_detection_watchers
1263 .get(&file.worktree_id(cx))
1264 .cloned()?;
1265
1266 Some((choice, license_detection_watcher))
1267 });
1268
1269 if let Some((choice, watcher)) = choice_and_watcher {
1270 ProviderDataCollection {
1271 choice: Some(choice),
1272 license_detection_watcher: Some(watcher),
1273 }
1274 } else {
1275 ProviderDataCollection {
1276 choice: None,
1277 license_detection_watcher: None,
1278 }
1279 }
1280 }
1281
1282 pub fn can_collect_data(&self, cx: &App) -> bool {
1283 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1284 }
1285
1286 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1287 self.choice
1288 .as_ref()
1289 .is_some_and(|choice| choice.read(cx).is_enabled())
1290 }
1291
1292 fn is_project_open_source(&self) -> bool {
1293 self.license_detection_watcher
1294 .as_ref()
1295 .is_some_and(|watcher| watcher.is_project_open_source())
1296 }
1297
1298 pub fn toggle(&mut self, cx: &mut App) {
1299 if let Some(choice) = self.choice.as_mut() {
1300 let new_choice = choice.update(cx, |choice, _cx| {
1301 let new_choice = choice.toggle();
1302 *choice = new_choice;
1303 new_choice
1304 });
1305
1306 db::write_and_log(cx, move || {
1307 KEY_VALUE_STORE.write_kvp(
1308 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1309 new_choice.is_enabled().to_string(),
1310 )
1311 });
1312 }
1313 }
1314}
1315
1316async fn llm_token_retry(
1317 llm_token: &LlmApiToken,
1318 client: &Arc<Client>,
1319 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1320) -> Result<Response<AsyncBody>> {
1321 let mut did_retry = false;
1322 let http_client = client.http_client();
1323 let mut token = llm_token.acquire(client).await?;
1324 loop {
1325 let request = build_request(token.clone())?;
1326 let response = http_client.send(request).await?;
1327
1328 if !did_retry
1329 && !response.status().is_success()
1330 && response
1331 .headers()
1332 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1333 .is_some()
1334 {
1335 did_retry = true;
1336 token = llm_token.refresh(client).await?;
1337 continue;
1338 }
1339
1340 return Ok(response);
1341 }
1342}
1343
1344pub struct ZetaEditPredictionProvider {
1345 zeta: Entity<Zeta>,
1346 pending_completions: ArrayVec<PendingCompletion, 2>,
1347 next_pending_completion_id: usize,
1348 current_completion: Option<CurrentEditPrediction>,
1349 /// None if this is entirely disabled for this provider
1350 provider_data_collection: ProviderDataCollection,
1351 last_request_timestamp: Instant,
1352}
1353
1354impl ZetaEditPredictionProvider {
1355 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1356
1357 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1358 Self {
1359 zeta,
1360 pending_completions: ArrayVec::new(),
1361 next_pending_completion_id: 0,
1362 current_completion: None,
1363 provider_data_collection,
1364 last_request_timestamp: Instant::now(),
1365 }
1366 }
1367}
1368
1369impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1370 fn name() -> &'static str {
1371 "zed-predict"
1372 }
1373
1374 fn display_name() -> &'static str {
1375 "Zed's Edit Predictions"
1376 }
1377
1378 fn show_completions_in_menu() -> bool {
1379 true
1380 }
1381
1382 fn show_tab_accept_marker() -> bool {
1383 true
1384 }
1385
1386 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1387 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1388
1389 if self.provider_data_collection.is_data_collection_enabled(cx) {
1390 DataCollectionState::Enabled {
1391 is_project_open_source,
1392 }
1393 } else {
1394 DataCollectionState::Disabled {
1395 is_project_open_source,
1396 }
1397 }
1398 }
1399
1400 fn toggle_data_collection(&mut self, cx: &mut App) {
1401 self.provider_data_collection.toggle(cx);
1402 }
1403
1404 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1405 self.zeta.read(cx).usage(cx)
1406 }
1407
1408 fn is_enabled(
1409 &self,
1410 _buffer: &Entity<Buffer>,
1411 _cursor_position: language::Anchor,
1412 _cx: &App,
1413 ) -> bool {
1414 true
1415 }
1416 fn is_refreshing(&self) -> bool {
1417 !self.pending_completions.is_empty()
1418 }
1419
1420 fn refresh(
1421 &mut self,
1422 project: Option<Entity<Project>>,
1423 buffer: Entity<Buffer>,
1424 position: language::Anchor,
1425 _debounce: bool,
1426 cx: &mut Context<Self>,
1427 ) {
1428 if self.zeta.read(cx).update_required {
1429 return;
1430 }
1431 let Some(project) = project else {
1432 return;
1433 };
1434
1435 if self
1436 .zeta
1437 .read(cx)
1438 .user_store
1439 .read_with(cx, |user_store, _cx| {
1440 user_store.account_too_young() || user_store.has_overdue_invoices()
1441 })
1442 {
1443 return;
1444 }
1445
1446 if let Some(current_completion) = self.current_completion.as_ref() {
1447 let snapshot = buffer.read(cx).snapshot();
1448 if current_completion
1449 .completion
1450 .interpolate(&snapshot)
1451 .is_some()
1452 {
1453 return;
1454 }
1455 }
1456
1457 let pending_completion_id = self.next_pending_completion_id;
1458 self.next_pending_completion_id += 1;
1459 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1460 let last_request_timestamp = self.last_request_timestamp;
1461
1462 let task = cx.spawn(async move |this, cx| {
1463 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1464 .checked_duration_since(Instant::now())
1465 {
1466 cx.background_executor().timer(timeout).await;
1467 }
1468
1469 let completion_request = this.update(cx, |this, cx| {
1470 this.last_request_timestamp = Instant::now();
1471 this.zeta.update(cx, |zeta, cx| {
1472 zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
1473 })
1474 });
1475
1476 let completion = match completion_request {
1477 Ok(completion_request) => {
1478 let completion_request = completion_request.await;
1479 completion_request.map(|c| {
1480 c.map(|completion| CurrentEditPrediction {
1481 buffer_id: buffer.entity_id(),
1482 completion,
1483 })
1484 })
1485 }
1486 Err(error) => Err(error),
1487 };
1488 let Some(new_completion) = completion
1489 .context("edit prediction failed")
1490 .log_err()
1491 .flatten()
1492 else {
1493 this.update(cx, |this, cx| {
1494 if this.pending_completions[0].id == pending_completion_id {
1495 this.pending_completions.remove(0);
1496 } else {
1497 this.pending_completions.clear();
1498 }
1499
1500 cx.notify();
1501 })
1502 .ok();
1503 return;
1504 };
1505
1506 this.update(cx, |this, cx| {
1507 if this.pending_completions[0].id == pending_completion_id {
1508 this.pending_completions.remove(0);
1509 } else {
1510 this.pending_completions.clear();
1511 }
1512
1513 if let Some(old_completion) = this.current_completion.as_ref() {
1514 let snapshot = buffer.read(cx).snapshot();
1515 if new_completion.should_replace_completion(old_completion, &snapshot) {
1516 this.zeta.update(cx, |zeta, cx| {
1517 zeta.completion_shown(&new_completion.completion, cx);
1518 });
1519 this.current_completion = Some(new_completion);
1520 }
1521 } else {
1522 this.zeta.update(cx, |zeta, cx| {
1523 zeta.completion_shown(&new_completion.completion, cx);
1524 });
1525 this.current_completion = Some(new_completion);
1526 }
1527
1528 cx.notify();
1529 })
1530 .ok();
1531 });
1532
1533 // We always maintain at most two pending completions. When we already
1534 // have two, we replace the newest one.
1535 if self.pending_completions.len() <= 1 {
1536 self.pending_completions.push(PendingCompletion {
1537 id: pending_completion_id,
1538 _task: task,
1539 });
1540 } else if self.pending_completions.len() == 2 {
1541 self.pending_completions.pop();
1542 self.pending_completions.push(PendingCompletion {
1543 id: pending_completion_id,
1544 _task: task,
1545 });
1546 }
1547 }
1548
1549 fn cycle(
1550 &mut self,
1551 _buffer: Entity<Buffer>,
1552 _cursor_position: language::Anchor,
1553 _direction: edit_prediction::Direction,
1554 _cx: &mut Context<Self>,
1555 ) {
1556 // Right now we don't support cycling.
1557 }
1558
1559 fn accept(&mut self, cx: &mut Context<Self>) {
1560 let completion_id = self
1561 .current_completion
1562 .as_ref()
1563 .map(|completion| completion.completion.id);
1564 if let Some(completion_id) = completion_id {
1565 self.zeta
1566 .update(cx, |zeta, cx| {
1567 zeta.accept_edit_prediction(completion_id, cx)
1568 })
1569 .detach();
1570 }
1571 self.pending_completions.clear();
1572 }
1573
1574 fn discard(&mut self, _cx: &mut Context<Self>) {
1575 self.pending_completions.clear();
1576 self.current_completion.take();
1577 }
1578
1579 fn suggest(
1580 &mut self,
1581 buffer: &Entity<Buffer>,
1582 cursor_position: language::Anchor,
1583 cx: &mut Context<Self>,
1584 ) -> Option<edit_prediction::EditPrediction> {
1585 let CurrentEditPrediction {
1586 buffer_id,
1587 completion,
1588 ..
1589 } = self.current_completion.as_mut()?;
1590
1591 // Invalidate previous completion if it was generated for a different buffer.
1592 if *buffer_id != buffer.entity_id() {
1593 self.current_completion.take();
1594 return None;
1595 }
1596
1597 let buffer = buffer.read(cx);
1598 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1599 self.current_completion.take();
1600 return None;
1601 };
1602
1603 let cursor_row = cursor_position.to_point(buffer).row;
1604 let (closest_edit_ix, (closest_edit_range, _)) =
1605 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1606 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1607 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1608 cmp::min(distance_from_start, distance_from_end)
1609 })?;
1610
1611 let mut edit_start_ix = closest_edit_ix;
1612 for (range, _) in edits[..edit_start_ix].iter().rev() {
1613 let distance_from_closest_edit =
1614 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1615 if distance_from_closest_edit <= 1 {
1616 edit_start_ix -= 1;
1617 } else {
1618 break;
1619 }
1620 }
1621
1622 let mut edit_end_ix = closest_edit_ix + 1;
1623 for (range, _) in &edits[edit_end_ix..] {
1624 let distance_from_closest_edit =
1625 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1626 if distance_from_closest_edit <= 1 {
1627 edit_end_ix += 1;
1628 } else {
1629 break;
1630 }
1631 }
1632
1633 Some(edit_prediction::EditPrediction {
1634 id: Some(completion.id.to_string().into()),
1635 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1636 edit_preview: Some(completion.edit_preview.clone()),
1637 })
1638 }
1639}
1640
1641fn tokens_for_bytes(bytes: usize) -> usize {
1642 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1643 /// intentionally low to err on the side of underestimating limits.
1644 const BYTES_PER_TOKEN_GUESS: usize = 3;
1645 bytes / BYTES_PER_TOKEN_GUESS
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650 use client::test::FakeServer;
1651 use clock::FakeSystemClock;
1652 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1653 use gpui::TestAppContext;
1654 use http_client::FakeHttpClient;
1655 use indoc::indoc;
1656 use language::Point;
1657 use settings::SettingsStore;
1658 use util::path;
1659
1660 use super::*;
1661
1662 #[gpui::test]
1663 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1664 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1665 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1666 to_completion_edits(
1667 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1668 &buffer,
1669 cx,
1670 )
1671 .into()
1672 });
1673
1674 let edit_preview = cx
1675 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1676 .await;
1677
1678 let completion = EditPrediction {
1679 edits,
1680 edit_preview,
1681 path: Path::new("").into(),
1682 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1683 id: EditPredictionId(Uuid::new_v4()),
1684 excerpt_range: 0..0,
1685 cursor_offset: 0,
1686 input_outline: "".into(),
1687 input_events: "".into(),
1688 input_excerpt: "".into(),
1689 output_excerpt: "".into(),
1690 buffer_snapshotted_at: Instant::now(),
1691 response_received_at: Instant::now(),
1692 };
1693
1694 cx.update(|cx| {
1695 assert_eq!(
1696 from_completion_edits(
1697 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1698 &buffer,
1699 cx
1700 ),
1701 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1702 );
1703
1704 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1705 assert_eq!(
1706 from_completion_edits(
1707 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1708 &buffer,
1709 cx
1710 ),
1711 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1712 );
1713
1714 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1715 assert_eq!(
1716 from_completion_edits(
1717 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1718 &buffer,
1719 cx
1720 ),
1721 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1722 );
1723
1724 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1725 assert_eq!(
1726 from_completion_edits(
1727 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1728 &buffer,
1729 cx
1730 ),
1731 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1732 );
1733
1734 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1735 assert_eq!(
1736 from_completion_edits(
1737 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1738 &buffer,
1739 cx
1740 ),
1741 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1742 );
1743
1744 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1745 assert_eq!(
1746 from_completion_edits(
1747 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1748 &buffer,
1749 cx
1750 ),
1751 vec![(9..11, "".to_string())]
1752 );
1753
1754 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1755 assert_eq!(
1756 from_completion_edits(
1757 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1758 &buffer,
1759 cx
1760 ),
1761 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1762 );
1763
1764 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1765 assert_eq!(
1766 from_completion_edits(
1767 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1768 &buffer,
1769 cx
1770 ),
1771 vec![(4..4, "M".to_string())]
1772 );
1773
1774 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1775 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1776 })
1777 }
1778
1779 #[gpui::test]
1780 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1781 cx.update(|cx| {
1782 let settings_store = SettingsStore::test(cx);
1783 cx.set_global(settings_store);
1784 client::init_settings(cx);
1785 Project::init_settings(cx);
1786 });
1787
1788 let edits = edits_for_prediction(
1789 indoc! {"
1790 fn main() {
1791 let word_1 = \"lorem\";
1792 let range = word.len()..word.len();
1793 }
1794 "},
1795 indoc! {"
1796 <|editable_region_start|>
1797 fn main() {
1798 let word_1 = \"lorem\";
1799 let range = word_1.len()..word_1.len();
1800 }
1801
1802 <|editable_region_end|>
1803 "},
1804 cx,
1805 )
1806 .await;
1807 assert_eq!(
1808 edits,
1809 [
1810 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1811 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1812 ]
1813 );
1814
1815 let edits = edits_for_prediction(
1816 indoc! {"
1817 fn main() {
1818 let story = \"the quick\"
1819 }
1820 "},
1821 indoc! {"
1822 <|editable_region_start|>
1823 fn main() {
1824 let story = \"the quick brown fox jumps over the lazy dog\";
1825 }
1826
1827 <|editable_region_end|>
1828 "},
1829 cx,
1830 )
1831 .await;
1832 assert_eq!(
1833 edits,
1834 [
1835 (
1836 Point::new(1, 26)..Point::new(1, 26),
1837 " brown fox jumps over the lazy dog".to_string()
1838 ),
1839 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1840 ]
1841 );
1842 }
1843
1844 #[gpui::test]
1845 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1846 cx.update(|cx| {
1847 let settings_store = SettingsStore::test(cx);
1848 cx.set_global(settings_store);
1849 client::init_settings(cx);
1850 Project::init_settings(cx);
1851 });
1852
1853 let buffer_content = "lorem\n";
1854 let completion_response = indoc! {"
1855 ```animals.js
1856 <|start_of_file|>
1857 <|editable_region_start|>
1858 lorem
1859 ipsum
1860 <|editable_region_end|>
1861 ```"};
1862
1863 let http_client = FakeHttpClient::create(move |req| async move {
1864 match (req.method(), req.uri().path()) {
1865 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1866 .status(200)
1867 .body(
1868 serde_json::to_string(&CreateLlmTokenResponse {
1869 token: LlmToken("the-llm-token".to_string()),
1870 })
1871 .unwrap()
1872 .into(),
1873 )
1874 .unwrap()),
1875 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
1876 .status(200)
1877 .body(
1878 serde_json::to_string(&PredictEditsResponse {
1879 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1880 .unwrap(),
1881 output_excerpt: completion_response.to_string(),
1882 })
1883 .unwrap()
1884 .into(),
1885 )
1886 .unwrap()),
1887 _ => Ok(http_client::Response::builder()
1888 .status(404)
1889 .body("Not Found".into())
1890 .unwrap()),
1891 }
1892 });
1893
1894 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1895 cx.update(|cx| {
1896 RefreshLlmTokenListener::register(client.clone(), cx);
1897 });
1898 // Construct the fake server to authenticate.
1899 let _server = FakeServer::for_client(42, &client, cx).await;
1900 let fs = project::FakeFs::new(cx.executor());
1901 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1902 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1903 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1904
1905 let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
1906 let completion_task = zeta.update(cx, |zeta, cx| {
1907 zeta.request_completion(&project, &buffer, cursor, false, cx)
1908 });
1909
1910 let completion = completion_task.await.unwrap().unwrap();
1911 buffer.update(cx, |buffer, cx| {
1912 buffer.edit(completion.edits.iter().cloned(), None, cx)
1913 });
1914 assert_eq!(
1915 buffer.read_with(cx, |buffer, _| buffer.text()),
1916 "lorem\nipsum"
1917 );
1918 }
1919
1920 async fn edits_for_prediction(
1921 buffer_content: &str,
1922 completion_response: &str,
1923 cx: &mut TestAppContext,
1924 ) -> Vec<(Range<Point>, String)> {
1925 let completion_response = completion_response.to_string();
1926 let http_client = FakeHttpClient::create(move |req| {
1927 let completion = completion_response.clone();
1928 async move {
1929 match (req.method(), req.uri().path()) {
1930 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1931 .status(200)
1932 .body(
1933 serde_json::to_string(&CreateLlmTokenResponse {
1934 token: LlmToken("the-llm-token".to_string()),
1935 })
1936 .unwrap()
1937 .into(),
1938 )
1939 .unwrap()),
1940 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
1941 .status(200)
1942 .body(
1943 serde_json::to_string(&PredictEditsResponse {
1944 request_id: Uuid::new_v4(),
1945 output_excerpt: completion,
1946 })
1947 .unwrap()
1948 .into(),
1949 )
1950 .unwrap()),
1951 _ => Ok(http_client::Response::builder()
1952 .status(404)
1953 .body("Not Found".into())
1954 .unwrap()),
1955 }
1956 }
1957 });
1958
1959 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1960 cx.update(|cx| {
1961 RefreshLlmTokenListener::register(client.clone(), cx);
1962 });
1963 // Construct the fake server to authenticate.
1964 let _server = FakeServer::for_client(42, &client, cx).await;
1965 let fs = project::FakeFs::new(cx.executor());
1966 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1967 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1968 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1969 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1970
1971 let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
1972 let completion_task = zeta.update(cx, |zeta, cx| {
1973 zeta.request_completion(&project, &buffer, cursor, false, cx)
1974 });
1975
1976 let completion = completion_task.await.unwrap().unwrap();
1977 completion
1978 .edits
1979 .iter()
1980 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1981 .collect::<Vec<_>>()
1982 }
1983
1984 fn to_completion_edits(
1985 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1986 buffer: &Entity<Buffer>,
1987 cx: &App,
1988 ) -> Vec<(Range<Anchor>, String)> {
1989 let buffer = buffer.read(cx);
1990 iterator
1991 .into_iter()
1992 .map(|(range, text)| {
1993 (
1994 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1995 text,
1996 )
1997 })
1998 .collect()
1999 }
2000
2001 fn from_completion_edits(
2002 editor_edits: &[(Range<Anchor>, String)],
2003 buffer: &Entity<Buffer>,
2004 cx: &App,
2005 ) -> Vec<(Range<usize>, String)> {
2006 let buffer = buffer.read(cx);
2007 editor_edits
2008 .iter()
2009 .map(|(range, text)| {
2010 (
2011 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2012 text.clone(),
2013 )
2014 })
2015 .collect()
2016 }
2017
2018 #[ctor::ctor]
2019 fn init_logger() {
2020 zlog::init_test();
2021 }
2022}