1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11use edit_prediction::DataCollectionState;
12pub use init::*;
13use license_detection::LicenseDetectionWatcher;
14pub use rate_completion_modal::*;
15
16use anyhow::{Context as _, Result, anyhow};
17use arrayvec::ArrayVec;
18use client::{Client, EditPredictionUsage, UserStore};
19use cloud_llm_client::{
20 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
21 PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
22};
23use collections::{HashMap, HashSet, VecDeque};
24use futures::AsyncReadExt;
25use gpui::{
26 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
27 Subscription, Task, WeakEntity, actions,
28};
29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
30use input_excerpt::excerpt_for_cursor_position;
31use language::{
32 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
33};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use project::{Project, ProjectPath};
36use release_channel::AppVersion;
37use settings::WorktreeId;
38use std::str::FromStr;
39use std::{
40 cmp,
41 fmt::Write,
42 future::Future,
43 mem,
44 ops::Range,
45 path::Path,
46 rc::Rc,
47 sync::Arc,
48 time::{Duration, Instant},
49};
50use telemetry_events::EditPredictionRating;
51use thiserror::Error;
52use util::ResultExt;
53use uuid::Uuid;
54use workspace::Workspace;
55use workspace::notifications::{ErrorMessagePrompt, NotificationId};
56use worktree::Worktree;
57
58const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
59const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
60const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
61const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
62const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
63const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
64
65const MAX_CONTEXT_TOKENS: usize = 150;
66const MAX_REWRITE_TOKENS: usize = 350;
67const MAX_EVENT_TOKENS: usize = 500;
68const MAX_DIAGNOSTIC_GROUPS: usize = 10;
69
70/// Maximum number of events to track.
71const MAX_EVENT_COUNT: usize = 16;
72
73actions!(
74 edit_prediction,
75 [
76 /// Clears the edit prediction history.
77 ClearHistory
78 ]
79);
80
81#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
82pub struct EditPredictionId(Uuid);
83
84impl From<EditPredictionId> for gpui::ElementId {
85 fn from(value: EditPredictionId) -> Self {
86 gpui::ElementId::Uuid(value.0)
87 }
88}
89
90impl std::fmt::Display for EditPredictionId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96struct ZedPredictUpsell;
97
98impl Dismissable for ZedPredictUpsell {
99 const KEY: &'static str = "dismissed-edit-predict-upsell";
100
101 fn dismissed() -> bool {
102 // To make this backwards compatible with older versions of Zed, we
103 // check if the user has seen the previous Edit Prediction Onboarding
104 // before, by checking the data collection choice which was written to
105 // the database once the user clicked on "Accept and Enable"
106 if KEY_VALUE_STORE
107 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
108 .log_err()
109 .is_some_and(|s| s.is_some())
110 {
111 return true;
112 }
113
114 KEY_VALUE_STORE
115 .read_kvp(Self::KEY)
116 .log_err()
117 .is_some_and(|s| s.is_some())
118 }
119}
120
121pub fn should_show_upsell_modal() -> bool {
122 !ZedPredictUpsell::dismissed()
123}
124
125#[derive(Clone)]
126struct ZetaGlobal(Entity<Zeta>);
127
128impl Global for ZetaGlobal {}
129
130#[derive(Clone)]
131pub struct EditPrediction {
132 id: EditPredictionId,
133 path: Arc<Path>,
134 excerpt_range: Range<usize>,
135 cursor_offset: usize,
136 edits: Arc<[(Range<Anchor>, String)]>,
137 snapshot: BufferSnapshot,
138 edit_preview: EditPreview,
139 input_outline: Arc<str>,
140 input_events: Arc<str>,
141 input_excerpt: Arc<str>,
142 output_excerpt: Arc<str>,
143 buffer_snapshotted_at: Instant,
144 response_received_at: Instant,
145}
146
147impl EditPrediction {
148 fn latency(&self) -> Duration {
149 self.response_received_at
150 .duration_since(self.buffer_snapshotted_at)
151 }
152
153 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
154 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
155 }
156}
157
158fn interpolate(
159 old_snapshot: &BufferSnapshot,
160 new_snapshot: &BufferSnapshot,
161 current_edits: Arc<[(Range<Anchor>, String)]>,
162) -> Option<Vec<(Range<Anchor>, String)>> {
163 let mut edits = Vec::new();
164
165 let mut model_edits = current_edits.iter().peekable();
166 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
167 while let Some((model_old_range, _)) = model_edits.peek() {
168 let model_old_range = model_old_range.to_offset(old_snapshot);
169 if model_old_range.end < user_edit.old.start {
170 let (model_old_range, model_new_text) = model_edits.next().unwrap();
171 edits.push((model_old_range.clone(), model_new_text.clone()));
172 } else {
173 break;
174 }
175 }
176
177 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
178 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
179 if user_edit.old == model_old_offset_range {
180 let user_new_text = new_snapshot
181 .text_for_range(user_edit.new.clone())
182 .collect::<String>();
183
184 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
185 if !model_suffix.is_empty() {
186 let anchor = old_snapshot.anchor_after(user_edit.old.end);
187 edits.push((anchor..anchor, model_suffix.to_string()));
188 }
189
190 model_edits.next();
191 continue;
192 }
193 }
194 }
195
196 return None;
197 }
198
199 edits.extend(model_edits.cloned());
200
201 if edits.is_empty() { None } else { Some(edits) }
202}
203
204impl std::fmt::Debug for EditPrediction {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("EditPrediction")
207 .field("id", &self.id)
208 .field("path", &self.path)
209 .field("edits", &self.edits)
210 .finish_non_exhaustive()
211 }
212}
213
214pub struct Zeta {
215 workspace: Option<WeakEntity<Workspace>>,
216 client: Arc<Client>,
217 events: VecDeque<Event>,
218 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
219 shown_completions: VecDeque<EditPrediction>,
220 rated_completions: HashSet<EditPredictionId>,
221 data_collection_choice: Entity<DataCollectionChoice>,
222 llm_token: LlmApiToken,
223 _llm_token_subscription: Subscription,
224 /// Whether an update to a newer version of Zed is required to continue using Zeta.
225 update_required: bool,
226 user_store: Entity<UserStore>,
227 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
228}
229
230impl Zeta {
231 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
232 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
233 }
234
235 pub fn register(
236 workspace: Option<WeakEntity<Workspace>>,
237 worktree: Option<Entity<Worktree>>,
238 client: Arc<Client>,
239 user_store: Entity<UserStore>,
240 cx: &mut App,
241 ) -> Entity<Self> {
242 let this = Self::global(cx).unwrap_or_else(|| {
243 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
244 cx.set_global(ZetaGlobal(entity.clone()));
245 entity
246 });
247
248 this.update(cx, move |this, cx| {
249 if let Some(worktree) = worktree {
250 let worktree_id = worktree.read(cx).id();
251 this.license_detection_watchers
252 .entry(worktree_id)
253 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
254 }
255 });
256
257 this
258 }
259
260 pub fn clear_history(&mut self) {
261 self.events.clear();
262 }
263
264 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
265 self.user_store.read(cx).edit_prediction_usage()
266 }
267
268 fn new(
269 workspace: Option<WeakEntity<Workspace>>,
270 client: Arc<Client>,
271 user_store: Entity<UserStore>,
272 cx: &mut Context<Self>,
273 ) -> Self {
274 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
275
276 let data_collection_choice = Self::load_data_collection_choices();
277 let data_collection_choice = cx.new(|_| data_collection_choice);
278
279 Self {
280 workspace,
281 client,
282 events: VecDeque::new(),
283 shown_completions: VecDeque::new(),
284 rated_completions: HashSet::default(),
285 registered_buffers: HashMap::default(),
286 data_collection_choice,
287 llm_token: LlmApiToken::default(),
288 _llm_token_subscription: cx.subscribe(
289 &refresh_llm_token_listener,
290 |this, _listener, _event, cx| {
291 let client = this.client.clone();
292 let llm_token = this.llm_token.clone();
293 cx.spawn(async move |_this, _cx| {
294 llm_token.refresh(&client).await?;
295 anyhow::Ok(())
296 })
297 .detach_and_log_err(cx);
298 },
299 ),
300 update_required: false,
301 license_detection_watchers: HashMap::default(),
302 user_store,
303 }
304 }
305
306 fn push_event(&mut self, event: Event) {
307 if let Some(Event::BufferChange {
308 new_snapshot: last_new_snapshot,
309 timestamp: last_timestamp,
310 ..
311 }) = self.events.back_mut()
312 {
313 // Coalesce edits for the same buffer when they happen one after the other.
314 let Event::BufferChange {
315 old_snapshot,
316 new_snapshot,
317 timestamp,
318 } = &event;
319
320 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
321 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
322 && old_snapshot.version == last_new_snapshot.version
323 {
324 *last_new_snapshot = new_snapshot.clone();
325 *last_timestamp = *timestamp;
326 return;
327 }
328 }
329
330 self.events.push_back(event);
331 if self.events.len() >= MAX_EVENT_COUNT {
332 // These are halved instead of popping to improve prompt caching.
333 self.events.drain(..MAX_EVENT_COUNT / 2);
334 }
335 }
336
337 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
338 let buffer_id = buffer.entity_id();
339 let weak_buffer = buffer.downgrade();
340
341 if let std::collections::hash_map::Entry::Vacant(entry) =
342 self.registered_buffers.entry(buffer_id)
343 {
344 let snapshot = buffer.read(cx).snapshot();
345
346 entry.insert(RegisteredBuffer {
347 snapshot,
348 _subscriptions: [
349 cx.subscribe(buffer, move |this, buffer, event, cx| {
350 this.handle_buffer_event(buffer, event, cx);
351 }),
352 cx.observe_release(buffer, move |this, _buffer, _cx| {
353 this.registered_buffers.remove(&weak_buffer.entity_id());
354 }),
355 ],
356 });
357 };
358 }
359
360 fn handle_buffer_event(
361 &mut self,
362 buffer: Entity<Buffer>,
363 event: &language::BufferEvent,
364 cx: &mut Context<Self>,
365 ) {
366 if let language::BufferEvent::Edited = event {
367 self.report_changes_for_buffer(&buffer, cx);
368 }
369 }
370
371 fn request_completion_impl<F, R>(
372 &mut self,
373 workspace: Option<Entity<Workspace>>,
374 project: Option<&Entity<Project>>,
375 buffer: &Entity<Buffer>,
376 cursor: language::Anchor,
377 can_collect_data: bool,
378 cx: &mut Context<Self>,
379 perform_predict_edits: F,
380 ) -> Task<Result<Option<EditPrediction>>>
381 where
382 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
383 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
384 + Send
385 + 'static,
386 {
387 let buffer = buffer.clone();
388 let buffer_snapshotted_at = Instant::now();
389 let snapshot = self.report_changes_for_buffer(&buffer, cx);
390 let zeta = cx.entity();
391 let events = self.events.clone();
392 let client = self.client.clone();
393 let llm_token = self.llm_token.clone();
394 let app_version = AppVersion::global(cx);
395
396 let git_info = if let (true, Some(project), Some(file)) =
397 (can_collect_data, project, snapshot.file())
398 {
399 git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
400 } else {
401 None
402 };
403
404 let full_path: Arc<Path> = snapshot
405 .file()
406 .map(|f| Arc::from(f.full_path(cx).as_path()))
407 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
408 let full_path_str = full_path.to_string_lossy().to_string();
409 let cursor_point = cursor.to_point(&snapshot);
410 let cursor_offset = cursor_point.to_offset(&snapshot);
411 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
412 let gather_task = gather_context(
413 project,
414 full_path_str,
415 &snapshot,
416 cursor_point,
417 make_events_prompt,
418 can_collect_data,
419 git_info,
420 cx,
421 );
422
423 cx.spawn(async move |this, cx| {
424 let GatherContextOutput {
425 body,
426 editable_range,
427 } = gather_task.await?;
428 let done_gathering_context_at = Instant::now();
429
430 log::debug!(
431 "Events:\n{}\nExcerpt:\n{:?}",
432 body.input_events,
433 body.input_excerpt
434 );
435
436 let input_outline = body.outline.clone().unwrap_or_default();
437 let input_events = body.input_events.clone();
438 let input_excerpt = body.input_excerpt.clone();
439
440 let response = perform_predict_edits(PerformPredictEditsParams {
441 client,
442 llm_token,
443 app_version,
444 body,
445 })
446 .await;
447 let (response, usage) = match response {
448 Ok(response) => response,
449 Err(err) => {
450 if err.is::<ZedUpdateRequiredError>() {
451 cx.update(|cx| {
452 zeta.update(cx, |zeta, _cx| {
453 zeta.update_required = true;
454 });
455
456 if let Some(workspace) = workspace {
457 workspace.update(cx, |workspace, cx| {
458 workspace.show_notification(
459 NotificationId::unique::<ZedUpdateRequiredError>(),
460 cx,
461 |cx| {
462 cx.new(|cx| {
463 ErrorMessagePrompt::new(err.to_string(), cx)
464 .with_link_button(
465 "Update Zed",
466 "https://zed.dev/releases",
467 )
468 })
469 },
470 );
471 });
472 }
473 })
474 .ok();
475 }
476
477 return Err(err);
478 }
479 };
480
481 let received_response_at = Instant::now();
482 log::debug!("completion response: {}", &response.output_excerpt);
483
484 if let Some(usage) = usage {
485 this.update(cx, |this, cx| {
486 this.user_store.update(cx, |user_store, cx| {
487 user_store.update_edit_prediction_usage(usage, cx);
488 });
489 })
490 .ok();
491 }
492
493 let edit_prediction = Self::process_completion_response(
494 response,
495 buffer,
496 &snapshot,
497 editable_range,
498 cursor_offset,
499 full_path,
500 input_outline,
501 input_events,
502 input_excerpt,
503 buffer_snapshotted_at,
504 cx,
505 )
506 .await;
507
508 let finished_at = Instant::now();
509
510 // record latency for ~1% of requests
511 if rand::random::<u8>() <= 2 {
512 telemetry::event!(
513 "Edit Prediction Request",
514 context_latency = done_gathering_context_at
515 .duration_since(buffer_snapshotted_at)
516 .as_millis(),
517 request_latency = received_response_at
518 .duration_since(done_gathering_context_at)
519 .as_millis(),
520 process_latency = finished_at.duration_since(received_response_at).as_millis()
521 );
522 }
523
524 edit_prediction
525 })
526 }
527
528 // Generates several example completions of various states to fill the Zeta completion modal
529 #[cfg(any(test, feature = "test-support"))]
530 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
531 use language::Point;
532
533 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
534 And maybe a short line
535
536 Then a few lines
537
538 and then another
539 "#};
540
541 let project = None;
542 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
543 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
544
545 let completion_tasks = vec![
546 self.fake_completion(
547 project,
548 &buffer,
549 position,
550 PredictEditsResponse {
551 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
552 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
553a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
554[here's an edit]
555And maybe a short line
556Then a few lines
557and then another
558{EDITABLE_REGION_END_MARKER}
559 ", ),
560 },
561 cx,
562 ),
563 self.fake_completion(
564 project,
565 &buffer,
566 position,
567 PredictEditsResponse {
568 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
569 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
570a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
571And maybe a short line
572[and another edit]
573Then a few lines
574and then another
575{EDITABLE_REGION_END_MARKER}
576 "#),
577 },
578 cx,
579 ),
580 self.fake_completion(
581 project,
582 &buffer,
583 position,
584 PredictEditsResponse {
585 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
586 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
587a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
588And maybe a short line
589
590Then a few lines
591
592and then another
593{EDITABLE_REGION_END_MARKER}
594 "#),
595 },
596 cx,
597 ),
598 self.fake_completion(
599 project,
600 &buffer,
601 position,
602 PredictEditsResponse {
603 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
604 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
605a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
606And maybe a short line
607
608Then a few lines
609
610and then another
611{EDITABLE_REGION_END_MARKER}
612 "#),
613 },
614 cx,
615 ),
616 self.fake_completion(
617 project,
618 &buffer,
619 position,
620 PredictEditsResponse {
621 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
622 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
623a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
624And maybe a short line
625Then a few lines
626[a third completion]
627and then another
628{EDITABLE_REGION_END_MARKER}
629 "#),
630 },
631 cx,
632 ),
633 self.fake_completion(
634 project,
635 &buffer,
636 position,
637 PredictEditsResponse {
638 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
639 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
640a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
641And maybe a short line
642and then another
643[fourth completion example]
644{EDITABLE_REGION_END_MARKER}
645 "#),
646 },
647 cx,
648 ),
649 self.fake_completion(
650 project,
651 &buffer,
652 position,
653 PredictEditsResponse {
654 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
655 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
656a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
657And maybe a short line
658Then a few lines
659and then another
660[fifth and final completion]
661{EDITABLE_REGION_END_MARKER}
662 "#),
663 },
664 cx,
665 ),
666 ];
667
668 cx.spawn(async move |zeta, cx| {
669 for task in completion_tasks {
670 task.await.unwrap();
671 }
672
673 zeta.update(cx, |zeta, _cx| {
674 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
675 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
676 })
677 .ok();
678 })
679 }
680
681 #[cfg(any(test, feature = "test-support"))]
682 pub fn fake_completion(
683 &mut self,
684 project: Option<&Entity<Project>>,
685 buffer: &Entity<Buffer>,
686 position: language::Anchor,
687 response: PredictEditsResponse,
688 cx: &mut Context<Self>,
689 ) -> Task<Result<Option<EditPrediction>>> {
690 use std::future::ready;
691
692 self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
693 ready(Ok((response, None)))
694 })
695 }
696
697 pub fn request_completion(
698 &mut self,
699 project: Option<&Entity<Project>>,
700 buffer: &Entity<Buffer>,
701 position: language::Anchor,
702 can_collect_data: bool,
703 cx: &mut Context<Self>,
704 ) -> Task<Result<Option<EditPrediction>>> {
705 let workspace = self
706 .workspace
707 .as_ref()
708 .and_then(|workspace| workspace.upgrade());
709 self.request_completion_impl(
710 workspace,
711 project,
712 buffer,
713 position,
714 can_collect_data,
715 cx,
716 Self::perform_predict_edits,
717 )
718 }
719
720 pub fn perform_predict_edits(
721 params: PerformPredictEditsParams,
722 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
723 async move {
724 let PerformPredictEditsParams {
725 client,
726 llm_token,
727 app_version,
728 body,
729 ..
730 } = params;
731
732 let http_client = client.http_client();
733 let mut token = llm_token.acquire(&client).await?;
734 let mut did_retry = false;
735
736 loop {
737 let request_builder = http_client::Request::builder().method(Method::POST);
738 let request_builder =
739 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
740 request_builder.uri(predict_edits_url)
741 } else {
742 request_builder.uri(
743 http_client
744 .build_zed_llm_url("/predict_edits/v2", &[])?
745 .as_ref(),
746 )
747 };
748 let request = request_builder
749 .header("Content-Type", "application/json")
750 .header("Authorization", format!("Bearer {}", token))
751 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
752 .body(serde_json::to_string(&body)?.into())?;
753
754 let mut response = http_client.send(request).await?;
755
756 if let Some(minimum_required_version) = response
757 .headers()
758 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
759 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
760 {
761 anyhow::ensure!(
762 app_version >= minimum_required_version,
763 ZedUpdateRequiredError {
764 minimum_version: minimum_required_version
765 }
766 );
767 }
768
769 if response.status().is_success() {
770 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
771
772 let mut body = String::new();
773 response.body_mut().read_to_string(&mut body).await?;
774 return Ok((serde_json::from_str(&body)?, usage));
775 } else if !did_retry
776 && response
777 .headers()
778 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
779 .is_some()
780 {
781 did_retry = true;
782 token = llm_token.refresh(&client).await?;
783 } else {
784 let mut body = String::new();
785 response.body_mut().read_to_string(&mut body).await?;
786 anyhow::bail!(
787 "error predicting edits.\nStatus: {:?}\nBody: {}",
788 response.status(),
789 body
790 );
791 }
792 }
793 }
794 }
795
796 fn accept_edit_prediction(
797 &mut self,
798 request_id: EditPredictionId,
799 cx: &mut Context<Self>,
800 ) -> Task<Result<()>> {
801 let client = self.client.clone();
802 let llm_token = self.llm_token.clone();
803 let app_version = AppVersion::global(cx);
804 cx.spawn(async move |this, cx| {
805 let http_client = client.http_client();
806 let mut response = llm_token_retry(&llm_token, &client, |token| {
807 let request_builder = http_client::Request::builder().method(Method::POST);
808 let request_builder =
809 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
810 request_builder.uri(accept_prediction_url)
811 } else {
812 request_builder.uri(
813 http_client
814 .build_zed_llm_url("/predict_edits/accept", &[])?
815 .as_ref(),
816 )
817 };
818 Ok(request_builder
819 .header("Content-Type", "application/json")
820 .header("Authorization", format!("Bearer {}", token))
821 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
822 .body(
823 serde_json::to_string(&AcceptEditPredictionBody {
824 request_id: request_id.0,
825 })?
826 .into(),
827 )?)
828 })
829 .await?;
830
831 if let Some(minimum_required_version) = response
832 .headers()
833 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
834 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
835 && app_version < minimum_required_version
836 {
837 return Err(anyhow!(ZedUpdateRequiredError {
838 minimum_version: minimum_required_version
839 }));
840 }
841
842 if response.status().is_success() {
843 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
844 this.update(cx, |this, cx| {
845 this.user_store.update(cx, |user_store, cx| {
846 user_store.update_edit_prediction_usage(usage, cx);
847 });
848 })?;
849 }
850
851 Ok(())
852 } else {
853 let mut body = String::new();
854 response.body_mut().read_to_string(&mut body).await?;
855 Err(anyhow!(
856 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
857 response.status(),
858 body
859 ))
860 }
861 })
862 }
863
864 fn process_completion_response(
865 prediction_response: PredictEditsResponse,
866 buffer: Entity<Buffer>,
867 snapshot: &BufferSnapshot,
868 editable_range: Range<usize>,
869 cursor_offset: usize,
870 path: Arc<Path>,
871 input_outline: String,
872 input_events: String,
873 input_excerpt: String,
874 buffer_snapshotted_at: Instant,
875 cx: &AsyncApp,
876 ) -> Task<Result<Option<EditPrediction>>> {
877 let snapshot = snapshot.clone();
878 let request_id = prediction_response.request_id;
879 let output_excerpt = prediction_response.output_excerpt;
880 cx.spawn(async move |cx| {
881 let output_excerpt: Arc<str> = output_excerpt.into();
882
883 let edits: Arc<[(Range<Anchor>, String)]> = cx
884 .background_spawn({
885 let output_excerpt = output_excerpt.clone();
886 let editable_range = editable_range.clone();
887 let snapshot = snapshot.clone();
888 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
889 })
890 .await?
891 .into();
892
893 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
894 let edits = edits.clone();
895 |buffer, cx| {
896 let new_snapshot = buffer.snapshot();
897 let edits: Arc<[(Range<Anchor>, String)]> =
898 interpolate(&snapshot, &new_snapshot, edits)?.into();
899 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
900 }
901 })?
902 else {
903 return anyhow::Ok(None);
904 };
905
906 let edit_preview = edit_preview.await;
907
908 Ok(Some(EditPrediction {
909 id: EditPredictionId(request_id),
910 path,
911 excerpt_range: editable_range,
912 cursor_offset,
913 edits,
914 edit_preview,
915 snapshot,
916 input_outline: input_outline.into(),
917 input_events: input_events.into(),
918 input_excerpt: input_excerpt.into(),
919 output_excerpt,
920 buffer_snapshotted_at,
921 response_received_at: Instant::now(),
922 }))
923 })
924 }
925
926 fn parse_edits(
927 output_excerpt: Arc<str>,
928 editable_range: Range<usize>,
929 snapshot: &BufferSnapshot,
930 ) -> Result<Vec<(Range<Anchor>, String)>> {
931 let content = output_excerpt.replace(CURSOR_MARKER, "");
932
933 let start_markers = content
934 .match_indices(EDITABLE_REGION_START_MARKER)
935 .collect::<Vec<_>>();
936 anyhow::ensure!(
937 start_markers.len() == 1,
938 "expected exactly one start marker, found {}",
939 start_markers.len()
940 );
941
942 let end_markers = content
943 .match_indices(EDITABLE_REGION_END_MARKER)
944 .collect::<Vec<_>>();
945 anyhow::ensure!(
946 end_markers.len() == 1,
947 "expected exactly one end marker, found {}",
948 end_markers.len()
949 );
950
951 let sof_markers = content
952 .match_indices(START_OF_FILE_MARKER)
953 .collect::<Vec<_>>();
954 anyhow::ensure!(
955 sof_markers.len() <= 1,
956 "expected at most one start-of-file marker, found {}",
957 sof_markers.len()
958 );
959
960 let codefence_start = start_markers[0].0;
961 let content = &content[codefence_start..];
962
963 let newline_ix = content.find('\n').context("could not find newline")?;
964 let content = &content[newline_ix + 1..];
965
966 let codefence_end = content
967 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
968 .context("could not find end marker")?;
969 let new_text = &content[..codefence_end];
970
971 let old_text = snapshot
972 .text_for_range(editable_range.clone())
973 .collect::<String>();
974
975 Ok(Self::compute_edits(
976 old_text,
977 new_text,
978 editable_range.start,
979 snapshot,
980 ))
981 }
982
983 pub fn compute_edits(
984 old_text: String,
985 new_text: &str,
986 offset: usize,
987 snapshot: &BufferSnapshot,
988 ) -> Vec<(Range<Anchor>, String)> {
989 text_diff(&old_text, new_text)
990 .into_iter()
991 .map(|(mut old_range, new_text)| {
992 old_range.start += offset;
993 old_range.end += offset;
994
995 let prefix_len = common_prefix(
996 snapshot.chars_for_range(old_range.clone()),
997 new_text.chars(),
998 );
999 old_range.start += prefix_len;
1000
1001 let suffix_len = common_prefix(
1002 snapshot.reversed_chars_for_range(old_range.clone()),
1003 new_text[prefix_len..].chars().rev(),
1004 );
1005 old_range.end = old_range.end.saturating_sub(suffix_len);
1006
1007 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1008 let range = if old_range.is_empty() {
1009 let anchor = snapshot.anchor_after(old_range.start);
1010 anchor..anchor
1011 } else {
1012 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1013 };
1014 (range, new_text)
1015 })
1016 .collect()
1017 }
1018
1019 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1020 self.rated_completions.contains(&completion_id)
1021 }
1022
1023 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1024 self.shown_completions.push_front(completion.clone());
1025 if self.shown_completions.len() > 50 {
1026 let completion = self.shown_completions.pop_back().unwrap();
1027 self.rated_completions.remove(&completion.id);
1028 }
1029 cx.notify();
1030 }
1031
1032 pub fn rate_completion(
1033 &mut self,
1034 completion: &EditPrediction,
1035 rating: EditPredictionRating,
1036 feedback: String,
1037 cx: &mut Context<Self>,
1038 ) {
1039 self.rated_completions.insert(completion.id);
1040 telemetry::event!(
1041 "Edit Prediction Rated",
1042 rating,
1043 input_events = completion.input_events,
1044 input_excerpt = completion.input_excerpt,
1045 input_outline = completion.input_outline,
1046 output_excerpt = completion.output_excerpt,
1047 feedback
1048 );
1049 self.client.telemetry().flush_events().detach();
1050 cx.notify();
1051 }
1052
1053 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1054 self.shown_completions.iter()
1055 }
1056
1057 pub fn shown_completions_len(&self) -> usize {
1058 self.shown_completions.len()
1059 }
1060
1061 fn report_changes_for_buffer(
1062 &mut self,
1063 buffer: &Entity<Buffer>,
1064 cx: &mut Context<Self>,
1065 ) -> BufferSnapshot {
1066 self.register_buffer(buffer, cx);
1067
1068 let registered_buffer = self
1069 .registered_buffers
1070 .get_mut(&buffer.entity_id())
1071 .unwrap();
1072 let new_snapshot = buffer.read(cx).snapshot();
1073
1074 if new_snapshot.version != registered_buffer.snapshot.version {
1075 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1076 self.push_event(Event::BufferChange {
1077 old_snapshot,
1078 new_snapshot: new_snapshot.clone(),
1079 timestamp: Instant::now(),
1080 });
1081 }
1082
1083 new_snapshot
1084 }
1085
1086 fn load_data_collection_choices() -> DataCollectionChoice {
1087 let choice = KEY_VALUE_STORE
1088 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1089 .log_err()
1090 .flatten();
1091
1092 match choice.as_deref() {
1093 Some("true") => DataCollectionChoice::Enabled,
1094 Some("false") => DataCollectionChoice::Disabled,
1095 Some(_) => {
1096 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1097 DataCollectionChoice::NotAnswered
1098 }
1099 None => DataCollectionChoice::NotAnswered,
1100 }
1101 }
1102}
1103
1104pub struct PerformPredictEditsParams {
1105 pub client: Arc<Client>,
1106 pub llm_token: LlmApiToken,
1107 pub app_version: SemanticVersion,
1108 pub body: PredictEditsBody,
1109}
1110
1111#[derive(Error, Debug)]
1112#[error(
1113 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1114)]
1115pub struct ZedUpdateRequiredError {
1116 minimum_version: SemanticVersion,
1117}
1118
1119fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1120 a.zip(b)
1121 .take_while(|(a, b)| a == b)
1122 .map(|(a, _)| a.len_utf8())
1123 .sum()
1124}
1125
1126fn git_info_for_file(
1127 project: &Entity<Project>,
1128 project_path: &ProjectPath,
1129 cx: &App,
1130) -> Option<PredictEditsGitInfo> {
1131 let git_store = project.read(cx).git_store().read(cx);
1132 if let Some((repository, _repo_path)) =
1133 git_store.repository_and_path_for_project_path(project_path, cx)
1134 {
1135 let repository = repository.read(cx);
1136 let head_sha = repository
1137 .head_commit
1138 .as_ref()
1139 .map(|head_commit| head_commit.sha.to_string());
1140 let remote_origin_url = repository.remote_origin_url.clone();
1141 let remote_upstream_url = repository.remote_upstream_url.clone();
1142 if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1143 return None;
1144 }
1145 Some(PredictEditsGitInfo {
1146 head_sha,
1147 remote_origin_url,
1148 remote_upstream_url,
1149 })
1150 } else {
1151 None
1152 }
1153}
1154
1155pub struct GatherContextOutput {
1156 pub body: PredictEditsBody,
1157 pub editable_range: Range<usize>,
1158}
1159
1160pub fn gather_context(
1161 project: Option<&Entity<Project>>,
1162 full_path_str: String,
1163 snapshot: &BufferSnapshot,
1164 cursor_point: language::Point,
1165 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1166 can_collect_data: bool,
1167 git_info: Option<PredictEditsGitInfo>,
1168 cx: &App,
1169) -> Task<Result<GatherContextOutput>> {
1170 let local_lsp_store =
1171 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1172 let diagnostic_groups: Vec<(String, serde_json::Value)> =
1173 if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1174 snapshot
1175 .diagnostic_groups(None)
1176 .into_iter()
1177 .filter_map(|(language_server_id, diagnostic_group)| {
1178 let language_server =
1179 local_lsp_store.running_language_server_for_id(language_server_id)?;
1180 let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1181 let language_server_name = language_server.name().to_string();
1182 let serialized = serde_json::to_value(diagnostic_group).unwrap();
1183 Some((language_server_name, serialized))
1184 })
1185 .collect::<Vec<_>>()
1186 } else {
1187 Vec::new()
1188 };
1189
1190 cx.background_spawn({
1191 let snapshot = snapshot.clone();
1192 async move {
1193 let diagnostic_groups = if diagnostic_groups.is_empty()
1194 || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1195 {
1196 None
1197 } else {
1198 Some(diagnostic_groups)
1199 };
1200
1201 let input_excerpt = excerpt_for_cursor_position(
1202 cursor_point,
1203 &full_path_str,
1204 &snapshot,
1205 MAX_REWRITE_TOKENS,
1206 MAX_CONTEXT_TOKENS,
1207 );
1208 let input_events = make_events_prompt();
1209 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1210
1211 let body = PredictEditsBody {
1212 input_events,
1213 input_excerpt: input_excerpt.prompt,
1214 can_collect_data,
1215 diagnostic_groups,
1216 git_info,
1217 outline: None,
1218 speculated_output: None,
1219 };
1220
1221 Ok(GatherContextOutput {
1222 body,
1223 editable_range,
1224 })
1225 }
1226 })
1227}
1228
1229fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1230 let mut result = String::new();
1231 for event in events.iter().rev() {
1232 let event_string = event.to_prompt();
1233 let event_tokens = tokens_for_bytes(event_string.len());
1234 if event_tokens > remaining_tokens {
1235 break;
1236 }
1237
1238 if !result.is_empty() {
1239 result.insert_str(0, "\n\n");
1240 }
1241 result.insert_str(0, &event_string);
1242 remaining_tokens -= event_tokens;
1243 }
1244 result
1245}
1246
1247struct RegisteredBuffer {
1248 snapshot: BufferSnapshot,
1249 _subscriptions: [gpui::Subscription; 2],
1250}
1251
1252#[derive(Clone)]
1253pub enum Event {
1254 BufferChange {
1255 old_snapshot: BufferSnapshot,
1256 new_snapshot: BufferSnapshot,
1257 timestamp: Instant,
1258 },
1259}
1260
1261impl Event {
1262 fn to_prompt(&self) -> String {
1263 match self {
1264 Event::BufferChange {
1265 old_snapshot,
1266 new_snapshot,
1267 ..
1268 } => {
1269 let mut prompt = String::new();
1270
1271 let old_path = old_snapshot
1272 .file()
1273 .map(|f| f.path().as_ref())
1274 .unwrap_or(Path::new("untitled"));
1275 let new_path = new_snapshot
1276 .file()
1277 .map(|f| f.path().as_ref())
1278 .unwrap_or(Path::new("untitled"));
1279 if old_path != new_path {
1280 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1281 }
1282
1283 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1284 if !diff.is_empty() {
1285 write!(
1286 prompt,
1287 "User edited {:?}:\n```diff\n{}\n```",
1288 new_path, diff
1289 )
1290 .unwrap();
1291 }
1292
1293 prompt
1294 }
1295 }
1296 }
1297}
1298
1299#[derive(Debug, Clone)]
1300struct CurrentEditPrediction {
1301 buffer_id: EntityId,
1302 completion: EditPrediction,
1303}
1304
1305impl CurrentEditPrediction {
1306 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1307 if self.buffer_id != old_completion.buffer_id {
1308 return true;
1309 }
1310
1311 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1312 return true;
1313 };
1314 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1315 return false;
1316 };
1317
1318 if old_edits.len() == 1 && new_edits.len() == 1 {
1319 let (old_range, old_text) = &old_edits[0];
1320 let (new_range, new_text) = &new_edits[0];
1321 new_range == old_range && new_text.starts_with(old_text)
1322 } else {
1323 true
1324 }
1325 }
1326}
1327
1328struct PendingCompletion {
1329 id: usize,
1330 _task: Task<()>,
1331}
1332
1333#[derive(Debug, Clone, Copy)]
1334pub enum DataCollectionChoice {
1335 NotAnswered,
1336 Enabled,
1337 Disabled,
1338}
1339
1340impl DataCollectionChoice {
1341 pub fn is_enabled(self) -> bool {
1342 match self {
1343 Self::Enabled => true,
1344 Self::NotAnswered | Self::Disabled => false,
1345 }
1346 }
1347
1348 pub fn is_answered(self) -> bool {
1349 match self {
1350 Self::Enabled | Self::Disabled => true,
1351 Self::NotAnswered => false,
1352 }
1353 }
1354
1355 pub fn toggle(&self) -> DataCollectionChoice {
1356 match self {
1357 Self::Enabled => Self::Disabled,
1358 Self::Disabled => Self::Enabled,
1359 Self::NotAnswered => Self::Enabled,
1360 }
1361 }
1362}
1363
1364impl From<bool> for DataCollectionChoice {
1365 fn from(value: bool) -> Self {
1366 match value {
1367 true => DataCollectionChoice::Enabled,
1368 false => DataCollectionChoice::Disabled,
1369 }
1370 }
1371}
1372
1373pub struct ProviderDataCollection {
1374 /// When set to None, data collection is not possible in the provider buffer
1375 choice: Option<Entity<DataCollectionChoice>>,
1376 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1377}
1378
1379impl ProviderDataCollection {
1380 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1381 let choice_and_watcher = buffer.and_then(|buffer| {
1382 let file = buffer.read(cx).file()?;
1383
1384 if !file.is_local() || file.is_private() {
1385 return None;
1386 }
1387
1388 let zeta = zeta.read(cx);
1389 let choice = zeta.data_collection_choice.clone();
1390
1391 let license_detection_watcher = zeta
1392 .license_detection_watchers
1393 .get(&file.worktree_id(cx))
1394 .cloned()?;
1395
1396 Some((choice, license_detection_watcher))
1397 });
1398
1399 if let Some((choice, watcher)) = choice_and_watcher {
1400 ProviderDataCollection {
1401 choice: Some(choice),
1402 license_detection_watcher: Some(watcher),
1403 }
1404 } else {
1405 ProviderDataCollection {
1406 choice: None,
1407 license_detection_watcher: None,
1408 }
1409 }
1410 }
1411
1412 pub fn can_collect_data(&self, cx: &App) -> bool {
1413 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1414 }
1415
1416 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1417 self.choice
1418 .as_ref()
1419 .is_some_and(|choice| choice.read(cx).is_enabled())
1420 }
1421
1422 fn is_project_open_source(&self) -> bool {
1423 self.license_detection_watcher
1424 .as_ref()
1425 .is_some_and(|watcher| watcher.is_project_open_source())
1426 }
1427
1428 pub fn toggle(&mut self, cx: &mut App) {
1429 if let Some(choice) = self.choice.as_mut() {
1430 let new_choice = choice.update(cx, |choice, _cx| {
1431 let new_choice = choice.toggle();
1432 *choice = new_choice;
1433 new_choice
1434 });
1435
1436 db::write_and_log(cx, move || {
1437 KEY_VALUE_STORE.write_kvp(
1438 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1439 new_choice.is_enabled().to_string(),
1440 )
1441 });
1442 }
1443 }
1444}
1445
1446async fn llm_token_retry(
1447 llm_token: &LlmApiToken,
1448 client: &Arc<Client>,
1449 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1450) -> Result<Response<AsyncBody>> {
1451 let mut did_retry = false;
1452 let http_client = client.http_client();
1453 let mut token = llm_token.acquire(client).await?;
1454 loop {
1455 let request = build_request(token.clone())?;
1456 let response = http_client.send(request).await?;
1457
1458 if !did_retry
1459 && !response.status().is_success()
1460 && response
1461 .headers()
1462 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1463 .is_some()
1464 {
1465 did_retry = true;
1466 token = llm_token.refresh(client).await?;
1467 continue;
1468 }
1469
1470 return Ok(response);
1471 }
1472}
1473
1474pub struct ZetaEditPredictionProvider {
1475 zeta: Entity<Zeta>,
1476 pending_completions: ArrayVec<PendingCompletion, 2>,
1477 next_pending_completion_id: usize,
1478 current_completion: Option<CurrentEditPrediction>,
1479 /// None if this is entirely disabled for this provider
1480 provider_data_collection: ProviderDataCollection,
1481 last_request_timestamp: Instant,
1482}
1483
1484impl ZetaEditPredictionProvider {
1485 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1486
1487 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1488 Self {
1489 zeta,
1490 pending_completions: ArrayVec::new(),
1491 next_pending_completion_id: 0,
1492 current_completion: None,
1493 provider_data_collection,
1494 last_request_timestamp: Instant::now(),
1495 }
1496 }
1497}
1498
1499impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1500 fn name() -> &'static str {
1501 "zed-predict"
1502 }
1503
1504 fn display_name() -> &'static str {
1505 "Zed's Edit Predictions"
1506 }
1507
1508 fn show_completions_in_menu() -> bool {
1509 true
1510 }
1511
1512 fn show_tab_accept_marker() -> bool {
1513 true
1514 }
1515
1516 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1517 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1518
1519 if self.provider_data_collection.is_data_collection_enabled(cx) {
1520 DataCollectionState::Enabled {
1521 is_project_open_source,
1522 }
1523 } else {
1524 DataCollectionState::Disabled {
1525 is_project_open_source,
1526 }
1527 }
1528 }
1529
1530 fn toggle_data_collection(&mut self, cx: &mut App) {
1531 self.provider_data_collection.toggle(cx);
1532 }
1533
1534 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1535 self.zeta.read(cx).usage(cx)
1536 }
1537
1538 fn is_enabled(
1539 &self,
1540 _buffer: &Entity<Buffer>,
1541 _cursor_position: language::Anchor,
1542 _cx: &App,
1543 ) -> bool {
1544 true
1545 }
1546 fn is_refreshing(&self) -> bool {
1547 !self.pending_completions.is_empty()
1548 }
1549
1550 fn refresh(
1551 &mut self,
1552 project: Option<Entity<Project>>,
1553 buffer: Entity<Buffer>,
1554 position: language::Anchor,
1555 _debounce: bool,
1556 cx: &mut Context<Self>,
1557 ) {
1558 if self.zeta.read(cx).update_required {
1559 return;
1560 }
1561
1562 if self
1563 .zeta
1564 .read(cx)
1565 .user_store
1566 .read_with(cx, |user_store, _cx| {
1567 user_store.account_too_young() || user_store.has_overdue_invoices()
1568 })
1569 {
1570 return;
1571 }
1572
1573 if let Some(current_completion) = self.current_completion.as_ref() {
1574 let snapshot = buffer.read(cx).snapshot();
1575 if current_completion
1576 .completion
1577 .interpolate(&snapshot)
1578 .is_some()
1579 {
1580 return;
1581 }
1582 }
1583
1584 let pending_completion_id = self.next_pending_completion_id;
1585 self.next_pending_completion_id += 1;
1586 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1587 let last_request_timestamp = self.last_request_timestamp;
1588
1589 let task = cx.spawn(async move |this, cx| {
1590 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1591 .checked_duration_since(Instant::now())
1592 {
1593 cx.background_executor().timer(timeout).await;
1594 }
1595
1596 let completion_request = this.update(cx, |this, cx| {
1597 this.last_request_timestamp = Instant::now();
1598 this.zeta.update(cx, |zeta, cx| {
1599 zeta.request_completion(
1600 project.as_ref(),
1601 &buffer,
1602 position,
1603 can_collect_data,
1604 cx,
1605 )
1606 })
1607 });
1608
1609 let completion = match completion_request {
1610 Ok(completion_request) => {
1611 let completion_request = completion_request.await;
1612 completion_request.map(|c| {
1613 c.map(|completion| CurrentEditPrediction {
1614 buffer_id: buffer.entity_id(),
1615 completion,
1616 })
1617 })
1618 }
1619 Err(error) => Err(error),
1620 };
1621 let Some(new_completion) = completion
1622 .context("edit prediction failed")
1623 .log_err()
1624 .flatten()
1625 else {
1626 this.update(cx, |this, cx| {
1627 if this.pending_completions[0].id == pending_completion_id {
1628 this.pending_completions.remove(0);
1629 } else {
1630 this.pending_completions.clear();
1631 }
1632
1633 cx.notify();
1634 })
1635 .ok();
1636 return;
1637 };
1638
1639 this.update(cx, |this, cx| {
1640 if this.pending_completions[0].id == pending_completion_id {
1641 this.pending_completions.remove(0);
1642 } else {
1643 this.pending_completions.clear();
1644 }
1645
1646 if let Some(old_completion) = this.current_completion.as_ref() {
1647 let snapshot = buffer.read(cx).snapshot();
1648 if new_completion.should_replace_completion(old_completion, &snapshot) {
1649 this.zeta.update(cx, |zeta, cx| {
1650 zeta.completion_shown(&new_completion.completion, cx);
1651 });
1652 this.current_completion = Some(new_completion);
1653 }
1654 } else {
1655 this.zeta.update(cx, |zeta, cx| {
1656 zeta.completion_shown(&new_completion.completion, cx);
1657 });
1658 this.current_completion = Some(new_completion);
1659 }
1660
1661 cx.notify();
1662 })
1663 .ok();
1664 });
1665
1666 // We always maintain at most two pending completions. When we already
1667 // have two, we replace the newest one.
1668 if self.pending_completions.len() <= 1 {
1669 self.pending_completions.push(PendingCompletion {
1670 id: pending_completion_id,
1671 _task: task,
1672 });
1673 } else if self.pending_completions.len() == 2 {
1674 self.pending_completions.pop();
1675 self.pending_completions.push(PendingCompletion {
1676 id: pending_completion_id,
1677 _task: task,
1678 });
1679 }
1680 }
1681
1682 fn cycle(
1683 &mut self,
1684 _buffer: Entity<Buffer>,
1685 _cursor_position: language::Anchor,
1686 _direction: edit_prediction::Direction,
1687 _cx: &mut Context<Self>,
1688 ) {
1689 // Right now we don't support cycling.
1690 }
1691
1692 fn accept(&mut self, cx: &mut Context<Self>) {
1693 let completion_id = self
1694 .current_completion
1695 .as_ref()
1696 .map(|completion| completion.completion.id);
1697 if let Some(completion_id) = completion_id {
1698 self.zeta
1699 .update(cx, |zeta, cx| {
1700 zeta.accept_edit_prediction(completion_id, cx)
1701 })
1702 .detach();
1703 }
1704 self.pending_completions.clear();
1705 }
1706
1707 fn discard(&mut self, _cx: &mut Context<Self>) {
1708 self.pending_completions.clear();
1709 self.current_completion.take();
1710 }
1711
1712 fn suggest(
1713 &mut self,
1714 buffer: &Entity<Buffer>,
1715 cursor_position: language::Anchor,
1716 cx: &mut Context<Self>,
1717 ) -> Option<edit_prediction::EditPrediction> {
1718 let CurrentEditPrediction {
1719 buffer_id,
1720 completion,
1721 ..
1722 } = self.current_completion.as_mut()?;
1723
1724 // Invalidate previous completion if it was generated for a different buffer.
1725 if *buffer_id != buffer.entity_id() {
1726 self.current_completion.take();
1727 return None;
1728 }
1729
1730 let buffer = buffer.read(cx);
1731 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1732 self.current_completion.take();
1733 return None;
1734 };
1735
1736 let cursor_row = cursor_position.to_point(buffer).row;
1737 let (closest_edit_ix, (closest_edit_range, _)) =
1738 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1739 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1740 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1741 cmp::min(distance_from_start, distance_from_end)
1742 })?;
1743
1744 let mut edit_start_ix = closest_edit_ix;
1745 for (range, _) in edits[..edit_start_ix].iter().rev() {
1746 let distance_from_closest_edit =
1747 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1748 if distance_from_closest_edit <= 1 {
1749 edit_start_ix -= 1;
1750 } else {
1751 break;
1752 }
1753 }
1754
1755 let mut edit_end_ix = closest_edit_ix + 1;
1756 for (range, _) in &edits[edit_end_ix..] {
1757 let distance_from_closest_edit =
1758 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1759 if distance_from_closest_edit <= 1 {
1760 edit_end_ix += 1;
1761 } else {
1762 break;
1763 }
1764 }
1765
1766 Some(edit_prediction::EditPrediction {
1767 id: Some(completion.id.to_string().into()),
1768 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1769 edit_preview: Some(completion.edit_preview.clone()),
1770 })
1771 }
1772}
1773
1774fn tokens_for_bytes(bytes: usize) -> usize {
1775 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1776 /// intentionally low to err on the side of underestimating limits.
1777 const BYTES_PER_TOKEN_GUESS: usize = 3;
1778 bytes / BYTES_PER_TOKEN_GUESS
1779}
1780
1781#[cfg(test)]
1782mod tests {
1783 use client::UserStore;
1784 use client::test::FakeServer;
1785 use clock::FakeSystemClock;
1786 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1787 use gpui::TestAppContext;
1788 use http_client::FakeHttpClient;
1789 use indoc::indoc;
1790 use language::Point;
1791 use settings::SettingsStore;
1792
1793 use super::*;
1794
1795 #[gpui::test]
1796 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1797 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1798 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1799 to_completion_edits(
1800 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1801 &buffer,
1802 cx,
1803 )
1804 .into()
1805 });
1806
1807 let edit_preview = cx
1808 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1809 .await;
1810
1811 let completion = EditPrediction {
1812 edits,
1813 edit_preview,
1814 path: Path::new("").into(),
1815 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1816 id: EditPredictionId(Uuid::new_v4()),
1817 excerpt_range: 0..0,
1818 cursor_offset: 0,
1819 input_outline: "".into(),
1820 input_events: "".into(),
1821 input_excerpt: "".into(),
1822 output_excerpt: "".into(),
1823 buffer_snapshotted_at: Instant::now(),
1824 response_received_at: Instant::now(),
1825 };
1826
1827 cx.update(|cx| {
1828 assert_eq!(
1829 from_completion_edits(
1830 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1831 &buffer,
1832 cx
1833 ),
1834 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1835 );
1836
1837 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1838 assert_eq!(
1839 from_completion_edits(
1840 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1841 &buffer,
1842 cx
1843 ),
1844 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1845 );
1846
1847 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1848 assert_eq!(
1849 from_completion_edits(
1850 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1851 &buffer,
1852 cx
1853 ),
1854 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1855 );
1856
1857 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1858 assert_eq!(
1859 from_completion_edits(
1860 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1861 &buffer,
1862 cx
1863 ),
1864 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1865 );
1866
1867 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1868 assert_eq!(
1869 from_completion_edits(
1870 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1871 &buffer,
1872 cx
1873 ),
1874 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1875 );
1876
1877 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1878 assert_eq!(
1879 from_completion_edits(
1880 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1881 &buffer,
1882 cx
1883 ),
1884 vec![(9..11, "".to_string())]
1885 );
1886
1887 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1888 assert_eq!(
1889 from_completion_edits(
1890 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1891 &buffer,
1892 cx
1893 ),
1894 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1895 );
1896
1897 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1898 assert_eq!(
1899 from_completion_edits(
1900 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1901 &buffer,
1902 cx
1903 ),
1904 vec![(4..4, "M".to_string())]
1905 );
1906
1907 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1908 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1909 })
1910 }
1911
1912 #[gpui::test]
1913 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1914 cx.update(|cx| {
1915 let settings_store = SettingsStore::test(cx);
1916 cx.set_global(settings_store);
1917 client::init_settings(cx);
1918 });
1919
1920 let edits = edits_for_prediction(
1921 indoc! {"
1922 fn main() {
1923 let word_1 = \"lorem\";
1924 let range = word.len()..word.len();
1925 }
1926 "},
1927 indoc! {"
1928 <|editable_region_start|>
1929 fn main() {
1930 let word_1 = \"lorem\";
1931 let range = word_1.len()..word_1.len();
1932 }
1933
1934 <|editable_region_end|>
1935 "},
1936 cx,
1937 )
1938 .await;
1939 assert_eq!(
1940 edits,
1941 [
1942 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1943 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1944 ]
1945 );
1946
1947 let edits = edits_for_prediction(
1948 indoc! {"
1949 fn main() {
1950 let story = \"the quick\"
1951 }
1952 "},
1953 indoc! {"
1954 <|editable_region_start|>
1955 fn main() {
1956 let story = \"the quick brown fox jumps over the lazy dog\";
1957 }
1958
1959 <|editable_region_end|>
1960 "},
1961 cx,
1962 )
1963 .await;
1964 assert_eq!(
1965 edits,
1966 [
1967 (
1968 Point::new(1, 26)..Point::new(1, 26),
1969 " brown fox jumps over the lazy dog".to_string()
1970 ),
1971 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1972 ]
1973 );
1974 }
1975
1976 #[gpui::test]
1977 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1978 cx.update(|cx| {
1979 let settings_store = SettingsStore::test(cx);
1980 cx.set_global(settings_store);
1981 client::init_settings(cx);
1982 });
1983
1984 let buffer_content = "lorem\n";
1985 let completion_response = indoc! {"
1986 ```animals.js
1987 <|start_of_file|>
1988 <|editable_region_start|>
1989 lorem
1990 ipsum
1991 <|editable_region_end|>
1992 ```"};
1993
1994 let http_client = FakeHttpClient::create(move |req| async move {
1995 match (req.method(), req.uri().path()) {
1996 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1997 .status(200)
1998 .body(
1999 serde_json::to_string(&CreateLlmTokenResponse {
2000 token: LlmToken("the-llm-token".to_string()),
2001 })
2002 .unwrap()
2003 .into(),
2004 )
2005 .unwrap()),
2006 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2007 .status(200)
2008 .body(
2009 serde_json::to_string(&PredictEditsResponse {
2010 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2011 .unwrap(),
2012 output_excerpt: completion_response.to_string(),
2013 })
2014 .unwrap()
2015 .into(),
2016 )
2017 .unwrap()),
2018 _ => Ok(http_client::Response::builder()
2019 .status(404)
2020 .body("Not Found".into())
2021 .unwrap()),
2022 }
2023 });
2024
2025 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2026 cx.update(|cx| {
2027 RefreshLlmTokenListener::register(client.clone(), cx);
2028 });
2029 // Construct the fake server to authenticate.
2030 let _server = FakeServer::for_client(42, &client, cx).await;
2031 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2032 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2033
2034 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2035 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2036 let completion_task = zeta.update(cx, |zeta, cx| {
2037 zeta.request_completion(None, &buffer, cursor, false, cx)
2038 });
2039
2040 let completion = completion_task.await.unwrap().unwrap();
2041 buffer.update(cx, |buffer, cx| {
2042 buffer.edit(completion.edits.iter().cloned(), None, cx)
2043 });
2044 assert_eq!(
2045 buffer.read_with(cx, |buffer, _| buffer.text()),
2046 "lorem\nipsum"
2047 );
2048 }
2049
2050 async fn edits_for_prediction(
2051 buffer_content: &str,
2052 completion_response: &str,
2053 cx: &mut TestAppContext,
2054 ) -> Vec<(Range<Point>, String)> {
2055 let completion_response = completion_response.to_string();
2056 let http_client = FakeHttpClient::create(move |req| {
2057 let completion = completion_response.clone();
2058 async move {
2059 match (req.method(), req.uri().path()) {
2060 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2061 .status(200)
2062 .body(
2063 serde_json::to_string(&CreateLlmTokenResponse {
2064 token: LlmToken("the-llm-token".to_string()),
2065 })
2066 .unwrap()
2067 .into(),
2068 )
2069 .unwrap()),
2070 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2071 .status(200)
2072 .body(
2073 serde_json::to_string(&PredictEditsResponse {
2074 request_id: Uuid::new_v4(),
2075 output_excerpt: completion,
2076 })
2077 .unwrap()
2078 .into(),
2079 )
2080 .unwrap()),
2081 _ => Ok(http_client::Response::builder()
2082 .status(404)
2083 .body("Not Found".into())
2084 .unwrap()),
2085 }
2086 }
2087 });
2088
2089 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2090 cx.update(|cx| {
2091 RefreshLlmTokenListener::register(client.clone(), cx);
2092 });
2093 // Construct the fake server to authenticate.
2094 let _server = FakeServer::for_client(42, &client, cx).await;
2095 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2096 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2097
2098 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2099 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2100 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2101 let completion_task = zeta.update(cx, |zeta, cx| {
2102 zeta.request_completion(None, &buffer, cursor, false, cx)
2103 });
2104
2105 let completion = completion_task.await.unwrap().unwrap();
2106 completion
2107 .edits
2108 .iter()
2109 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2110 .collect::<Vec<_>>()
2111 }
2112
2113 fn to_completion_edits(
2114 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2115 buffer: &Entity<Buffer>,
2116 cx: &App,
2117 ) -> Vec<(Range<Anchor>, String)> {
2118 let buffer = buffer.read(cx);
2119 iterator
2120 .into_iter()
2121 .map(|(range, text)| {
2122 (
2123 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2124 text,
2125 )
2126 })
2127 .collect()
2128 }
2129
2130 fn from_completion_edits(
2131 editor_edits: &[(Range<Anchor>, String)],
2132 buffer: &Entity<Buffer>,
2133 cx: &App,
2134 ) -> Vec<(Range<usize>, String)> {
2135 let buffer = buffer.read(cx);
2136 editor_edits
2137 .iter()
2138 .map(|(range, text)| {
2139 (
2140 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2141 text.clone(),
2142 )
2143 })
2144 .collect()
2145 }
2146
2147 #[ctor::ctor]
2148 fn init_logger() {
2149 zlog::init_test();
2150 }
2151}