1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use project::{Project, ProjectPath};
39use release_channel::AppVersion;
40use settings::WorktreeId;
41use std::str::FromStr;
42use std::{
43 cmp,
44 fmt::Write,
45 future::Future,
46 mem,
47 ops::Range,
48 path::Path,
49 rc::Rc,
50 sync::Arc,
51 time::{Duration, Instant},
52};
53use telemetry_events::EditPredictionRating;
54use thiserror::Error;
55use util::{ResultExt, maybe};
56use uuid::Uuid;
57use workspace::Workspace;
58use workspace::notifications::{ErrorMessagePrompt, NotificationId};
59use worktree::Worktree;
60
61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Minimum number of milliseconds between recent file entries.
79const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
80
81/// Maximum file path length to include in recent files list.
82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
83
84/// Maximum number of JSON bytes for diagnostics in additional context.
85const MAX_DIAGNOSTICS_BYTES: usize = 4096;
86
87/// Maximum number of edit predictions to store for feedback.
88const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
89
90actions!(
91 edit_prediction,
92 [
93 /// Clears the edit prediction history.
94 ClearHistory
95 ]
96);
97
98#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
99pub struct EditPredictionId(Uuid);
100
101impl From<EditPredictionId> for gpui::ElementId {
102 fn from(value: EditPredictionId) -> Self {
103 gpui::ElementId::Uuid(value.0)
104 }
105}
106
107impl std::fmt::Display for EditPredictionId {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}", self.0)
110 }
111}
112
113struct ZedPredictUpsell;
114
115impl Dismissable for ZedPredictUpsell {
116 const KEY: &'static str = "dismissed-edit-predict-upsell";
117
118 fn dismissed() -> bool {
119 // To make this backwards compatible with older versions of Zed, we
120 // check if the user has seen the previous Edit Prediction Onboarding
121 // before, by checking the data collection choice which was written to
122 // the database once the user clicked on "Accept and Enable"
123 if KEY_VALUE_STORE
124 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
125 .log_err()
126 .is_some_and(|s| s.is_some())
127 {
128 return true;
129 }
130
131 KEY_VALUE_STORE
132 .read_kvp(Self::KEY)
133 .log_err()
134 .is_some_and(|s| s.is_some())
135 }
136}
137
138pub fn should_show_upsell_modal() -> bool {
139 !ZedPredictUpsell::dismissed()
140}
141
142#[derive(Clone)]
143struct ZetaGlobal(Entity<Zeta>);
144
145impl Global for ZetaGlobal {}
146
147#[derive(Clone)]
148pub struct EditPrediction {
149 id: EditPredictionId,
150 path: Arc<Path>,
151 excerpt_range: Range<usize>,
152 cursor_offset: usize,
153 edits: Arc<[(Range<Anchor>, String)]>,
154 snapshot: BufferSnapshot,
155 edit_preview: EditPreview,
156 input_events: Arc<str>,
157 input_excerpt: Arc<str>,
158 output_excerpt: Arc<str>,
159 buffer_snapshotted_at: Instant,
160 response_received_at: Instant,
161}
162
163impl EditPrediction {
164 fn latency(&self) -> Duration {
165 self.response_received_at
166 .duration_since(self.buffer_snapshotted_at)
167 }
168
169 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
170 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
171 }
172}
173
174fn interpolate(
175 old_snapshot: &BufferSnapshot,
176 new_snapshot: &BufferSnapshot,
177 current_edits: Arc<[(Range<Anchor>, String)]>,
178) -> Option<Vec<(Range<Anchor>, String)>> {
179 let mut edits = Vec::new();
180
181 let mut model_edits = current_edits.iter().peekable();
182 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
183 while let Some((model_old_range, _)) = model_edits.peek() {
184 let model_old_range = model_old_range.to_offset(old_snapshot);
185 if model_old_range.end < user_edit.old.start {
186 let (model_old_range, model_new_text) = model_edits.next().unwrap();
187 edits.push((model_old_range.clone(), model_new_text.clone()));
188 } else {
189 break;
190 }
191 }
192
193 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
194 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
195 if user_edit.old == model_old_offset_range {
196 let user_new_text = new_snapshot
197 .text_for_range(user_edit.new.clone())
198 .collect::<String>();
199
200 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
201 if !model_suffix.is_empty() {
202 let anchor = old_snapshot.anchor_after(user_edit.old.end);
203 edits.push((anchor..anchor, model_suffix.to_string()));
204 }
205
206 model_edits.next();
207 continue;
208 }
209 }
210 }
211
212 return None;
213 }
214
215 edits.extend(model_edits.cloned());
216
217 if edits.is_empty() { None } else { Some(edits) }
218}
219
220impl std::fmt::Debug for EditPrediction {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("EditPrediction")
223 .field("id", &self.id)
224 .field("path", &self.path)
225 .field("edits", &self.edits)
226 .finish_non_exhaustive()
227 }
228}
229
230pub struct Zeta {
231 workspace: WeakEntity<Workspace>,
232 client: Arc<Client>,
233 events: VecDeque<Event>,
234 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
235 shown_completions: VecDeque<EditPrediction>,
236 rated_completions: HashSet<EditPredictionId>,
237 data_collection_choice: Entity<DataCollectionChoice>,
238 llm_token: LlmApiToken,
239 _llm_token_subscription: Subscription,
240 /// Whether an update to a newer version of Zed is required to continue using Zeta.
241 update_required: bool,
242 user_store: Entity<UserStore>,
243 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
244 recent_editors: VecDeque<RecentEditor>,
245}
246
247struct RecentEditor {
248 editor: WeakEntity<Editor>,
249 last_active_at: Instant,
250}
251
252impl Zeta {
253 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
254 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
255 }
256
257 pub fn register(
258 workspace: Option<Entity<Workspace>>,
259 worktree: Option<Entity<Worktree>>,
260 client: Arc<Client>,
261 user_store: Entity<UserStore>,
262 cx: &mut App,
263 ) -> Entity<Self> {
264 let this = Self::global(cx).unwrap_or_else(|| {
265 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
266 cx.set_global(ZetaGlobal(entity.clone()));
267 entity
268 });
269
270 this.update(cx, move |this, cx| {
271 if let Some(worktree) = worktree {
272 let worktree_id = worktree.read(cx).id();
273 this.license_detection_watchers
274 .entry(worktree_id)
275 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
276 }
277 });
278
279 this
280 }
281
282 pub fn clear_history(&mut self) {
283 self.events.clear();
284 }
285
286 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
287 self.user_store.read(cx).edit_prediction_usage()
288 }
289
290 fn new(
291 workspace: Option<Entity<Workspace>>,
292 client: Arc<Client>,
293 user_store: Entity<UserStore>,
294 cx: &mut Context<Self>,
295 ) -> Self {
296 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
297
298 let data_collection_choice = Self::load_data_collection_choices();
299 let data_collection_choice = cx.new(|_| data_collection_choice);
300
301 if let Some(workspace) = &workspace {
302 cx.subscribe(workspace, |this, _workspace, event, cx| match event {
303 workspace::Event::ActiveItemChanged => {
304 this.handle_active_workspace_item_changed(cx)
305 }
306 _ => {}
307 })
308 .detach();
309 }
310
311 Self {
312 workspace: workspace.map_or_else(
313 || WeakEntity::new_invalid(),
314 |workspace| workspace.downgrade(),
315 ),
316 client,
317 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
318 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
319 rated_completions: HashSet::default(),
320 registered_buffers: HashMap::default(),
321 data_collection_choice,
322 llm_token: LlmApiToken::default(),
323 _llm_token_subscription: cx.subscribe(
324 &refresh_llm_token_listener,
325 |this, _listener, _event, cx| {
326 let client = this.client.clone();
327 let llm_token = this.llm_token.clone();
328 cx.spawn(async move |_this, _cx| {
329 llm_token.refresh(&client).await?;
330 anyhow::Ok(())
331 })
332 .detach_and_log_err(cx);
333 },
334 ),
335 update_required: false,
336 license_detection_watchers: HashMap::default(),
337 user_store,
338 recent_editors: VecDeque::new(),
339 }
340 }
341
342 fn push_event(&mut self, event: Event) {
343 if let Some(Event::BufferChange {
344 new_snapshot: last_new_snapshot,
345 timestamp: last_timestamp,
346 ..
347 }) = self.events.back_mut()
348 {
349 // Coalesce edits for the same buffer when they happen one after the other.
350 let Event::BufferChange {
351 old_snapshot,
352 new_snapshot,
353 timestamp,
354 } = &event;
355
356 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
357 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
358 && old_snapshot.version == last_new_snapshot.version
359 {
360 *last_new_snapshot = new_snapshot.clone();
361 *last_timestamp = *timestamp;
362 return;
363 }
364 }
365
366 if self.events.len() >= MAX_EVENT_COUNT {
367 // These are halved instead of popping to improve prompt caching.
368 self.events.drain(..MAX_EVENT_COUNT / 2);
369 }
370
371 self.events.push_back(event);
372 }
373
374 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
375 let buffer_id = buffer.entity_id();
376 let weak_buffer = buffer.downgrade();
377
378 if let std::collections::hash_map::Entry::Vacant(entry) =
379 self.registered_buffers.entry(buffer_id)
380 {
381 let snapshot = buffer.read(cx).snapshot();
382
383 entry.insert(RegisteredBuffer {
384 snapshot,
385 _subscriptions: [
386 cx.subscribe(buffer, move |this, buffer, event, cx| {
387 this.handle_buffer_event(buffer, event, cx);
388 }),
389 cx.observe_release(buffer, move |this, _buffer, _cx| {
390 this.registered_buffers.remove(&weak_buffer.entity_id());
391 }),
392 ],
393 });
394 };
395 }
396
397 fn handle_buffer_event(
398 &mut self,
399 buffer: Entity<Buffer>,
400 event: &language::BufferEvent,
401 cx: &mut Context<Self>,
402 ) {
403 if let language::BufferEvent::Edited = event {
404 self.report_changes_for_buffer(&buffer, cx);
405 }
406 }
407
408 fn request_completion_impl<F, R>(
409 &mut self,
410 workspace: Option<Entity<Workspace>>,
411 project: Option<Entity<Project>>,
412 buffer: &Entity<Buffer>,
413 cursor: language::Anchor,
414 can_collect_data: CanCollectData,
415 cx: &mut Context<Self>,
416 perform_predict_edits: F,
417 ) -> Task<Result<Option<EditPrediction>>>
418 where
419 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
420 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
421 + Send
422 + 'static,
423 {
424 let buffer = buffer.clone();
425 let buffer_snapshotted_at = Instant::now();
426 let snapshot = self.report_changes_for_buffer(&buffer, cx);
427 let zeta = cx.entity();
428 let events = self.events.clone();
429 let client = self.client.clone();
430 let llm_token = self.llm_token.clone();
431 let app_version = AppVersion::global(cx);
432
433 let full_path: Arc<Path> = snapshot
434 .file()
435 .map(|f| Arc::from(f.full_path(cx).as_path()))
436 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
437 let full_path_str = full_path.to_string_lossy().to_string();
438 let cursor_point = cursor.to_point(&snapshot);
439 let cursor_offset = cursor_point.to_offset(&snapshot);
440 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
441 let gather_task = cx.background_spawn(gather_context(
442 full_path_str,
443 snapshot.clone(),
444 cursor_point,
445 make_events_prompt,
446 can_collect_data,
447 ));
448
449 cx.spawn(async move |this, cx| {
450 let GatherContextOutput {
451 body,
452 editable_range,
453 } = gather_task.await?;
454 let done_gathering_context_at = Instant::now();
455
456 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
457 && let Some(file) = snapshot.file()
458 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
459 {
460 // This is async to reduce latency of the edit predictions request. The downside is
461 // that it will see a slightly later state than was used when gathering context.
462 let snapshot = snapshot.clone();
463 let this = this.clone();
464 Some(cx.spawn(async move |cx| {
465 if let Ok(Some(task)) = this.update(cx, |this, cx| {
466 this.gather_additional_context(
467 cursor_point,
468 cursor_offset,
469 snapshot,
470 &buffer_snapshotted_at,
471 project_path,
472 project.as_ref(),
473 cx,
474 )
475 }) {
476 Some(task.await)
477 } else {
478 None
479 }
480 }))
481 } else {
482 None
483 };
484
485 log::debug!(
486 "Events:\n{}\nExcerpt:\n{:?}",
487 body.input_events,
488 body.input_excerpt
489 );
490
491 let input_events = body.input_events.clone();
492 let input_excerpt = body.input_excerpt.clone();
493
494 let response = perform_predict_edits(PerformPredictEditsParams {
495 client,
496 llm_token,
497 app_version,
498 body,
499 })
500 .await;
501 let (response, usage) = match response {
502 Ok(response) => response,
503 Err(err) => {
504 if err.is::<ZedUpdateRequiredError>() {
505 cx.update(|cx| {
506 zeta.update(cx, |zeta, _cx| {
507 zeta.update_required = true;
508 });
509
510 if let Some(workspace) = workspace {
511 workspace.update(cx, |workspace, cx| {
512 workspace.show_notification(
513 NotificationId::unique::<ZedUpdateRequiredError>(),
514 cx,
515 |cx| {
516 cx.new(|cx| {
517 ErrorMessagePrompt::new(err.to_string(), cx)
518 .with_link_button(
519 "Update Zed",
520 "https://zed.dev/releases",
521 )
522 })
523 },
524 );
525 });
526 }
527 })
528 .ok();
529 }
530
531 return Err(err);
532 }
533 };
534
535 let received_response_at = Instant::now();
536 log::debug!("completion response: {}", &response.output_excerpt);
537
538 if let Some(usage) = usage {
539 this.update(cx, |this, cx| {
540 this.user_store.update(cx, |user_store, cx| {
541 user_store.update_edit_prediction_usage(usage, cx);
542 });
543 })
544 .ok();
545 }
546
547 let request_id = response.request_id.clone();
548 let edit_prediction = Self::process_completion_response(
549 response,
550 buffer,
551 &snapshot,
552 editable_range,
553 cursor_offset,
554 full_path,
555 input_events,
556 input_excerpt,
557 buffer_snapshotted_at,
558 cx,
559 )
560 .await;
561
562 let finished_at = Instant::now();
563
564 // record latency for ~1% of requests
565 if rand::random::<u8>() <= 2 {
566 telemetry::event!(
567 "Edit Prediction Request",
568 context_latency = done_gathering_context_at
569 .duration_since(buffer_snapshotted_at)
570 .as_millis(),
571 request_latency = received_response_at
572 .duration_since(done_gathering_context_at)
573 .as_millis(),
574 process_latency = finished_at.duration_since(received_response_at).as_millis()
575 );
576 }
577
578 if let Some(additional_context_task) = additional_context_task {
579 cx.background_spawn(async move {
580 if let Some(additional_context) = additional_context_task.await {
581 telemetry::event!(
582 "Edit Prediction Additional Context",
583 request_id = request_id,
584 additional_context = additional_context
585 );
586 }
587 })
588 .detach();
589 }
590
591 edit_prediction
592 })
593 }
594
595 // Generates several example completions of various states to fill the Zeta completion modal
596 #[cfg(any(test, feature = "test-support"))]
597 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
598 use language::Point;
599
600 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
601 And maybe a short line
602
603 Then a few lines
604
605 and then another
606 "#};
607
608 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
609 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
610
611 let completion_tasks = vec![
612 self.fake_completion(
613 None,
614 &buffer,
615 position,
616 PredictEditsResponse {
617 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
618 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
619a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
620[here's an edit]
621And maybe a short line
622Then a few lines
623and then another
624{EDITABLE_REGION_END_MARKER}
625 "),
626 },
627 cx,
628 ),
629 self.fake_completion(
630 None,
631 &buffer,
632 position,
633 PredictEditsResponse {
634 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
635 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
636a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
637And maybe a short line
638[and another edit]
639Then a few lines
640and then another
641{EDITABLE_REGION_END_MARKER}
642 "#),
643 },
644 cx,
645 ),
646 self.fake_completion(
647 None,
648 &buffer,
649 position,
650 PredictEditsResponse {
651 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
652 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
653a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
654And maybe a short line
655
656Then a few lines
657
658and then another
659{EDITABLE_REGION_END_MARKER}
660 "#),
661 },
662 cx,
663 ),
664 self.fake_completion(
665 None,
666 &buffer,
667 position,
668 PredictEditsResponse {
669 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
670 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
671a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
672And maybe a short line
673
674Then a few lines
675
676and then another
677{EDITABLE_REGION_END_MARKER}
678 "#),
679 },
680 cx,
681 ),
682 self.fake_completion(
683 None,
684 &buffer,
685 position,
686 PredictEditsResponse {
687 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
688 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
689a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
690And maybe a short line
691Then a few lines
692[a third completion]
693and then another
694{EDITABLE_REGION_END_MARKER}
695 "#),
696 },
697 cx,
698 ),
699 self.fake_completion(
700 None,
701 &buffer,
702 position,
703 PredictEditsResponse {
704 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
705 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
706a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
707And maybe a short line
708and then another
709[fourth completion example]
710{EDITABLE_REGION_END_MARKER}
711 "#),
712 },
713 cx,
714 ),
715 self.fake_completion(
716 None,
717 &buffer,
718 position,
719 PredictEditsResponse {
720 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
721 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
722a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
723And maybe a short line
724Then a few lines
725and then another
726[fifth and final completion]
727{EDITABLE_REGION_END_MARKER}
728 "#),
729 },
730 cx,
731 ),
732 ];
733
734 cx.spawn(async move |zeta, cx| {
735 for task in completion_tasks {
736 task.await.unwrap();
737 }
738
739 zeta.update(cx, |zeta, _cx| {
740 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
741 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
742 })
743 .ok();
744 })
745 }
746
747 #[cfg(any(test, feature = "test-support"))]
748 pub fn fake_completion(
749 &mut self,
750 project: Option<Entity<Project>>,
751 buffer: &Entity<Buffer>,
752 position: language::Anchor,
753 response: PredictEditsResponse,
754 cx: &mut Context<Self>,
755 ) -> Task<Result<Option<EditPrediction>>> {
756 use std::future::ready;
757
758 self.request_completion_impl(
759 None,
760 project,
761 buffer,
762 position,
763 CanCollectData(false),
764 cx,
765 |_params| ready(Ok((response, None))),
766 )
767 }
768
769 pub fn request_completion(
770 &mut self,
771 project: Option<Entity<Project>>,
772 buffer: &Entity<Buffer>,
773 position: language::Anchor,
774 can_collect_data: CanCollectData,
775 cx: &mut Context<Self>,
776 ) -> Task<Result<Option<EditPrediction>>> {
777 self.request_completion_impl(
778 self.workspace.upgrade(),
779 project,
780 buffer,
781 position,
782 can_collect_data,
783 cx,
784 Self::perform_predict_edits,
785 )
786 }
787
788 pub fn perform_predict_edits(
789 params: PerformPredictEditsParams,
790 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
791 async move {
792 let PerformPredictEditsParams {
793 client,
794 llm_token,
795 app_version,
796 body,
797 ..
798 } = params;
799
800 let http_client = client.http_client();
801 let mut token = llm_token.acquire(&client).await?;
802 let mut did_retry = false;
803
804 loop {
805 let request_builder = http_client::Request::builder().method(Method::POST);
806 let request_builder =
807 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
808 request_builder.uri(predict_edits_url)
809 } else {
810 request_builder.uri(
811 http_client
812 .build_zed_llm_url("/predict_edits/v2", &[])?
813 .as_ref(),
814 )
815 };
816 let request = request_builder
817 .header("Content-Type", "application/json")
818 .header("Authorization", format!("Bearer {}", token))
819 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
820 .body(serde_json::to_string(&body)?.into())?;
821
822 let mut response = http_client.send(request).await?;
823
824 if let Some(minimum_required_version) = response
825 .headers()
826 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
827 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
828 {
829 anyhow::ensure!(
830 app_version >= minimum_required_version,
831 ZedUpdateRequiredError {
832 minimum_version: minimum_required_version
833 }
834 );
835 }
836
837 if response.status().is_success() {
838 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
839
840 let mut body = String::new();
841 response.body_mut().read_to_string(&mut body).await?;
842 return Ok((serde_json::from_str(&body)?, usage));
843 } else if !did_retry
844 && response
845 .headers()
846 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
847 .is_some()
848 {
849 did_retry = true;
850 token = llm_token.refresh(&client).await?;
851 } else {
852 let mut body = String::new();
853 response.body_mut().read_to_string(&mut body).await?;
854 anyhow::bail!(
855 "error predicting edits.\nStatus: {:?}\nBody: {}",
856 response.status(),
857 body
858 );
859 }
860 }
861 }
862 }
863
864 fn accept_edit_prediction(
865 &mut self,
866 request_id: EditPredictionId,
867 cx: &mut Context<Self>,
868 ) -> Task<Result<()>> {
869 let client = self.client.clone();
870 let llm_token = self.llm_token.clone();
871 let app_version = AppVersion::global(cx);
872 cx.spawn(async move |this, cx| {
873 let http_client = client.http_client();
874 let mut response = llm_token_retry(&llm_token, &client, |token| {
875 let request_builder = http_client::Request::builder().method(Method::POST);
876 let request_builder =
877 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
878 request_builder.uri(accept_prediction_url)
879 } else {
880 request_builder.uri(
881 http_client
882 .build_zed_llm_url("/predict_edits/accept", &[])?
883 .as_ref(),
884 )
885 };
886 Ok(request_builder
887 .header("Content-Type", "application/json")
888 .header("Authorization", format!("Bearer {}", token))
889 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
890 .body(
891 serde_json::to_string(&AcceptEditPredictionBody {
892 request_id: request_id.0,
893 })?
894 .into(),
895 )?)
896 })
897 .await?;
898
899 if let Some(minimum_required_version) = response
900 .headers()
901 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
902 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
903 && app_version < minimum_required_version
904 {
905 return Err(anyhow!(ZedUpdateRequiredError {
906 minimum_version: minimum_required_version
907 }));
908 }
909
910 if response.status().is_success() {
911 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
912 this.update(cx, |this, cx| {
913 this.user_store.update(cx, |user_store, cx| {
914 user_store.update_edit_prediction_usage(usage, cx);
915 });
916 })?;
917 }
918
919 Ok(())
920 } else {
921 let mut body = String::new();
922 response.body_mut().read_to_string(&mut body).await?;
923 Err(anyhow!(
924 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
925 response.status(),
926 body
927 ))
928 }
929 })
930 }
931
932 fn process_completion_response(
933 prediction_response: PredictEditsResponse,
934 buffer: Entity<Buffer>,
935 snapshot: &BufferSnapshot,
936 editable_range: Range<usize>,
937 cursor_offset: usize,
938 path: Arc<Path>,
939 input_events: String,
940 input_excerpt: String,
941 buffer_snapshotted_at: Instant,
942 cx: &AsyncApp,
943 ) -> Task<Result<Option<EditPrediction>>> {
944 let snapshot = snapshot.clone();
945 let request_id = prediction_response.request_id;
946 let output_excerpt = prediction_response.output_excerpt;
947 cx.spawn(async move |cx| {
948 let output_excerpt: Arc<str> = output_excerpt.into();
949
950 let edits: Arc<[(Range<Anchor>, String)]> = cx
951 .background_spawn({
952 let output_excerpt = output_excerpt.clone();
953 let editable_range = editable_range.clone();
954 let snapshot = snapshot.clone();
955 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
956 })
957 .await?
958 .into();
959
960 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
961 let edits = edits.clone();
962 |buffer, cx| {
963 let new_snapshot = buffer.snapshot();
964 let edits: Arc<[(Range<Anchor>, String)]> =
965 interpolate(&snapshot, &new_snapshot, edits)?.into();
966 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
967 }
968 })?
969 else {
970 return anyhow::Ok(None);
971 };
972
973 let edit_preview = edit_preview.await;
974
975 Ok(Some(EditPrediction {
976 id: EditPredictionId(request_id),
977 path,
978 excerpt_range: editable_range,
979 cursor_offset,
980 edits,
981 edit_preview,
982 snapshot,
983 input_events: input_events.into(),
984 input_excerpt: input_excerpt.into(),
985 output_excerpt,
986 buffer_snapshotted_at,
987 response_received_at: Instant::now(),
988 }))
989 })
990 }
991
992 fn parse_edits(
993 output_excerpt: Arc<str>,
994 editable_range: Range<usize>,
995 snapshot: &BufferSnapshot,
996 ) -> Result<Vec<(Range<Anchor>, String)>> {
997 let content = output_excerpt.replace(CURSOR_MARKER, "");
998
999 let start_markers = content
1000 .match_indices(EDITABLE_REGION_START_MARKER)
1001 .collect::<Vec<_>>();
1002 anyhow::ensure!(
1003 start_markers.len() == 1,
1004 "expected exactly one start marker, found {}",
1005 start_markers.len()
1006 );
1007
1008 let end_markers = content
1009 .match_indices(EDITABLE_REGION_END_MARKER)
1010 .collect::<Vec<_>>();
1011 anyhow::ensure!(
1012 end_markers.len() == 1,
1013 "expected exactly one end marker, found {}",
1014 end_markers.len()
1015 );
1016
1017 let sof_markers = content
1018 .match_indices(START_OF_FILE_MARKER)
1019 .collect::<Vec<_>>();
1020 anyhow::ensure!(
1021 sof_markers.len() <= 1,
1022 "expected at most one start-of-file marker, found {}",
1023 sof_markers.len()
1024 );
1025
1026 let codefence_start = start_markers[0].0;
1027 let content = &content[codefence_start..];
1028
1029 let newline_ix = content.find('\n').context("could not find newline")?;
1030 let content = &content[newline_ix + 1..];
1031
1032 let codefence_end = content
1033 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1034 .context("could not find end marker")?;
1035 let new_text = &content[..codefence_end];
1036
1037 let old_text = snapshot
1038 .text_for_range(editable_range.clone())
1039 .collect::<String>();
1040
1041 Ok(Self::compute_edits(
1042 old_text,
1043 new_text,
1044 editable_range.start,
1045 snapshot,
1046 ))
1047 }
1048
1049 pub fn compute_edits(
1050 old_text: String,
1051 new_text: &str,
1052 offset: usize,
1053 snapshot: &BufferSnapshot,
1054 ) -> Vec<(Range<Anchor>, String)> {
1055 text_diff(&old_text, new_text)
1056 .into_iter()
1057 .map(|(mut old_range, new_text)| {
1058 old_range.start += offset;
1059 old_range.end += offset;
1060
1061 let prefix_len = common_prefix(
1062 snapshot.chars_for_range(old_range.clone()),
1063 new_text.chars(),
1064 );
1065 old_range.start += prefix_len;
1066
1067 let suffix_len = common_prefix(
1068 snapshot.reversed_chars_for_range(old_range.clone()),
1069 new_text[prefix_len..].chars().rev(),
1070 );
1071 old_range.end = old_range.end.saturating_sub(suffix_len);
1072
1073 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1074 let range = if old_range.is_empty() {
1075 let anchor = snapshot.anchor_after(old_range.start);
1076 anchor..anchor
1077 } else {
1078 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1079 };
1080 (range, new_text)
1081 })
1082 .collect()
1083 }
1084
1085 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1086 self.rated_completions.contains(&completion_id)
1087 }
1088
1089 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1090 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1091 let completion = self.shown_completions.pop_back().unwrap();
1092 self.rated_completions.remove(&completion.id);
1093 }
1094 self.shown_completions.push_front(completion.clone());
1095 cx.notify();
1096 }
1097
1098 pub fn rate_completion(
1099 &mut self,
1100 completion: &EditPrediction,
1101 rating: EditPredictionRating,
1102 feedback: String,
1103 cx: &mut Context<Self>,
1104 ) {
1105 self.rated_completions.insert(completion.id);
1106 telemetry::event!(
1107 "Edit Prediction Rated",
1108 rating,
1109 input_events = completion.input_events,
1110 input_excerpt = completion.input_excerpt,
1111 output_excerpt = completion.output_excerpt,
1112 feedback
1113 );
1114 self.client.telemetry().flush_events().detach();
1115 cx.notify();
1116 }
1117
1118 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1119 self.shown_completions.iter()
1120 }
1121
1122 pub fn shown_completions_len(&self) -> usize {
1123 self.shown_completions.len()
1124 }
1125
1126 fn report_changes_for_buffer(
1127 &mut self,
1128 buffer: &Entity<Buffer>,
1129 cx: &mut Context<Self>,
1130 ) -> BufferSnapshot {
1131 self.register_buffer(buffer, cx);
1132
1133 let registered_buffer = self
1134 .registered_buffers
1135 .get_mut(&buffer.entity_id())
1136 .unwrap();
1137 let new_snapshot = buffer.read(cx).snapshot();
1138
1139 if new_snapshot.version != registered_buffer.snapshot.version {
1140 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1141 self.push_event(Event::BufferChange {
1142 old_snapshot,
1143 new_snapshot: new_snapshot.clone(),
1144 timestamp: Instant::now(),
1145 });
1146 }
1147
1148 new_snapshot
1149 }
1150
1151 fn load_data_collection_choices() -> DataCollectionChoice {
1152 let choice = KEY_VALUE_STORE
1153 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1154 .log_err()
1155 .flatten();
1156
1157 match choice.as_deref() {
1158 Some("true") => DataCollectionChoice::Enabled,
1159 Some("false") => DataCollectionChoice::Disabled,
1160 Some(_) => {
1161 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1162 DataCollectionChoice::NotAnswered
1163 }
1164 None => DataCollectionChoice::NotAnswered,
1165 }
1166 }
1167
1168 fn gather_additional_context(
1169 &mut self,
1170 cursor_point: language::Point,
1171 cursor_offset: usize,
1172 snapshot: BufferSnapshot,
1173 buffer_snapshotted_at: &Instant,
1174 project_path: ProjectPath,
1175 project: Option<&Entity<Project>>,
1176 cx: &mut Context<Self>,
1177 ) -> Option<Task<PredictEditsAdditionalContext>> {
1178 let project = project?.read(cx);
1179 let entry = project.entry_for_path(&project_path, cx)?;
1180 if !worktree_entry_is_eligible_for_collection(&entry) {
1181 return None;
1182 }
1183
1184 let git_store = project.git_store().read(cx);
1185 let (repository, repo_path) =
1186 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1187 let repo_path_string = repo_path.to_str()?.to_string();
1188
1189 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1190 snapshot
1191 .diagnostics
1192 .iter()
1193 .filter_map(|(language_server_id, diagnostics)| {
1194 let language_server =
1195 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1196 Some((
1197 *language_server_id,
1198 language_server.name(),
1199 diagnostics.clone(),
1200 ))
1201 })
1202 .collect()
1203 } else {
1204 Vec::new()
1205 };
1206
1207 repository.update(cx, |repository, cx| {
1208 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1209 let remote_origin_url = repository.remote_origin_url.clone();
1210 let remote_upstream_url = repository.remote_upstream_url.clone();
1211 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1212
1213 // group, resolve, and select diagnostics on a background thread
1214 Some(cx.background_spawn(async move {
1215 let mut diagnostic_groups_with_name = Vec::new();
1216 for (language_server_id, language_server_name, diagnostics) in
1217 diagnostics.into_iter()
1218 {
1219 let mut groups = Vec::new();
1220 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1221 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1222 (
1223 language_server_name.clone(),
1224 group.resolve::<usize>(&snapshot),
1225 )
1226 }));
1227 }
1228
1229 // sort by proximity to cursor
1230 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1231 let range = &group.entries[group.primary_ix].range;
1232 if range.start >= cursor_offset {
1233 range.start - cursor_offset
1234 } else if cursor_offset >= range.end {
1235 cursor_offset - range.end
1236 } else {
1237 (cursor_offset - range.start).min(range.end - cursor_offset)
1238 }
1239 });
1240
1241 let mut diagnostic_groups = Vec::new();
1242 let mut diagnostic_groups_truncated = false;
1243 let mut diagnostics_byte_count = 0;
1244 for (name, group) in diagnostic_groups_with_name {
1245 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1246 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1247 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1248 diagnostic_groups_truncated = true;
1249 break;
1250 }
1251 diagnostic_groups.push((name.to_string(), raw_value));
1252 }
1253
1254 PredictEditsAdditionalContext {
1255 input_path: repo_path_string,
1256 cursor_point: to_cloud_llm_client_point(cursor_point),
1257 cursor_offset: cursor_offset,
1258 git_info: PredictEditsGitInfo {
1259 head_sha: Some(head_sha),
1260 remote_origin_url,
1261 remote_upstream_url,
1262 },
1263 diagnostic_groups,
1264 diagnostic_groups_truncated,
1265 recent_files,
1266 }
1267 }))
1268 })
1269 }
1270
1271 fn handle_active_workspace_item_changed(&mut self, cx: &Context<Self>) {
1272 if !self.data_collection_choice.read(cx).is_enabled() {
1273 self.recent_editors.clear();
1274 return;
1275 }
1276 if let Some(active_editor) = self
1277 .workspace
1278 .read_with(cx, |workspace, cx| {
1279 workspace
1280 .active_item(cx)
1281 .and_then(|item| item.act_as::<Editor>(cx))
1282 })
1283 .ok()
1284 .flatten()
1285 {
1286 let now = Instant::now();
1287 let new_recent = RecentEditor {
1288 editor: active_editor.downgrade(),
1289 last_active_at: now,
1290 };
1291 if let Some(existing_ix) = self
1292 .recent_editors
1293 .iter()
1294 .rposition(|recent| &recent.editor == &new_recent.editor)
1295 {
1296 self.recent_editors.remove(existing_ix);
1297 }
1298 // filter out rapid changes in active item, particularly since this can happen rapidly when
1299 // a workspace is loaded.
1300 if let Some(previous_recent) = self.recent_editors.back_mut()
1301 && now.duration_since(previous_recent.last_active_at)
1302 < MIN_TIME_BETWEEN_RECENT_FILES
1303 {
1304 *previous_recent = new_recent;
1305 return;
1306 }
1307 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1308 self.recent_editors.pop_front();
1309 }
1310 self.recent_editors.push_back(new_recent);
1311 }
1312 }
1313
1314 fn recent_files(
1315 &mut self,
1316 now: &Instant,
1317 repository: &Repository,
1318 cx: &mut App,
1319 ) -> Vec<PredictEditsRecentFile> {
1320 let Ok(project) = self
1321 .workspace
1322 .read_with(cx, |workspace, _cx| workspace.project().clone())
1323 else {
1324 return Vec::new();
1325 };
1326 let mut results = Vec::with_capacity(self.recent_editors.len());
1327 for ix in (0..self.recent_editors.len()).rev() {
1328 let recent_editor = &self.recent_editors[ix];
1329 let keep_entry = recent_editor
1330 .editor
1331 .update(cx, |editor, cx| {
1332 maybe!({
1333 let (buffer, cursor_point, _) = editor.cursor_buffer_point(cx)?;
1334 let file = buffer.read(cx).file()?;
1335 if !file_is_eligible_for_collection(file.as_ref()) {
1336 return None;
1337 }
1338 let project_path = ProjectPath {
1339 worktree_id: file.worktree_id(cx),
1340 path: file.path().clone(),
1341 };
1342 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1343 if !worktree_entry_is_eligible_for_collection(entry) {
1344 return None;
1345 }
1346 let Some(repo_path) =
1347 repository.project_path_to_repo_path(&project_path, cx)
1348 else {
1349 // entry not removed since later queries may involve other repositories
1350 return Some(());
1351 };
1352 // paths may not be valid UTF-8
1353 let repo_path_str = repo_path.to_str()?;
1354 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1355 return None;
1356 }
1357 let active_to_now_ms = now
1358 .duration_since(recent_editor.last_active_at)
1359 .as_millis()
1360 .try_into()
1361 .ok()?;
1362 results.push(PredictEditsRecentFile {
1363 path: repo_path_str.to_string(),
1364 cursor_point: to_cloud_llm_client_point(cursor_point),
1365 active_to_now_ms,
1366 });
1367 Some(())
1368 })
1369 })
1370 .ok()
1371 .flatten();
1372 if keep_entry.is_none() {
1373 self.recent_editors.remove(ix);
1374 }
1375 }
1376 results
1377 }
1378}
1379
1380fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1381 cloud_llm_client::Point {
1382 row: point.row,
1383 column: point.column,
1384 }
1385}
1386
1387fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1388 file.is_local() && !file.is_private()
1389}
1390
1391fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1392 entry.is_file()
1393 && entry.is_created()
1394 && !entry.is_ignored
1395 && !entry.is_private
1396 && !entry.is_external
1397 && !entry.is_fifo
1398}
1399
1400pub struct PerformPredictEditsParams {
1401 pub client: Arc<Client>,
1402 pub llm_token: LlmApiToken,
1403 pub app_version: SemanticVersion,
1404 pub body: PredictEditsBody,
1405}
1406
1407#[derive(Error, Debug)]
1408#[error(
1409 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1410)]
1411pub struct ZedUpdateRequiredError {
1412 minimum_version: SemanticVersion,
1413}
1414
1415fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1416 a.zip(b)
1417 .take_while(|(a, b)| a == b)
1418 .map(|(a, _)| a.len_utf8())
1419 .sum()
1420}
1421
1422pub struct GatherContextOutput {
1423 pub body: PredictEditsBody,
1424 pub editable_range: Range<usize>,
1425}
1426
1427pub async fn gather_context(
1428 full_path_str: String,
1429 snapshot: BufferSnapshot,
1430 cursor_point: language::Point,
1431 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1432 can_collect_data: CanCollectData,
1433) -> Result<GatherContextOutput> {
1434 let input_excerpt = excerpt_for_cursor_position(
1435 cursor_point,
1436 &full_path_str,
1437 &snapshot,
1438 MAX_REWRITE_TOKENS,
1439 MAX_CONTEXT_TOKENS,
1440 );
1441 let input_events = make_events_prompt();
1442 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1443
1444 let body = PredictEditsBody {
1445 input_events,
1446 input_excerpt: input_excerpt.prompt,
1447 can_collect_data: can_collect_data.0,
1448 diagnostic_groups: None,
1449 git_info: None,
1450 };
1451
1452 Ok(GatherContextOutput {
1453 body,
1454 editable_range,
1455 })
1456}
1457
1458fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1459 let mut result = String::new();
1460 for event in events.iter().rev() {
1461 let event_string = event.to_prompt();
1462 let event_tokens = tokens_for_bytes(event_string.len());
1463 if event_tokens > remaining_tokens {
1464 break;
1465 }
1466
1467 if !result.is_empty() {
1468 result.insert_str(0, "\n\n");
1469 }
1470 result.insert_str(0, &event_string);
1471 remaining_tokens -= event_tokens;
1472 }
1473 result
1474}
1475
1476struct RegisteredBuffer {
1477 snapshot: BufferSnapshot,
1478 _subscriptions: [gpui::Subscription; 2],
1479}
1480
1481#[derive(Clone)]
1482pub enum Event {
1483 BufferChange {
1484 old_snapshot: BufferSnapshot,
1485 new_snapshot: BufferSnapshot,
1486 timestamp: Instant,
1487 },
1488}
1489
1490impl Event {
1491 fn to_prompt(&self) -> String {
1492 match self {
1493 Event::BufferChange {
1494 old_snapshot,
1495 new_snapshot,
1496 ..
1497 } => {
1498 let mut prompt = String::new();
1499
1500 let old_path = old_snapshot
1501 .file()
1502 .map(|f| f.path().as_ref())
1503 .unwrap_or(Path::new("untitled"));
1504 let new_path = new_snapshot
1505 .file()
1506 .map(|f| f.path().as_ref())
1507 .unwrap_or(Path::new("untitled"));
1508 if old_path != new_path {
1509 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1510 }
1511
1512 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1513 if !diff.is_empty() {
1514 write!(
1515 prompt,
1516 "User edited {:?}:\n```diff\n{}\n```",
1517 new_path, diff
1518 )
1519 .unwrap();
1520 }
1521
1522 prompt
1523 }
1524 }
1525 }
1526}
1527
1528#[derive(Debug, Clone)]
1529struct CurrentEditPrediction {
1530 buffer_id: EntityId,
1531 completion: EditPrediction,
1532}
1533
1534impl CurrentEditPrediction {
1535 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1536 if self.buffer_id != old_completion.buffer_id {
1537 return true;
1538 }
1539
1540 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1541 return true;
1542 };
1543 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1544 return false;
1545 };
1546
1547 if old_edits.len() == 1 && new_edits.len() == 1 {
1548 let (old_range, old_text) = &old_edits[0];
1549 let (new_range, new_text) = &new_edits[0];
1550 new_range == old_range && new_text.starts_with(old_text)
1551 } else {
1552 true
1553 }
1554 }
1555}
1556
1557struct PendingCompletion {
1558 id: usize,
1559 _task: Task<()>,
1560}
1561
1562#[derive(Debug, Clone, Copy)]
1563pub enum DataCollectionChoice {
1564 NotAnswered,
1565 Enabled,
1566 Disabled,
1567}
1568
1569impl DataCollectionChoice {
1570 pub fn is_enabled(self) -> bool {
1571 match self {
1572 Self::Enabled => true,
1573 Self::NotAnswered | Self::Disabled => false,
1574 }
1575 }
1576
1577 pub fn is_answered(self) -> bool {
1578 match self {
1579 Self::Enabled | Self::Disabled => true,
1580 Self::NotAnswered => false,
1581 }
1582 }
1583
1584 pub fn toggle(&self) -> DataCollectionChoice {
1585 match self {
1586 Self::Enabled => Self::Disabled,
1587 Self::Disabled => Self::Enabled,
1588 Self::NotAnswered => Self::Enabled,
1589 }
1590 }
1591}
1592
1593impl From<bool> for DataCollectionChoice {
1594 fn from(value: bool) -> Self {
1595 match value {
1596 true => DataCollectionChoice::Enabled,
1597 false => DataCollectionChoice::Disabled,
1598 }
1599 }
1600}
1601
1602pub struct ProviderDataCollection {
1603 /// When set to None, data collection is not possible in the provider buffer
1604 choice: Option<Entity<DataCollectionChoice>>,
1605 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1606}
1607
1608#[derive(Debug, Clone, Copy)]
1609pub struct CanCollectData(pub bool);
1610
1611impl ProviderDataCollection {
1612 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1613 let choice_and_watcher = buffer.and_then(|buffer| {
1614 let file = buffer.read(cx).file()?;
1615
1616 if !file_is_eligible_for_collection(file.as_ref()) {
1617 return None;
1618 }
1619
1620 let zeta = zeta.read(cx);
1621 let choice = zeta.data_collection_choice.clone();
1622
1623 let license_detection_watcher = zeta
1624 .license_detection_watchers
1625 .get(&file.worktree_id(cx))
1626 .cloned()?;
1627
1628 Some((choice, license_detection_watcher))
1629 });
1630
1631 if let Some((choice, watcher)) = choice_and_watcher {
1632 ProviderDataCollection {
1633 choice: Some(choice),
1634 license_detection_watcher: Some(watcher),
1635 }
1636 } else {
1637 ProviderDataCollection {
1638 choice: None,
1639 license_detection_watcher: None,
1640 }
1641 }
1642 }
1643
1644 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1645 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1646 }
1647
1648 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1649 self.choice
1650 .as_ref()
1651 .is_some_and(|choice| choice.read(cx).is_enabled())
1652 }
1653
1654 fn is_project_open_source(&self) -> bool {
1655 self.license_detection_watcher
1656 .as_ref()
1657 .is_some_and(|watcher| watcher.is_project_open_source())
1658 }
1659
1660 pub fn toggle(&mut self, cx: &mut App) {
1661 if let Some(choice) = self.choice.as_mut() {
1662 let new_choice = choice.update(cx, |choice, _cx| {
1663 let new_choice = choice.toggle();
1664 *choice = new_choice;
1665 new_choice
1666 });
1667
1668 db::write_and_log(cx, move || {
1669 KEY_VALUE_STORE.write_kvp(
1670 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1671 new_choice.is_enabled().to_string(),
1672 )
1673 });
1674 }
1675 }
1676}
1677
1678async fn llm_token_retry(
1679 llm_token: &LlmApiToken,
1680 client: &Arc<Client>,
1681 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1682) -> Result<Response<AsyncBody>> {
1683 let mut did_retry = false;
1684 let http_client = client.http_client();
1685 let mut token = llm_token.acquire(client).await?;
1686 loop {
1687 let request = build_request(token.clone())?;
1688 let response = http_client.send(request).await?;
1689
1690 if !did_retry
1691 && !response.status().is_success()
1692 && response
1693 .headers()
1694 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1695 .is_some()
1696 {
1697 did_retry = true;
1698 token = llm_token.refresh(client).await?;
1699 continue;
1700 }
1701
1702 return Ok(response);
1703 }
1704}
1705
1706pub struct ZetaEditPredictionProvider {
1707 zeta: Entity<Zeta>,
1708 pending_completions: ArrayVec<PendingCompletion, 2>,
1709 next_pending_completion_id: usize,
1710 current_completion: Option<CurrentEditPrediction>,
1711 /// None if this is entirely disabled for this provider
1712 provider_data_collection: ProviderDataCollection,
1713 last_request_timestamp: Instant,
1714}
1715
1716impl ZetaEditPredictionProvider {
1717 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1718
1719 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1720 Self {
1721 zeta,
1722 pending_completions: ArrayVec::new(),
1723 next_pending_completion_id: 0,
1724 current_completion: None,
1725 provider_data_collection,
1726 last_request_timestamp: Instant::now(),
1727 }
1728 }
1729}
1730
1731impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1732 fn name() -> &'static str {
1733 "zed-predict"
1734 }
1735
1736 fn display_name() -> &'static str {
1737 "Zed's Edit Predictions"
1738 }
1739
1740 fn show_completions_in_menu() -> bool {
1741 true
1742 }
1743
1744 fn show_tab_accept_marker() -> bool {
1745 true
1746 }
1747
1748 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1749 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1750
1751 if self.provider_data_collection.is_data_collection_enabled(cx) {
1752 DataCollectionState::Enabled {
1753 is_project_open_source,
1754 }
1755 } else {
1756 DataCollectionState::Disabled {
1757 is_project_open_source,
1758 }
1759 }
1760 }
1761
1762 fn toggle_data_collection(&mut self, cx: &mut App) {
1763 self.provider_data_collection.toggle(cx);
1764 }
1765
1766 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1767 self.zeta.read(cx).usage(cx)
1768 }
1769
1770 fn is_enabled(
1771 &self,
1772 _buffer: &Entity<Buffer>,
1773 _cursor_position: language::Anchor,
1774 _cx: &App,
1775 ) -> bool {
1776 true
1777 }
1778 fn is_refreshing(&self) -> bool {
1779 !self.pending_completions.is_empty()
1780 }
1781
1782 fn refresh(
1783 &mut self,
1784 project: Option<Entity<Project>>,
1785 buffer: Entity<Buffer>,
1786 position: language::Anchor,
1787 _debounce: bool,
1788 cx: &mut Context<Self>,
1789 ) {
1790 if self.zeta.read(cx).update_required {
1791 return;
1792 }
1793
1794 if self
1795 .zeta
1796 .read(cx)
1797 .user_store
1798 .read_with(cx, |user_store, _cx| {
1799 user_store.account_too_young() || user_store.has_overdue_invoices()
1800 })
1801 {
1802 return;
1803 }
1804
1805 if let Some(current_completion) = self.current_completion.as_ref() {
1806 let snapshot = buffer.read(cx).snapshot();
1807 if current_completion
1808 .completion
1809 .interpolate(&snapshot)
1810 .is_some()
1811 {
1812 return;
1813 }
1814 }
1815
1816 let pending_completion_id = self.next_pending_completion_id;
1817 self.next_pending_completion_id += 1;
1818 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1819 let last_request_timestamp = self.last_request_timestamp;
1820
1821 let task = cx.spawn(async move |this, cx| {
1822 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1823 .checked_duration_since(Instant::now())
1824 {
1825 cx.background_executor().timer(timeout).await;
1826 }
1827
1828 let completion_request = this.update(cx, |this, cx| {
1829 this.last_request_timestamp = Instant::now();
1830 this.zeta.update(cx, |zeta, cx| {
1831 zeta.request_completion(project, &buffer, position, can_collect_data, cx)
1832 })
1833 });
1834
1835 let completion = match completion_request {
1836 Ok(completion_request) => {
1837 let completion_request = completion_request.await;
1838 completion_request.map(|c| {
1839 c.map(|completion| CurrentEditPrediction {
1840 buffer_id: buffer.entity_id(),
1841 completion,
1842 })
1843 })
1844 }
1845 Err(error) => Err(error),
1846 };
1847 let Some(new_completion) = completion
1848 .context("edit prediction failed")
1849 .log_err()
1850 .flatten()
1851 else {
1852 this.update(cx, |this, cx| {
1853 if this.pending_completions[0].id == pending_completion_id {
1854 this.pending_completions.remove(0);
1855 } else {
1856 this.pending_completions.clear();
1857 }
1858
1859 cx.notify();
1860 })
1861 .ok();
1862 return;
1863 };
1864
1865 this.update(cx, |this, cx| {
1866 if this.pending_completions[0].id == pending_completion_id {
1867 this.pending_completions.remove(0);
1868 } else {
1869 this.pending_completions.clear();
1870 }
1871
1872 if let Some(old_completion) = this.current_completion.as_ref() {
1873 let snapshot = buffer.read(cx).snapshot();
1874 if new_completion.should_replace_completion(old_completion, &snapshot) {
1875 this.zeta.update(cx, |zeta, cx| {
1876 zeta.completion_shown(&new_completion.completion, cx);
1877 });
1878 this.current_completion = Some(new_completion);
1879 }
1880 } else {
1881 this.zeta.update(cx, |zeta, cx| {
1882 zeta.completion_shown(&new_completion.completion, cx);
1883 });
1884 this.current_completion = Some(new_completion);
1885 }
1886
1887 cx.notify();
1888 })
1889 .ok();
1890 });
1891
1892 // We always maintain at most two pending completions. When we already
1893 // have two, we replace the newest one.
1894 if self.pending_completions.len() <= 1 {
1895 self.pending_completions.push(PendingCompletion {
1896 id: pending_completion_id,
1897 _task: task,
1898 });
1899 } else if self.pending_completions.len() == 2 {
1900 self.pending_completions.pop();
1901 self.pending_completions.push(PendingCompletion {
1902 id: pending_completion_id,
1903 _task: task,
1904 });
1905 }
1906 }
1907
1908 fn cycle(
1909 &mut self,
1910 _buffer: Entity<Buffer>,
1911 _cursor_position: language::Anchor,
1912 _direction: edit_prediction::Direction,
1913 _cx: &mut Context<Self>,
1914 ) {
1915 // Right now we don't support cycling.
1916 }
1917
1918 fn accept(&mut self, cx: &mut Context<Self>) {
1919 let completion_id = self
1920 .current_completion
1921 .as_ref()
1922 .map(|completion| completion.completion.id);
1923 if let Some(completion_id) = completion_id {
1924 self.zeta
1925 .update(cx, |zeta, cx| {
1926 zeta.accept_edit_prediction(completion_id, cx)
1927 })
1928 .detach();
1929 }
1930 self.pending_completions.clear();
1931 }
1932
1933 fn discard(&mut self, _cx: &mut Context<Self>) {
1934 self.pending_completions.clear();
1935 self.current_completion.take();
1936 }
1937
1938 fn suggest(
1939 &mut self,
1940 buffer: &Entity<Buffer>,
1941 cursor_position: language::Anchor,
1942 cx: &mut Context<Self>,
1943 ) -> Option<edit_prediction::EditPrediction> {
1944 let CurrentEditPrediction {
1945 buffer_id,
1946 completion,
1947 ..
1948 } = self.current_completion.as_mut()?;
1949
1950 // Invalidate previous completion if it was generated for a different buffer.
1951 if *buffer_id != buffer.entity_id() {
1952 self.current_completion.take();
1953 return None;
1954 }
1955
1956 let buffer = buffer.read(cx);
1957 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1958 self.current_completion.take();
1959 return None;
1960 };
1961
1962 let cursor_row = cursor_position.to_point(buffer).row;
1963 let (closest_edit_ix, (closest_edit_range, _)) =
1964 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1965 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1966 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1967 cmp::min(distance_from_start, distance_from_end)
1968 })?;
1969
1970 let mut edit_start_ix = closest_edit_ix;
1971 for (range, _) in edits[..edit_start_ix].iter().rev() {
1972 let distance_from_closest_edit =
1973 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1974 if distance_from_closest_edit <= 1 {
1975 edit_start_ix -= 1;
1976 } else {
1977 break;
1978 }
1979 }
1980
1981 let mut edit_end_ix = closest_edit_ix + 1;
1982 for (range, _) in &edits[edit_end_ix..] {
1983 let distance_from_closest_edit =
1984 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1985 if distance_from_closest_edit <= 1 {
1986 edit_end_ix += 1;
1987 } else {
1988 break;
1989 }
1990 }
1991
1992 Some(edit_prediction::EditPrediction {
1993 id: Some(completion.id.to_string().into()),
1994 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1995 edit_preview: Some(completion.edit_preview.clone()),
1996 })
1997 }
1998}
1999
2000fn tokens_for_bytes(bytes: usize) -> usize {
2001 /// Typical number of string bytes per token for the purposes of limiting model input. This is
2002 /// intentionally low to err on the side of underestimating limits.
2003 const BYTES_PER_TOKEN_GUESS: usize = 3;
2004 bytes / BYTES_PER_TOKEN_GUESS
2005}
2006
2007#[cfg(test)]
2008mod tests {
2009 use client::UserStore;
2010 use client::test::FakeServer;
2011 use clock::FakeSystemClock;
2012 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2013 use gpui::TestAppContext;
2014 use http_client::FakeHttpClient;
2015 use indoc::indoc;
2016 use language::Point;
2017 use settings::SettingsStore;
2018
2019 use super::*;
2020
2021 #[gpui::test]
2022 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2023 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2024 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2025 to_completion_edits(
2026 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2027 &buffer,
2028 cx,
2029 )
2030 .into()
2031 });
2032
2033 let edit_preview = cx
2034 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2035 .await;
2036
2037 let completion = EditPrediction {
2038 edits,
2039 edit_preview,
2040 path: Path::new("").into(),
2041 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2042 id: EditPredictionId(Uuid::new_v4()),
2043 excerpt_range: 0..0,
2044 cursor_offset: 0,
2045 input_events: "".into(),
2046 input_excerpt: "".into(),
2047 output_excerpt: "".into(),
2048 buffer_snapshotted_at: Instant::now(),
2049 response_received_at: Instant::now(),
2050 };
2051
2052 cx.update(|cx| {
2053 assert_eq!(
2054 from_completion_edits(
2055 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2056 &buffer,
2057 cx
2058 ),
2059 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2060 );
2061
2062 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2063 assert_eq!(
2064 from_completion_edits(
2065 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2066 &buffer,
2067 cx
2068 ),
2069 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2070 );
2071
2072 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2073 assert_eq!(
2074 from_completion_edits(
2075 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2076 &buffer,
2077 cx
2078 ),
2079 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2080 );
2081
2082 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2083 assert_eq!(
2084 from_completion_edits(
2085 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2086 &buffer,
2087 cx
2088 ),
2089 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2090 );
2091
2092 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2093 assert_eq!(
2094 from_completion_edits(
2095 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2096 &buffer,
2097 cx
2098 ),
2099 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2100 );
2101
2102 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2103 assert_eq!(
2104 from_completion_edits(
2105 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2106 &buffer,
2107 cx
2108 ),
2109 vec![(9..11, "".to_string())]
2110 );
2111
2112 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2113 assert_eq!(
2114 from_completion_edits(
2115 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2116 &buffer,
2117 cx
2118 ),
2119 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2120 );
2121
2122 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2123 assert_eq!(
2124 from_completion_edits(
2125 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2126 &buffer,
2127 cx
2128 ),
2129 vec![(4..4, "M".to_string())]
2130 );
2131
2132 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2133 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2134 })
2135 }
2136
2137 #[gpui::test]
2138 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2139 cx.update(|cx| {
2140 let settings_store = SettingsStore::test(cx);
2141 cx.set_global(settings_store);
2142 client::init_settings(cx);
2143 });
2144
2145 let edits = edits_for_prediction(
2146 indoc! {"
2147 fn main() {
2148 let word_1 = \"lorem\";
2149 let range = word.len()..word.len();
2150 }
2151 "},
2152 indoc! {"
2153 <|editable_region_start|>
2154 fn main() {
2155 let word_1 = \"lorem\";
2156 let range = word_1.len()..word_1.len();
2157 }
2158
2159 <|editable_region_end|>
2160 "},
2161 cx,
2162 )
2163 .await;
2164 assert_eq!(
2165 edits,
2166 [
2167 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2168 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2169 ]
2170 );
2171
2172 let edits = edits_for_prediction(
2173 indoc! {"
2174 fn main() {
2175 let story = \"the quick\"
2176 }
2177 "},
2178 indoc! {"
2179 <|editable_region_start|>
2180 fn main() {
2181 let story = \"the quick brown fox jumps over the lazy dog\";
2182 }
2183
2184 <|editable_region_end|>
2185 "},
2186 cx,
2187 )
2188 .await;
2189 assert_eq!(
2190 edits,
2191 [
2192 (
2193 Point::new(1, 26)..Point::new(1, 26),
2194 " brown fox jumps over the lazy dog".to_string()
2195 ),
2196 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2197 ]
2198 );
2199 }
2200
2201 #[gpui::test]
2202 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2203 cx.update(|cx| {
2204 let settings_store = SettingsStore::test(cx);
2205 cx.set_global(settings_store);
2206 client::init_settings(cx);
2207 });
2208
2209 let buffer_content = "lorem\n";
2210 let completion_response = indoc! {"
2211 ```animals.js
2212 <|start_of_file|>
2213 <|editable_region_start|>
2214 lorem
2215 ipsum
2216 <|editable_region_end|>
2217 ```"};
2218
2219 let http_client = FakeHttpClient::create(move |req| async move {
2220 match (req.method(), req.uri().path()) {
2221 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2222 .status(200)
2223 .body(
2224 serde_json::to_string(&CreateLlmTokenResponse {
2225 token: LlmToken("the-llm-token".to_string()),
2226 })
2227 .unwrap()
2228 .into(),
2229 )
2230 .unwrap()),
2231 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2232 .status(200)
2233 .body(
2234 serde_json::to_string(&PredictEditsResponse {
2235 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2236 .unwrap(),
2237 output_excerpt: completion_response.to_string(),
2238 })
2239 .unwrap()
2240 .into(),
2241 )
2242 .unwrap()),
2243 _ => Ok(http_client::Response::builder()
2244 .status(404)
2245 .body("Not Found".into())
2246 .unwrap()),
2247 }
2248 });
2249
2250 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2251 cx.update(|cx| {
2252 RefreshLlmTokenListener::register(client.clone(), cx);
2253 });
2254 // Construct the fake server to authenticate.
2255 let _server = FakeServer::for_client(42, &client, cx).await;
2256 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2257 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2258
2259 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2260 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2261 let completion_task = zeta.update(cx, |zeta, cx| {
2262 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2263 });
2264
2265 let completion = completion_task.await.unwrap().unwrap();
2266 buffer.update(cx, |buffer, cx| {
2267 buffer.edit(completion.edits.iter().cloned(), None, cx)
2268 });
2269 assert_eq!(
2270 buffer.read_with(cx, |buffer, _| buffer.text()),
2271 "lorem\nipsum"
2272 );
2273 }
2274
2275 async fn edits_for_prediction(
2276 buffer_content: &str,
2277 completion_response: &str,
2278 cx: &mut TestAppContext,
2279 ) -> Vec<(Range<Point>, String)> {
2280 let completion_response = completion_response.to_string();
2281 let http_client = FakeHttpClient::create(move |req| {
2282 let completion = completion_response.clone();
2283 async move {
2284 match (req.method(), req.uri().path()) {
2285 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2286 .status(200)
2287 .body(
2288 serde_json::to_string(&CreateLlmTokenResponse {
2289 token: LlmToken("the-llm-token".to_string()),
2290 })
2291 .unwrap()
2292 .into(),
2293 )
2294 .unwrap()),
2295 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2296 .status(200)
2297 .body(
2298 serde_json::to_string(&PredictEditsResponse {
2299 request_id: Uuid::new_v4(),
2300 output_excerpt: completion,
2301 })
2302 .unwrap()
2303 .into(),
2304 )
2305 .unwrap()),
2306 _ => Ok(http_client::Response::builder()
2307 .status(404)
2308 .body("Not Found".into())
2309 .unwrap()),
2310 }
2311 }
2312 });
2313
2314 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2315 cx.update(|cx| {
2316 RefreshLlmTokenListener::register(client.clone(), cx);
2317 });
2318 // Construct the fake server to authenticate.
2319 let _server = FakeServer::for_client(42, &client, cx).await;
2320 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2321 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2322
2323 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2324 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2325 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2326 let completion_task = zeta.update(cx, |zeta, cx| {
2327 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2328 });
2329
2330 let completion = completion_task.await.unwrap().unwrap();
2331 completion
2332 .edits
2333 .iter()
2334 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2335 .collect::<Vec<_>>()
2336 }
2337
2338 fn to_completion_edits(
2339 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2340 buffer: &Entity<Buffer>,
2341 cx: &App,
2342 ) -> Vec<(Range<Anchor>, String)> {
2343 let buffer = buffer.read(cx);
2344 iterator
2345 .into_iter()
2346 .map(|(range, text)| {
2347 (
2348 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2349 text,
2350 )
2351 })
2352 .collect()
2353 }
2354
2355 fn from_completion_edits(
2356 editor_edits: &[(Range<Anchor>, String)],
2357 buffer: &Entity<Buffer>,
2358 cx: &App,
2359 ) -> Vec<(Range<usize>, String)> {
2360 let buffer = buffer.read(cx);
2361 editor_edits
2362 .iter()
2363 .map(|(range, text)| {
2364 (
2365 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2366 text.clone(),
2367 )
2368 })
2369 .collect()
2370 }
2371
2372 #[ctor::ctor]
2373 fn init_logger() {
2374 zlog::init_test();
2375 }
2376}