1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9use arrayvec::ArrayVec;
10pub(crate) use completion_diff_element::*;
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction::DataCollectionState;
13use editor::Editor;
14pub use init::*;
15use license_detection::LicenseDetectionWatcher;
16use project::git_store::Repository;
17pub use rate_completion_modal::*;
18
19use anyhow::{Context as _, Result, anyhow};
20use client::{Client, EditPredictionUsage, UserStore};
21use cloud_llm_client::{
22 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
23 PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
24 PredictEditsResponse, ZED_VERSION_HEADER_NAME,
25};
26use collections::{HashMap, HashSet, VecDeque};
27use futures::AsyncReadExt;
28use gpui::{
29 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
30 Subscription, Task, WeakEntity, actions,
31};
32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
33use input_excerpt::excerpt_for_cursor_position;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
36};
37use language_model::{LlmApiToken, RefreshLlmTokenListener};
38use project::{Project, ProjectPath};
39use release_channel::AppVersion;
40use settings::WorktreeId;
41use std::str::FromStr;
42use std::{
43 cmp,
44 fmt::Write,
45 future::Future,
46 mem,
47 ops::Range,
48 path::Path,
49 rc::Rc,
50 sync::Arc,
51 time::{Duration, Instant},
52};
53use telemetry_events::EditPredictionRating;
54use thiserror::Error;
55use util::{ResultExt, maybe};
56use uuid::Uuid;
57use workspace::Workspace;
58use workspace::notifications::{ErrorMessagePrompt, NotificationId};
59use worktree::Worktree;
60
61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75/// Maximum number of recent files to track.
76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
77
78/// Minimum number of milliseconds between recent file entries.
79const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
80
81/// Maximum file path length to include in recent files list.
82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
83
84/// Maximum number of JSON bytes for diagnostics in additional context.
85const MAX_DIAGNOSTICS_BYTES: usize = 4096;
86
87/// Maximum number of edit predictions to store for feedback.
88const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
89
90actions!(
91 edit_prediction,
92 [
93 /// Clears the edit prediction history.
94 ClearHistory
95 ]
96);
97
98#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
99pub struct EditPredictionId(Uuid);
100
101impl From<EditPredictionId> for gpui::ElementId {
102 fn from(value: EditPredictionId) -> Self {
103 gpui::ElementId::Uuid(value.0)
104 }
105}
106
107impl std::fmt::Display for EditPredictionId {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}", self.0)
110 }
111}
112
113struct ZedPredictUpsell;
114
115impl Dismissable for ZedPredictUpsell {
116 const KEY: &'static str = "dismissed-edit-predict-upsell";
117
118 fn dismissed() -> bool {
119 // To make this backwards compatible with older versions of Zed, we
120 // check if the user has seen the previous Edit Prediction Onboarding
121 // before, by checking the data collection choice which was written to
122 // the database once the user clicked on "Accept and Enable"
123 if KEY_VALUE_STORE
124 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
125 .log_err()
126 .is_some_and(|s| s.is_some())
127 {
128 return true;
129 }
130
131 KEY_VALUE_STORE
132 .read_kvp(Self::KEY)
133 .log_err()
134 .is_some_and(|s| s.is_some())
135 }
136}
137
138pub fn should_show_upsell_modal() -> bool {
139 !ZedPredictUpsell::dismissed()
140}
141
142#[derive(Clone)]
143struct ZetaGlobal(Entity<Zeta>);
144
145impl Global for ZetaGlobal {}
146
147#[derive(Clone)]
148pub struct EditPrediction {
149 id: EditPredictionId,
150 path: Arc<Path>,
151 excerpt_range: Range<usize>,
152 cursor_offset: usize,
153 edits: Arc<[(Range<Anchor>, String)]>,
154 snapshot: BufferSnapshot,
155 edit_preview: EditPreview,
156 input_events: Arc<str>,
157 input_excerpt: Arc<str>,
158 output_excerpt: Arc<str>,
159 buffer_snapshotted_at: Instant,
160 response_received_at: Instant,
161}
162
163impl EditPrediction {
164 fn latency(&self) -> Duration {
165 self.response_received_at
166 .duration_since(self.buffer_snapshotted_at)
167 }
168
169 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
170 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
171 }
172}
173
174fn interpolate(
175 old_snapshot: &BufferSnapshot,
176 new_snapshot: &BufferSnapshot,
177 current_edits: Arc<[(Range<Anchor>, String)]>,
178) -> Option<Vec<(Range<Anchor>, String)>> {
179 let mut edits = Vec::new();
180
181 let mut model_edits = current_edits.iter().peekable();
182 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
183 while let Some((model_old_range, _)) = model_edits.peek() {
184 let model_old_range = model_old_range.to_offset(old_snapshot);
185 if model_old_range.end < user_edit.old.start {
186 let (model_old_range, model_new_text) = model_edits.next().unwrap();
187 edits.push((model_old_range.clone(), model_new_text.clone()));
188 } else {
189 break;
190 }
191 }
192
193 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
194 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
195 if user_edit.old == model_old_offset_range {
196 let user_new_text = new_snapshot
197 .text_for_range(user_edit.new.clone())
198 .collect::<String>();
199
200 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
201 if !model_suffix.is_empty() {
202 let anchor = old_snapshot.anchor_after(user_edit.old.end);
203 edits.push((anchor..anchor, model_suffix.to_string()));
204 }
205
206 model_edits.next();
207 continue;
208 }
209 }
210 }
211
212 return None;
213 }
214
215 edits.extend(model_edits.cloned());
216
217 if edits.is_empty() { None } else { Some(edits) }
218}
219
220impl std::fmt::Debug for EditPrediction {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("EditPrediction")
223 .field("id", &self.id)
224 .field("path", &self.path)
225 .field("edits", &self.edits)
226 .finish_non_exhaustive()
227 }
228}
229
230pub struct Zeta {
231 workspace: WeakEntity<Workspace>,
232 client: Arc<Client>,
233 events: VecDeque<Event>,
234 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
235 shown_completions: VecDeque<EditPrediction>,
236 rated_completions: HashSet<EditPredictionId>,
237 data_collection_choice: Entity<DataCollectionChoice>,
238 llm_token: LlmApiToken,
239 _llm_token_subscription: Subscription,
240 /// Whether an update to a newer version of Zed is required to continue using Zeta.
241 update_required: bool,
242 user_store: Entity<UserStore>,
243 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
244 recent_editors: VecDeque<RecentEditor>,
245}
246
247struct RecentEditor {
248 editor: WeakEntity<Editor>,
249 last_active_at: Instant,
250}
251
252impl Zeta {
253 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
254 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
255 }
256
257 pub fn register(
258 workspace: Option<Entity<Workspace>>,
259 worktree: Option<Entity<Worktree>>,
260 client: Arc<Client>,
261 user_store: Entity<UserStore>,
262 cx: &mut App,
263 ) -> Entity<Self> {
264 let this = Self::global(cx).unwrap_or_else(|| {
265 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
266 cx.set_global(ZetaGlobal(entity.clone()));
267 entity
268 });
269
270 this.update(cx, move |this, cx| {
271 if let Some(worktree) = worktree {
272 let worktree_id = worktree.read(cx).id();
273 this.license_detection_watchers
274 .entry(worktree_id)
275 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
276 }
277 });
278
279 this
280 }
281
282 pub fn clear_history(&mut self) {
283 self.events.clear();
284 }
285
286 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
287 self.user_store.read(cx).edit_prediction_usage()
288 }
289
290 fn new(
291 workspace: Option<Entity<Workspace>>,
292 client: Arc<Client>,
293 user_store: Entity<UserStore>,
294 cx: &mut Context<Self>,
295 ) -> Self {
296 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
297
298 let data_collection_choice = Self::load_data_collection_choices();
299 let data_collection_choice = cx.new(|_| data_collection_choice);
300
301 if let Some(workspace) = &workspace {
302 cx.subscribe(workspace, |this, _workspace, event, cx| match event {
303 workspace::Event::ActiveItemChanged => {
304 this.handle_active_workspace_item_changed(cx)
305 }
306 _ => {}
307 })
308 .detach();
309 }
310
311 Self {
312 workspace: workspace.map_or_else(
313 || WeakEntity::new_invalid(),
314 |workspace| workspace.downgrade(),
315 ),
316 client,
317 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
318 shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
319 rated_completions: HashSet::default(),
320 registered_buffers: HashMap::default(),
321 data_collection_choice,
322 llm_token: LlmApiToken::default(),
323 _llm_token_subscription: cx.subscribe(
324 &refresh_llm_token_listener,
325 |this, _listener, _event, cx| {
326 let client = this.client.clone();
327 let llm_token = this.llm_token.clone();
328 cx.spawn(async move |_this, _cx| {
329 llm_token.refresh(&client).await?;
330 anyhow::Ok(())
331 })
332 .detach_and_log_err(cx);
333 },
334 ),
335 update_required: false,
336 license_detection_watchers: HashMap::default(),
337 user_store,
338 recent_editors: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
339 }
340 }
341
342 fn push_event(&mut self, event: Event) {
343 if let Some(Event::BufferChange {
344 new_snapshot: last_new_snapshot,
345 timestamp: last_timestamp,
346 ..
347 }) = self.events.back_mut()
348 {
349 // Coalesce edits for the same buffer when they happen one after the other.
350 let Event::BufferChange {
351 old_snapshot,
352 new_snapshot,
353 timestamp,
354 } = &event;
355
356 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
357 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
358 && old_snapshot.version == last_new_snapshot.version
359 {
360 *last_new_snapshot = new_snapshot.clone();
361 *last_timestamp = *timestamp;
362 return;
363 }
364 }
365
366 if self.events.len() >= MAX_EVENT_COUNT {
367 // These are halved instead of popping to improve prompt caching.
368 self.events.drain(..MAX_EVENT_COUNT / 2);
369 }
370
371 self.events.push_back(event);
372 }
373
374 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
375 let buffer_id = buffer.entity_id();
376 let weak_buffer = buffer.downgrade();
377
378 if let std::collections::hash_map::Entry::Vacant(entry) =
379 self.registered_buffers.entry(buffer_id)
380 {
381 let snapshot = buffer.read(cx).snapshot();
382
383 entry.insert(RegisteredBuffer {
384 snapshot,
385 _subscriptions: [
386 cx.subscribe(buffer, move |this, buffer, event, cx| {
387 this.handle_buffer_event(buffer, event, cx);
388 }),
389 cx.observe_release(buffer, move |this, _buffer, _cx| {
390 this.registered_buffers.remove(&weak_buffer.entity_id());
391 }),
392 ],
393 });
394 };
395 }
396
397 fn handle_buffer_event(
398 &mut self,
399 buffer: Entity<Buffer>,
400 event: &language::BufferEvent,
401 cx: &mut Context<Self>,
402 ) {
403 if let language::BufferEvent::Edited = event {
404 self.report_changes_for_buffer(&buffer, cx);
405 }
406 }
407
408 fn request_completion_impl<F, R>(
409 &mut self,
410 workspace: Option<Entity<Workspace>>,
411 project: Option<Entity<Project>>,
412 buffer: &Entity<Buffer>,
413 cursor: language::Anchor,
414 can_collect_data: CanCollectData,
415 cx: &mut Context<Self>,
416 perform_predict_edits: F,
417 ) -> Task<Result<Option<EditPrediction>>>
418 where
419 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
420 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
421 + Send
422 + 'static,
423 {
424 let buffer = buffer.clone();
425 let buffer_snapshotted_at = Instant::now();
426 let snapshot = self.report_changes_for_buffer(&buffer, cx);
427 let zeta = cx.entity();
428 let events = self.events.clone();
429 let client = self.client.clone();
430 let llm_token = self.llm_token.clone();
431 let app_version = AppVersion::global(cx);
432
433 let full_path: Arc<Path> = snapshot
434 .file()
435 .map(|f| Arc::from(f.full_path(cx).as_path()))
436 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
437 let full_path_str = full_path.to_string_lossy().to_string();
438 let cursor_point = cursor.to_point(&snapshot);
439 let cursor_offset = cursor_point.to_offset(&snapshot);
440 let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
441 let gather_task = cx.background_spawn(gather_context(
442 full_path_str,
443 snapshot.clone(),
444 cursor_point,
445 make_events_prompt,
446 can_collect_data,
447 ));
448
449 cx.spawn(async move |this, cx| {
450 let GatherContextOutput {
451 body,
452 editable_range,
453 } = gather_task.await?;
454 let done_gathering_context_at = Instant::now();
455
456 let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
457 && let Some(file) = snapshot.file()
458 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
459 {
460 // This is async to reduce latency of the edit predictions request. The downside is
461 // that it will see a slightly later state than was used when gathering context.
462 let snapshot = snapshot.clone();
463 let this = this.clone();
464 Some(cx.spawn(async move |cx| {
465 if let Ok(Some(task)) = this.update(cx, |this, cx| {
466 this.gather_additional_context(
467 cursor_point,
468 cursor_offset,
469 snapshot,
470 &buffer_snapshotted_at,
471 project_path,
472 project.as_ref(),
473 cx,
474 )
475 }) {
476 Some(task.await)
477 } else {
478 None
479 }
480 }))
481 } else {
482 None
483 };
484
485 log::debug!(
486 "Events:\n{}\nExcerpt:\n{:?}",
487 body.input_events,
488 body.input_excerpt
489 );
490
491 let input_events = body.input_events.clone();
492 let input_excerpt = body.input_excerpt.clone();
493
494 let response = perform_predict_edits(PerformPredictEditsParams {
495 client,
496 llm_token,
497 app_version,
498 body,
499 })
500 .await;
501 let (response, usage) = match response {
502 Ok(response) => response,
503 Err(err) => {
504 if err.is::<ZedUpdateRequiredError>() {
505 cx.update(|cx| {
506 zeta.update(cx, |zeta, _cx| {
507 zeta.update_required = true;
508 });
509
510 if let Some(workspace) = workspace {
511 workspace.update(cx, |workspace, cx| {
512 workspace.show_notification(
513 NotificationId::unique::<ZedUpdateRequiredError>(),
514 cx,
515 |cx| {
516 cx.new(|cx| {
517 ErrorMessagePrompt::new(err.to_string(), cx)
518 .with_link_button(
519 "Update Zed",
520 "https://zed.dev/releases",
521 )
522 })
523 },
524 );
525 });
526 }
527 })
528 .ok();
529 }
530
531 return Err(err);
532 }
533 };
534
535 let received_response_at = Instant::now();
536 log::debug!("completion response: {}", &response.output_excerpt);
537
538 if let Some(usage) = usage {
539 this.update(cx, |this, cx| {
540 this.user_store.update(cx, |user_store, cx| {
541 user_store.update_edit_prediction_usage(usage, cx);
542 });
543 })
544 .ok();
545 }
546
547 let request_id = response.request_id.clone();
548 let edit_prediction = Self::process_completion_response(
549 response,
550 buffer,
551 &snapshot,
552 editable_range,
553 cursor_offset,
554 full_path,
555 input_events,
556 input_excerpt,
557 buffer_snapshotted_at,
558 cx,
559 )
560 .await;
561
562 let finished_at = Instant::now();
563
564 // record latency for ~1% of requests
565 if rand::random::<u8>() <= 2 {
566 telemetry::event!(
567 "Edit Prediction Request",
568 context_latency = done_gathering_context_at
569 .duration_since(buffer_snapshotted_at)
570 .as_millis(),
571 request_latency = received_response_at
572 .duration_since(done_gathering_context_at)
573 .as_millis(),
574 process_latency = finished_at.duration_since(received_response_at).as_millis()
575 );
576 }
577
578 if let Some(additional_context_task) = additional_context_task {
579 cx.background_spawn(async move {
580 if let Some(additional_context) = additional_context_task.await {
581 telemetry::event!(
582 "Edit Prediction Additional Context",
583 request_id = request_id,
584 additional_context = additional_context
585 );
586 }
587 })
588 .detach();
589 }
590
591 edit_prediction
592 })
593 }
594
595 // Generates several example completions of various states to fill the Zeta completion modal
596 #[cfg(any(test, feature = "test-support"))]
597 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
598 use language::Point;
599
600 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
601 And maybe a short line
602
603 Then a few lines
604
605 and then another
606 "#};
607
608 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
609 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
610
611 let completion_tasks = vec![
612 self.fake_completion(
613 None,
614 &buffer,
615 position,
616 PredictEditsResponse {
617 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
618 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
619a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
620[here's an edit]
621And maybe a short line
622Then a few lines
623and then another
624{EDITABLE_REGION_END_MARKER}
625 "),
626 },
627 cx,
628 ),
629 self.fake_completion(
630 None,
631 &buffer,
632 position,
633 PredictEditsResponse {
634 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
635 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
636a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
637And maybe a short line
638[and another edit]
639Then a few lines
640and then another
641{EDITABLE_REGION_END_MARKER}
642 "#),
643 },
644 cx,
645 ),
646 self.fake_completion(
647 None,
648 &buffer,
649 position,
650 PredictEditsResponse {
651 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
652 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
653a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
654And maybe a short line
655
656Then a few lines
657
658and then another
659{EDITABLE_REGION_END_MARKER}
660 "#),
661 },
662 cx,
663 ),
664 self.fake_completion(
665 None,
666 &buffer,
667 position,
668 PredictEditsResponse {
669 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
670 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
671a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
672And maybe a short line
673
674Then a few lines
675
676and then another
677{EDITABLE_REGION_END_MARKER}
678 "#),
679 },
680 cx,
681 ),
682 self.fake_completion(
683 None,
684 &buffer,
685 position,
686 PredictEditsResponse {
687 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
688 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
689a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
690And maybe a short line
691Then a few lines
692[a third completion]
693and then another
694{EDITABLE_REGION_END_MARKER}
695 "#),
696 },
697 cx,
698 ),
699 self.fake_completion(
700 None,
701 &buffer,
702 position,
703 PredictEditsResponse {
704 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
705 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
706a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
707And maybe a short line
708and then another
709[fourth completion example]
710{EDITABLE_REGION_END_MARKER}
711 "#),
712 },
713 cx,
714 ),
715 self.fake_completion(
716 None,
717 &buffer,
718 position,
719 PredictEditsResponse {
720 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
721 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
722a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
723And maybe a short line
724Then a few lines
725and then another
726[fifth and final completion]
727{EDITABLE_REGION_END_MARKER}
728 "#),
729 },
730 cx,
731 ),
732 ];
733
734 cx.spawn(async move |zeta, cx| {
735 for task in completion_tasks {
736 task.await.unwrap();
737 }
738
739 zeta.update(cx, |zeta, _cx| {
740 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
741 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
742 })
743 .ok();
744 })
745 }
746
747 #[cfg(any(test, feature = "test-support"))]
748 pub fn fake_completion(
749 &mut self,
750 project: Option<Entity<Project>>,
751 buffer: &Entity<Buffer>,
752 position: language::Anchor,
753 response: PredictEditsResponse,
754 cx: &mut Context<Self>,
755 ) -> Task<Result<Option<EditPrediction>>> {
756 use std::future::ready;
757
758 self.request_completion_impl(
759 None,
760 project,
761 buffer,
762 position,
763 CanCollectData(false),
764 cx,
765 |_params| ready(Ok((response, None))),
766 )
767 }
768
769 pub fn request_completion(
770 &mut self,
771 project: Option<Entity<Project>>,
772 buffer: &Entity<Buffer>,
773 position: language::Anchor,
774 can_collect_data: CanCollectData,
775 cx: &mut Context<Self>,
776 ) -> Task<Result<Option<EditPrediction>>> {
777 self.request_completion_impl(
778 self.workspace.upgrade(),
779 project,
780 buffer,
781 position,
782 can_collect_data,
783 cx,
784 Self::perform_predict_edits,
785 )
786 }
787
788 pub fn perform_predict_edits(
789 params: PerformPredictEditsParams,
790 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
791 async move {
792 let PerformPredictEditsParams {
793 client,
794 llm_token,
795 app_version,
796 body,
797 ..
798 } = params;
799
800 let http_client = client.http_client();
801 let mut token = llm_token.acquire(&client).await?;
802 let mut did_retry = false;
803
804 loop {
805 let request_builder = http_client::Request::builder().method(Method::POST);
806 let request_builder =
807 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
808 request_builder.uri(predict_edits_url)
809 } else {
810 request_builder.uri(
811 http_client
812 .build_zed_llm_url("/predict_edits/v2", &[])?
813 .as_ref(),
814 )
815 };
816 let request = request_builder
817 .header("Content-Type", "application/json")
818 .header("Authorization", format!("Bearer {}", token))
819 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
820 .body(serde_json::to_string(&body)?.into())?;
821
822 let mut response = http_client.send(request).await?;
823
824 if let Some(minimum_required_version) = response
825 .headers()
826 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
827 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
828 {
829 anyhow::ensure!(
830 app_version >= minimum_required_version,
831 ZedUpdateRequiredError {
832 minimum_version: minimum_required_version
833 }
834 );
835 }
836
837 if response.status().is_success() {
838 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
839
840 let mut body = String::new();
841 response.body_mut().read_to_string(&mut body).await?;
842 return Ok((serde_json::from_str(&body)?, usage));
843 } else if !did_retry
844 && response
845 .headers()
846 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
847 .is_some()
848 {
849 did_retry = true;
850 token = llm_token.refresh(&client).await?;
851 } else {
852 let mut body = String::new();
853 response.body_mut().read_to_string(&mut body).await?;
854 anyhow::bail!(
855 "error predicting edits.\nStatus: {:?}\nBody: {}",
856 response.status(),
857 body
858 );
859 }
860 }
861 }
862 }
863
864 fn accept_edit_prediction(
865 &mut self,
866 request_id: EditPredictionId,
867 cx: &mut Context<Self>,
868 ) -> Task<Result<()>> {
869 let client = self.client.clone();
870 let llm_token = self.llm_token.clone();
871 let app_version = AppVersion::global(cx);
872 cx.spawn(async move |this, cx| {
873 let http_client = client.http_client();
874 let mut response = llm_token_retry(&llm_token, &client, |token| {
875 let request_builder = http_client::Request::builder().method(Method::POST);
876 let request_builder =
877 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
878 request_builder.uri(accept_prediction_url)
879 } else {
880 request_builder.uri(
881 http_client
882 .build_zed_llm_url("/predict_edits/accept", &[])?
883 .as_ref(),
884 )
885 };
886 Ok(request_builder
887 .header("Content-Type", "application/json")
888 .header("Authorization", format!("Bearer {}", token))
889 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
890 .body(
891 serde_json::to_string(&AcceptEditPredictionBody {
892 request_id: request_id.0,
893 })?
894 .into(),
895 )?)
896 })
897 .await?;
898
899 if let Some(minimum_required_version) = response
900 .headers()
901 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
902 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
903 && app_version < minimum_required_version
904 {
905 return Err(anyhow!(ZedUpdateRequiredError {
906 minimum_version: minimum_required_version
907 }));
908 }
909
910 if response.status().is_success() {
911 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
912 this.update(cx, |this, cx| {
913 this.user_store.update(cx, |user_store, cx| {
914 user_store.update_edit_prediction_usage(usage, cx);
915 });
916 })?;
917 }
918
919 Ok(())
920 } else {
921 let mut body = String::new();
922 response.body_mut().read_to_string(&mut body).await?;
923 Err(anyhow!(
924 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
925 response.status(),
926 body
927 ))
928 }
929 })
930 }
931
932 fn process_completion_response(
933 prediction_response: PredictEditsResponse,
934 buffer: Entity<Buffer>,
935 snapshot: &BufferSnapshot,
936 editable_range: Range<usize>,
937 cursor_offset: usize,
938 path: Arc<Path>,
939 input_events: String,
940 input_excerpt: String,
941 buffer_snapshotted_at: Instant,
942 cx: &AsyncApp,
943 ) -> Task<Result<Option<EditPrediction>>> {
944 let snapshot = snapshot.clone();
945 let request_id = prediction_response.request_id;
946 let output_excerpt = prediction_response.output_excerpt;
947 cx.spawn(async move |cx| {
948 let output_excerpt: Arc<str> = output_excerpt.into();
949
950 let edits: Arc<[(Range<Anchor>, String)]> = cx
951 .background_spawn({
952 let output_excerpt = output_excerpt.clone();
953 let editable_range = editable_range.clone();
954 let snapshot = snapshot.clone();
955 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
956 })
957 .await?
958 .into();
959
960 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
961 let edits = edits.clone();
962 |buffer, cx| {
963 let new_snapshot = buffer.snapshot();
964 let edits: Arc<[(Range<Anchor>, String)]> =
965 interpolate(&snapshot, &new_snapshot, edits)?.into();
966 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
967 }
968 })?
969 else {
970 return anyhow::Ok(None);
971 };
972
973 let edit_preview = edit_preview.await;
974
975 Ok(Some(EditPrediction {
976 id: EditPredictionId(request_id),
977 path,
978 excerpt_range: editable_range,
979 cursor_offset,
980 edits,
981 edit_preview,
982 snapshot,
983 input_events: input_events.into(),
984 input_excerpt: input_excerpt.into(),
985 output_excerpt,
986 buffer_snapshotted_at,
987 response_received_at: Instant::now(),
988 }))
989 })
990 }
991
992 fn parse_edits(
993 output_excerpt: Arc<str>,
994 editable_range: Range<usize>,
995 snapshot: &BufferSnapshot,
996 ) -> Result<Vec<(Range<Anchor>, String)>> {
997 let content = output_excerpt.replace(CURSOR_MARKER, "");
998
999 let start_markers = content
1000 .match_indices(EDITABLE_REGION_START_MARKER)
1001 .collect::<Vec<_>>();
1002 anyhow::ensure!(
1003 start_markers.len() == 1,
1004 "expected exactly one start marker, found {}",
1005 start_markers.len()
1006 );
1007
1008 let end_markers = content
1009 .match_indices(EDITABLE_REGION_END_MARKER)
1010 .collect::<Vec<_>>();
1011 anyhow::ensure!(
1012 end_markers.len() == 1,
1013 "expected exactly one end marker, found {}",
1014 end_markers.len()
1015 );
1016
1017 let sof_markers = content
1018 .match_indices(START_OF_FILE_MARKER)
1019 .collect::<Vec<_>>();
1020 anyhow::ensure!(
1021 sof_markers.len() <= 1,
1022 "expected at most one start-of-file marker, found {}",
1023 sof_markers.len()
1024 );
1025
1026 let codefence_start = start_markers[0].0;
1027 let content = &content[codefence_start..];
1028
1029 let newline_ix = content.find('\n').context("could not find newline")?;
1030 let content = &content[newline_ix + 1..];
1031
1032 let codefence_end = content
1033 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1034 .context("could not find end marker")?;
1035 let new_text = &content[..codefence_end];
1036
1037 let old_text = snapshot
1038 .text_for_range(editable_range.clone())
1039 .collect::<String>();
1040
1041 Ok(Self::compute_edits(
1042 old_text,
1043 new_text,
1044 editable_range.start,
1045 snapshot,
1046 ))
1047 }
1048
1049 pub fn compute_edits(
1050 old_text: String,
1051 new_text: &str,
1052 offset: usize,
1053 snapshot: &BufferSnapshot,
1054 ) -> Vec<(Range<Anchor>, String)> {
1055 text_diff(&old_text, new_text)
1056 .into_iter()
1057 .map(|(mut old_range, new_text)| {
1058 old_range.start += offset;
1059 old_range.end += offset;
1060
1061 let prefix_len = common_prefix(
1062 snapshot.chars_for_range(old_range.clone()),
1063 new_text.chars(),
1064 );
1065 old_range.start += prefix_len;
1066
1067 let suffix_len = common_prefix(
1068 snapshot.reversed_chars_for_range(old_range.clone()),
1069 new_text[prefix_len..].chars().rev(),
1070 );
1071 old_range.end = old_range.end.saturating_sub(suffix_len);
1072
1073 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1074 let range = if old_range.is_empty() {
1075 let anchor = snapshot.anchor_after(old_range.start);
1076 anchor..anchor
1077 } else {
1078 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1079 };
1080 (range, new_text)
1081 })
1082 .collect()
1083 }
1084
1085 pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1086 self.rated_completions.contains(&completion_id)
1087 }
1088
1089 pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1090 if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1091 let completion = self.shown_completions.pop_back().unwrap();
1092 self.rated_completions.remove(&completion.id);
1093 }
1094 self.shown_completions.push_front(completion.clone());
1095 cx.notify();
1096 }
1097
1098 pub fn rate_completion(
1099 &mut self,
1100 completion: &EditPrediction,
1101 rating: EditPredictionRating,
1102 feedback: String,
1103 cx: &mut Context<Self>,
1104 ) {
1105 self.rated_completions.insert(completion.id);
1106 telemetry::event!(
1107 "Edit Prediction Rated",
1108 rating,
1109 input_events = completion.input_events,
1110 input_excerpt = completion.input_excerpt,
1111 output_excerpt = completion.output_excerpt,
1112 feedback
1113 );
1114 self.client.telemetry().flush_events().detach();
1115 cx.notify();
1116 }
1117
1118 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1119 self.shown_completions.iter()
1120 }
1121
1122 pub fn shown_completions_len(&self) -> usize {
1123 self.shown_completions.len()
1124 }
1125
1126 fn report_changes_for_buffer(
1127 &mut self,
1128 buffer: &Entity<Buffer>,
1129 cx: &mut Context<Self>,
1130 ) -> BufferSnapshot {
1131 self.register_buffer(buffer, cx);
1132
1133 let registered_buffer = self
1134 .registered_buffers
1135 .get_mut(&buffer.entity_id())
1136 .unwrap();
1137 let new_snapshot = buffer.read(cx).snapshot();
1138
1139 if new_snapshot.version != registered_buffer.snapshot.version {
1140 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1141 self.push_event(Event::BufferChange {
1142 old_snapshot,
1143 new_snapshot: new_snapshot.clone(),
1144 timestamp: Instant::now(),
1145 });
1146 }
1147
1148 new_snapshot
1149 }
1150
1151 fn load_data_collection_choices() -> DataCollectionChoice {
1152 let choice = KEY_VALUE_STORE
1153 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1154 .log_err()
1155 .flatten();
1156
1157 match choice.as_deref() {
1158 Some("true") => DataCollectionChoice::Enabled,
1159 Some("false") => DataCollectionChoice::Disabled,
1160 Some(_) => {
1161 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1162 DataCollectionChoice::NotAnswered
1163 }
1164 None => DataCollectionChoice::NotAnswered,
1165 }
1166 }
1167
1168 fn gather_additional_context(
1169 &mut self,
1170 cursor_point: language::Point,
1171 cursor_offset: usize,
1172 snapshot: BufferSnapshot,
1173 buffer_snapshotted_at: &Instant,
1174 project_path: ProjectPath,
1175 project: Option<&Entity<Project>>,
1176 cx: &mut Context<Self>,
1177 ) -> Option<Task<PredictEditsAdditionalContext>> {
1178 let project = project?.read(cx);
1179 let entry = project.entry_for_path(&project_path, cx)?;
1180 if !worktree_entry_eligible_for_collection(&entry) {
1181 return None;
1182 }
1183
1184 let git_store = project.git_store().read(cx);
1185 let (repository, repo_path) =
1186 git_store.repository_and_path_for_project_path(&project_path, cx)?;
1187 let repo_path_string = repo_path.to_str()?.to_string();
1188
1189 let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1190 snapshot
1191 .diagnostics
1192 .iter()
1193 .filter_map(|(language_server_id, diagnostics)| {
1194 let language_server =
1195 local_lsp_store.running_language_server_for_id(*language_server_id)?;
1196 Some((
1197 *language_server_id,
1198 language_server.name(),
1199 diagnostics.clone(),
1200 ))
1201 })
1202 .collect()
1203 } else {
1204 Vec::new()
1205 };
1206
1207 repository.update(cx, |repository, cx| {
1208 let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1209 let remote_origin_url = repository.remote_origin_url.clone();
1210 let remote_upstream_url = repository.remote_upstream_url.clone();
1211 let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1212
1213 // group, resolve, and select diagnostics on a background thread
1214 Some(cx.background_spawn(async move {
1215 let mut diagnostic_groups_with_name = Vec::new();
1216 for (language_server_id, language_server_name, diagnostics) in
1217 diagnostics.into_iter()
1218 {
1219 let mut groups = Vec::new();
1220 diagnostics.groups(language_server_id, &mut groups, &snapshot);
1221 diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1222 (
1223 language_server_name.clone(),
1224 group.resolve::<usize>(&snapshot),
1225 )
1226 }));
1227 }
1228
1229 // sort by proximity to cursor
1230 diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1231 let range = &group.entries[group.primary_ix].range;
1232 if range.start >= cursor_offset {
1233 range.start - cursor_offset
1234 } else if cursor_offset >= range.end {
1235 cursor_offset - range.end
1236 } else {
1237 (cursor_offset - range.start).min(range.end - cursor_offset)
1238 }
1239 });
1240
1241 let mut diagnostic_groups = Vec::new();
1242 let mut diagnostic_groups_truncated = false;
1243 let mut diagnostics_byte_count = 0;
1244 for (name, group) in diagnostic_groups_with_name {
1245 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1246 diagnostics_byte_count += name.0.len() + raw_value.get().len();
1247 if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1248 diagnostic_groups_truncated = true;
1249 break;
1250 }
1251 diagnostic_groups.push((name.to_string(), raw_value));
1252 }
1253
1254 PredictEditsAdditionalContext {
1255 input_path: repo_path_string,
1256 cursor_point: to_cloud_llm_client_point(cursor_point),
1257 cursor_offset: cursor_offset,
1258 git_info: PredictEditsGitInfo {
1259 head_sha: Some(head_sha),
1260 remote_origin_url,
1261 remote_upstream_url,
1262 },
1263 diagnostic_groups,
1264 diagnostic_groups_truncated,
1265 recent_files,
1266 }
1267 }))
1268 })
1269 }
1270
1271 fn handle_active_workspace_item_changed(&mut self, cx: &Context<Self>) {
1272 if let Some(active_editor) = self
1273 .workspace
1274 .read_with(cx, |workspace, cx| {
1275 workspace
1276 .active_item(cx)
1277 .and_then(|item| item.act_as::<Editor>(cx))
1278 })
1279 .ok()
1280 .flatten()
1281 {
1282 let now = Instant::now();
1283 let new_recent = RecentEditor {
1284 editor: active_editor.downgrade(),
1285 last_active_at: now,
1286 };
1287 if let Some(existing_ix) = self
1288 .recent_editors
1289 .iter()
1290 .rposition(|recent| &recent.editor == &new_recent.editor)
1291 {
1292 self.recent_editors.remove(existing_ix);
1293 }
1294 // filter out rapid changes in active item, particularly since this can happen rapidly when
1295 // a workspace is loaded.
1296 if let Some(previous_recent) = self.recent_editors.back_mut()
1297 && now.duration_since(previous_recent.last_active_at)
1298 < MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES
1299 {
1300 *previous_recent = new_recent;
1301 return;
1302 }
1303 if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1304 self.recent_editors.pop_front();
1305 }
1306 self.recent_editors.push_back(new_recent);
1307 }
1308 }
1309
1310 fn recent_files(
1311 &mut self,
1312 now: &Instant,
1313 repository: &Repository,
1314 cx: &mut App,
1315 ) -> Vec<PredictEditsRecentFile> {
1316 let Ok(project) = self
1317 .workspace
1318 .read_with(cx, |workspace, _cx| workspace.project().clone())
1319 else {
1320 return Vec::new();
1321 };
1322 let mut results = Vec::with_capacity(self.recent_editors.len());
1323 for ix in (0..self.recent_editors.len()).rev() {
1324 let recent_editor = &self.recent_editors[ix];
1325 let keep_entry = recent_editor
1326 .editor
1327 .update(cx, |editor, cx| {
1328 maybe!({
1329 let (buffer, cursor_point, _) = editor.cursor_buffer_point(cx)?;
1330 let file = buffer.read(cx).file()?;
1331 let project_path = ProjectPath {
1332 worktree_id: file.worktree_id(cx),
1333 path: file.path().clone(),
1334 };
1335 let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1336 if !worktree_entry_eligible_for_collection(entry) {
1337 return None;
1338 }
1339 let Some(repo_path) =
1340 repository.project_path_to_repo_path(&project_path, cx)
1341 else {
1342 // entry not removed since later queries may involve other repositories
1343 return Some(());
1344 };
1345 // paths may not be valid UTF-8
1346 let repo_path_str = repo_path.to_str()?;
1347 if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1348 return None;
1349 }
1350 let active_to_now_ms = now
1351 .duration_since(recent_editor.last_active_at)
1352 .as_millis()
1353 .try_into()
1354 .ok()?;
1355 results.push(PredictEditsRecentFile {
1356 path: repo_path_str.to_string(),
1357 cursor_point: to_cloud_llm_client_point(cursor_point),
1358 active_to_now_ms,
1359 });
1360 Some(())
1361 })
1362 })
1363 .ok()
1364 .flatten();
1365 if keep_entry.is_none() {
1366 self.recent_editors.remove(ix);
1367 }
1368 }
1369 results
1370 }
1371}
1372
1373fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1374 cloud_llm_client::Point {
1375 row: point.row,
1376 column: point.column,
1377 }
1378}
1379
1380fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1381 entry.is_file()
1382 && entry.is_created()
1383 && !entry.is_ignored
1384 && !entry.is_private
1385 && !entry.is_external
1386 && !entry.is_fifo
1387}
1388
1389pub struct PerformPredictEditsParams {
1390 pub client: Arc<Client>,
1391 pub llm_token: LlmApiToken,
1392 pub app_version: SemanticVersion,
1393 pub body: PredictEditsBody,
1394}
1395
1396#[derive(Error, Debug)]
1397#[error(
1398 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1399)]
1400pub struct ZedUpdateRequiredError {
1401 minimum_version: SemanticVersion,
1402}
1403
1404fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1405 a.zip(b)
1406 .take_while(|(a, b)| a == b)
1407 .map(|(a, _)| a.len_utf8())
1408 .sum()
1409}
1410
1411pub struct GatherContextOutput {
1412 pub body: PredictEditsBody,
1413 pub editable_range: Range<usize>,
1414}
1415
1416pub async fn gather_context(
1417 full_path_str: String,
1418 snapshot: BufferSnapshot,
1419 cursor_point: language::Point,
1420 make_events_prompt: impl FnOnce() -> String + Send + 'static,
1421 can_collect_data: CanCollectData,
1422) -> Result<GatherContextOutput> {
1423 let input_excerpt = excerpt_for_cursor_position(
1424 cursor_point,
1425 &full_path_str,
1426 &snapshot,
1427 MAX_REWRITE_TOKENS,
1428 MAX_CONTEXT_TOKENS,
1429 );
1430 let input_events = make_events_prompt();
1431 let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1432
1433 let body = PredictEditsBody {
1434 input_events,
1435 input_excerpt: input_excerpt.prompt,
1436 can_collect_data: can_collect_data.0,
1437 diagnostic_groups: None,
1438 git_info: None,
1439 };
1440
1441 Ok(GatherContextOutput {
1442 body,
1443 editable_range,
1444 })
1445}
1446
1447fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1448 let mut result = String::new();
1449 for event in events.iter().rev() {
1450 let event_string = event.to_prompt();
1451 let event_tokens = tokens_for_bytes(event_string.len());
1452 if event_tokens > remaining_tokens {
1453 break;
1454 }
1455
1456 if !result.is_empty() {
1457 result.insert_str(0, "\n\n");
1458 }
1459 result.insert_str(0, &event_string);
1460 remaining_tokens -= event_tokens;
1461 }
1462 result
1463}
1464
1465struct RegisteredBuffer {
1466 snapshot: BufferSnapshot,
1467 _subscriptions: [gpui::Subscription; 2],
1468}
1469
1470#[derive(Clone)]
1471pub enum Event {
1472 BufferChange {
1473 old_snapshot: BufferSnapshot,
1474 new_snapshot: BufferSnapshot,
1475 timestamp: Instant,
1476 },
1477}
1478
1479impl Event {
1480 fn to_prompt(&self) -> String {
1481 match self {
1482 Event::BufferChange {
1483 old_snapshot,
1484 new_snapshot,
1485 ..
1486 } => {
1487 let mut prompt = String::new();
1488
1489 let old_path = old_snapshot
1490 .file()
1491 .map(|f| f.path().as_ref())
1492 .unwrap_or(Path::new("untitled"));
1493 let new_path = new_snapshot
1494 .file()
1495 .map(|f| f.path().as_ref())
1496 .unwrap_or(Path::new("untitled"));
1497 if old_path != new_path {
1498 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1499 }
1500
1501 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1502 if !diff.is_empty() {
1503 write!(
1504 prompt,
1505 "User edited {:?}:\n```diff\n{}\n```",
1506 new_path, diff
1507 )
1508 .unwrap();
1509 }
1510
1511 prompt
1512 }
1513 }
1514 }
1515}
1516
1517#[derive(Debug, Clone)]
1518struct CurrentEditPrediction {
1519 buffer_id: EntityId,
1520 completion: EditPrediction,
1521}
1522
1523impl CurrentEditPrediction {
1524 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1525 if self.buffer_id != old_completion.buffer_id {
1526 return true;
1527 }
1528
1529 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1530 return true;
1531 };
1532 let Some(new_edits) = self.completion.interpolate(snapshot) else {
1533 return false;
1534 };
1535
1536 if old_edits.len() == 1 && new_edits.len() == 1 {
1537 let (old_range, old_text) = &old_edits[0];
1538 let (new_range, new_text) = &new_edits[0];
1539 new_range == old_range && new_text.starts_with(old_text)
1540 } else {
1541 true
1542 }
1543 }
1544}
1545
1546struct PendingCompletion {
1547 id: usize,
1548 _task: Task<()>,
1549}
1550
1551#[derive(Debug, Clone, Copy)]
1552pub enum DataCollectionChoice {
1553 NotAnswered,
1554 Enabled,
1555 Disabled,
1556}
1557
1558impl DataCollectionChoice {
1559 pub fn is_enabled(self) -> bool {
1560 match self {
1561 Self::Enabled => true,
1562 Self::NotAnswered | Self::Disabled => false,
1563 }
1564 }
1565
1566 pub fn is_answered(self) -> bool {
1567 match self {
1568 Self::Enabled | Self::Disabled => true,
1569 Self::NotAnswered => false,
1570 }
1571 }
1572
1573 pub fn toggle(&self) -> DataCollectionChoice {
1574 match self {
1575 Self::Enabled => Self::Disabled,
1576 Self::Disabled => Self::Enabled,
1577 Self::NotAnswered => Self::Enabled,
1578 }
1579 }
1580}
1581
1582impl From<bool> for DataCollectionChoice {
1583 fn from(value: bool) -> Self {
1584 match value {
1585 true => DataCollectionChoice::Enabled,
1586 false => DataCollectionChoice::Disabled,
1587 }
1588 }
1589}
1590
1591pub struct ProviderDataCollection {
1592 /// When set to None, data collection is not possible in the provider buffer
1593 choice: Option<Entity<DataCollectionChoice>>,
1594 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1595}
1596
1597#[derive(Debug, Clone, Copy)]
1598pub struct CanCollectData(pub bool);
1599
1600impl ProviderDataCollection {
1601 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1602 let choice_and_watcher = buffer.and_then(|buffer| {
1603 let file = buffer.read(cx).file()?;
1604
1605 if !file.is_local() || file.is_private() {
1606 return None;
1607 }
1608
1609 let zeta = zeta.read(cx);
1610 let choice = zeta.data_collection_choice.clone();
1611
1612 let license_detection_watcher = zeta
1613 .license_detection_watchers
1614 .get(&file.worktree_id(cx))
1615 .cloned()?;
1616
1617 Some((choice, license_detection_watcher))
1618 });
1619
1620 if let Some((choice, watcher)) = choice_and_watcher {
1621 ProviderDataCollection {
1622 choice: Some(choice),
1623 license_detection_watcher: Some(watcher),
1624 }
1625 } else {
1626 ProviderDataCollection {
1627 choice: None,
1628 license_detection_watcher: None,
1629 }
1630 }
1631 }
1632
1633 pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1634 CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1635 }
1636
1637 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1638 self.choice
1639 .as_ref()
1640 .is_some_and(|choice| choice.read(cx).is_enabled())
1641 }
1642
1643 fn is_project_open_source(&self) -> bool {
1644 self.license_detection_watcher
1645 .as_ref()
1646 .is_some_and(|watcher| watcher.is_project_open_source())
1647 }
1648
1649 pub fn toggle(&mut self, cx: &mut App) {
1650 if let Some(choice) = self.choice.as_mut() {
1651 let new_choice = choice.update(cx, |choice, _cx| {
1652 let new_choice = choice.toggle();
1653 *choice = new_choice;
1654 new_choice
1655 });
1656
1657 db::write_and_log(cx, move || {
1658 KEY_VALUE_STORE.write_kvp(
1659 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1660 new_choice.is_enabled().to_string(),
1661 )
1662 });
1663 }
1664 }
1665}
1666
1667async fn llm_token_retry(
1668 llm_token: &LlmApiToken,
1669 client: &Arc<Client>,
1670 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1671) -> Result<Response<AsyncBody>> {
1672 let mut did_retry = false;
1673 let http_client = client.http_client();
1674 let mut token = llm_token.acquire(client).await?;
1675 loop {
1676 let request = build_request(token.clone())?;
1677 let response = http_client.send(request).await?;
1678
1679 if !did_retry
1680 && !response.status().is_success()
1681 && response
1682 .headers()
1683 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1684 .is_some()
1685 {
1686 did_retry = true;
1687 token = llm_token.refresh(client).await?;
1688 continue;
1689 }
1690
1691 return Ok(response);
1692 }
1693}
1694
1695pub struct ZetaEditPredictionProvider {
1696 zeta: Entity<Zeta>,
1697 pending_completions: ArrayVec<PendingCompletion, 2>,
1698 next_pending_completion_id: usize,
1699 current_completion: Option<CurrentEditPrediction>,
1700 /// None if this is entirely disabled for this provider
1701 provider_data_collection: ProviderDataCollection,
1702 last_request_timestamp: Instant,
1703}
1704
1705impl ZetaEditPredictionProvider {
1706 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1707
1708 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1709 Self {
1710 zeta,
1711 pending_completions: ArrayVec::new(),
1712 next_pending_completion_id: 0,
1713 current_completion: None,
1714 provider_data_collection,
1715 last_request_timestamp: Instant::now(),
1716 }
1717 }
1718}
1719
1720impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1721 fn name() -> &'static str {
1722 "zed-predict"
1723 }
1724
1725 fn display_name() -> &'static str {
1726 "Zed's Edit Predictions"
1727 }
1728
1729 fn show_completions_in_menu() -> bool {
1730 true
1731 }
1732
1733 fn show_tab_accept_marker() -> bool {
1734 true
1735 }
1736
1737 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1738 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1739
1740 if self.provider_data_collection.is_data_collection_enabled(cx) {
1741 DataCollectionState::Enabled {
1742 is_project_open_source,
1743 }
1744 } else {
1745 DataCollectionState::Disabled {
1746 is_project_open_source,
1747 }
1748 }
1749 }
1750
1751 fn toggle_data_collection(&mut self, cx: &mut App) {
1752 self.provider_data_collection.toggle(cx);
1753 }
1754
1755 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1756 self.zeta.read(cx).usage(cx)
1757 }
1758
1759 fn is_enabled(
1760 &self,
1761 _buffer: &Entity<Buffer>,
1762 _cursor_position: language::Anchor,
1763 _cx: &App,
1764 ) -> bool {
1765 true
1766 }
1767 fn is_refreshing(&self) -> bool {
1768 !self.pending_completions.is_empty()
1769 }
1770
1771 fn refresh(
1772 &mut self,
1773 project: Option<Entity<Project>>,
1774 buffer: Entity<Buffer>,
1775 position: language::Anchor,
1776 _debounce: bool,
1777 cx: &mut Context<Self>,
1778 ) {
1779 if self.zeta.read(cx).update_required {
1780 return;
1781 }
1782
1783 if self
1784 .zeta
1785 .read(cx)
1786 .user_store
1787 .read_with(cx, |user_store, _cx| {
1788 user_store.account_too_young() || user_store.has_overdue_invoices()
1789 })
1790 {
1791 return;
1792 }
1793
1794 if let Some(current_completion) = self.current_completion.as_ref() {
1795 let snapshot = buffer.read(cx).snapshot();
1796 if current_completion
1797 .completion
1798 .interpolate(&snapshot)
1799 .is_some()
1800 {
1801 return;
1802 }
1803 }
1804
1805 let pending_completion_id = self.next_pending_completion_id;
1806 self.next_pending_completion_id += 1;
1807 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1808 let last_request_timestamp = self.last_request_timestamp;
1809
1810 let task = cx.spawn(async move |this, cx| {
1811 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1812 .checked_duration_since(Instant::now())
1813 {
1814 cx.background_executor().timer(timeout).await;
1815 }
1816
1817 let completion_request = this.update(cx, |this, cx| {
1818 this.last_request_timestamp = Instant::now();
1819 this.zeta.update(cx, |zeta, cx| {
1820 zeta.request_completion(project, &buffer, position, can_collect_data, cx)
1821 })
1822 });
1823
1824 let completion = match completion_request {
1825 Ok(completion_request) => {
1826 let completion_request = completion_request.await;
1827 completion_request.map(|c| {
1828 c.map(|completion| CurrentEditPrediction {
1829 buffer_id: buffer.entity_id(),
1830 completion,
1831 })
1832 })
1833 }
1834 Err(error) => Err(error),
1835 };
1836 let Some(new_completion) = completion
1837 .context("edit prediction failed")
1838 .log_err()
1839 .flatten()
1840 else {
1841 this.update(cx, |this, cx| {
1842 if this.pending_completions[0].id == pending_completion_id {
1843 this.pending_completions.remove(0);
1844 } else {
1845 this.pending_completions.clear();
1846 }
1847
1848 cx.notify();
1849 })
1850 .ok();
1851 return;
1852 };
1853
1854 this.update(cx, |this, cx| {
1855 if this.pending_completions[0].id == pending_completion_id {
1856 this.pending_completions.remove(0);
1857 } else {
1858 this.pending_completions.clear();
1859 }
1860
1861 if let Some(old_completion) = this.current_completion.as_ref() {
1862 let snapshot = buffer.read(cx).snapshot();
1863 if new_completion.should_replace_completion(old_completion, &snapshot) {
1864 this.zeta.update(cx, |zeta, cx| {
1865 zeta.completion_shown(&new_completion.completion, cx);
1866 });
1867 this.current_completion = Some(new_completion);
1868 }
1869 } else {
1870 this.zeta.update(cx, |zeta, cx| {
1871 zeta.completion_shown(&new_completion.completion, cx);
1872 });
1873 this.current_completion = Some(new_completion);
1874 }
1875
1876 cx.notify();
1877 })
1878 .ok();
1879 });
1880
1881 // We always maintain at most two pending completions. When we already
1882 // have two, we replace the newest one.
1883 if self.pending_completions.len() <= 1 {
1884 self.pending_completions.push(PendingCompletion {
1885 id: pending_completion_id,
1886 _task: task,
1887 });
1888 } else if self.pending_completions.len() == 2 {
1889 self.pending_completions.pop();
1890 self.pending_completions.push(PendingCompletion {
1891 id: pending_completion_id,
1892 _task: task,
1893 });
1894 }
1895 }
1896
1897 fn cycle(
1898 &mut self,
1899 _buffer: Entity<Buffer>,
1900 _cursor_position: language::Anchor,
1901 _direction: edit_prediction::Direction,
1902 _cx: &mut Context<Self>,
1903 ) {
1904 // Right now we don't support cycling.
1905 }
1906
1907 fn accept(&mut self, cx: &mut Context<Self>) {
1908 let completion_id = self
1909 .current_completion
1910 .as_ref()
1911 .map(|completion| completion.completion.id);
1912 if let Some(completion_id) = completion_id {
1913 self.zeta
1914 .update(cx, |zeta, cx| {
1915 zeta.accept_edit_prediction(completion_id, cx)
1916 })
1917 .detach();
1918 }
1919 self.pending_completions.clear();
1920 }
1921
1922 fn discard(&mut self, _cx: &mut Context<Self>) {
1923 self.pending_completions.clear();
1924 self.current_completion.take();
1925 }
1926
1927 fn suggest(
1928 &mut self,
1929 buffer: &Entity<Buffer>,
1930 cursor_position: language::Anchor,
1931 cx: &mut Context<Self>,
1932 ) -> Option<edit_prediction::EditPrediction> {
1933 let CurrentEditPrediction {
1934 buffer_id,
1935 completion,
1936 ..
1937 } = self.current_completion.as_mut()?;
1938
1939 // Invalidate previous completion if it was generated for a different buffer.
1940 if *buffer_id != buffer.entity_id() {
1941 self.current_completion.take();
1942 return None;
1943 }
1944
1945 let buffer = buffer.read(cx);
1946 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1947 self.current_completion.take();
1948 return None;
1949 };
1950
1951 let cursor_row = cursor_position.to_point(buffer).row;
1952 let (closest_edit_ix, (closest_edit_range, _)) =
1953 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1954 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1955 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1956 cmp::min(distance_from_start, distance_from_end)
1957 })?;
1958
1959 let mut edit_start_ix = closest_edit_ix;
1960 for (range, _) in edits[..edit_start_ix].iter().rev() {
1961 let distance_from_closest_edit =
1962 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1963 if distance_from_closest_edit <= 1 {
1964 edit_start_ix -= 1;
1965 } else {
1966 break;
1967 }
1968 }
1969
1970 let mut edit_end_ix = closest_edit_ix + 1;
1971 for (range, _) in &edits[edit_end_ix..] {
1972 let distance_from_closest_edit =
1973 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1974 if distance_from_closest_edit <= 1 {
1975 edit_end_ix += 1;
1976 } else {
1977 break;
1978 }
1979 }
1980
1981 Some(edit_prediction::EditPrediction {
1982 id: Some(completion.id.to_string().into()),
1983 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1984 edit_preview: Some(completion.edit_preview.clone()),
1985 })
1986 }
1987}
1988
1989fn tokens_for_bytes(bytes: usize) -> usize {
1990 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1991 /// intentionally low to err on the side of underestimating limits.
1992 const BYTES_PER_TOKEN_GUESS: usize = 3;
1993 bytes / BYTES_PER_TOKEN_GUESS
1994}
1995
1996#[cfg(test)]
1997mod tests {
1998 use client::UserStore;
1999 use client::test::FakeServer;
2000 use clock::FakeSystemClock;
2001 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2002 use gpui::TestAppContext;
2003 use http_client::FakeHttpClient;
2004 use indoc::indoc;
2005 use language::Point;
2006 use settings::SettingsStore;
2007
2008 use super::*;
2009
2010 #[gpui::test]
2011 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2012 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2013 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2014 to_completion_edits(
2015 [(2..5, "REM".to_string()), (9..11, "".to_string())],
2016 &buffer,
2017 cx,
2018 )
2019 .into()
2020 });
2021
2022 let edit_preview = cx
2023 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2024 .await;
2025
2026 let completion = EditPrediction {
2027 edits,
2028 edit_preview,
2029 path: Path::new("").into(),
2030 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2031 id: EditPredictionId(Uuid::new_v4()),
2032 excerpt_range: 0..0,
2033 cursor_offset: 0,
2034 input_events: "".into(),
2035 input_excerpt: "".into(),
2036 output_excerpt: "".into(),
2037 buffer_snapshotted_at: Instant::now(),
2038 response_received_at: Instant::now(),
2039 };
2040
2041 cx.update(|cx| {
2042 assert_eq!(
2043 from_completion_edits(
2044 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2045 &buffer,
2046 cx
2047 ),
2048 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2049 );
2050
2051 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2052 assert_eq!(
2053 from_completion_edits(
2054 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2055 &buffer,
2056 cx
2057 ),
2058 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2059 );
2060
2061 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2062 assert_eq!(
2063 from_completion_edits(
2064 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2065 &buffer,
2066 cx
2067 ),
2068 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2069 );
2070
2071 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2072 assert_eq!(
2073 from_completion_edits(
2074 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2075 &buffer,
2076 cx
2077 ),
2078 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2079 );
2080
2081 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2082 assert_eq!(
2083 from_completion_edits(
2084 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2085 &buffer,
2086 cx
2087 ),
2088 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2089 );
2090
2091 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2092 assert_eq!(
2093 from_completion_edits(
2094 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2095 &buffer,
2096 cx
2097 ),
2098 vec![(9..11, "".to_string())]
2099 );
2100
2101 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2102 assert_eq!(
2103 from_completion_edits(
2104 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2105 &buffer,
2106 cx
2107 ),
2108 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2109 );
2110
2111 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2112 assert_eq!(
2113 from_completion_edits(
2114 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2115 &buffer,
2116 cx
2117 ),
2118 vec![(4..4, "M".to_string())]
2119 );
2120
2121 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2122 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2123 })
2124 }
2125
2126 #[gpui::test]
2127 async fn test_clean_up_diff(cx: &mut TestAppContext) {
2128 cx.update(|cx| {
2129 let settings_store = SettingsStore::test(cx);
2130 cx.set_global(settings_store);
2131 client::init_settings(cx);
2132 });
2133
2134 let edits = edits_for_prediction(
2135 indoc! {"
2136 fn main() {
2137 let word_1 = \"lorem\";
2138 let range = word.len()..word.len();
2139 }
2140 "},
2141 indoc! {"
2142 <|editable_region_start|>
2143 fn main() {
2144 let word_1 = \"lorem\";
2145 let range = word_1.len()..word_1.len();
2146 }
2147
2148 <|editable_region_end|>
2149 "},
2150 cx,
2151 )
2152 .await;
2153 assert_eq!(
2154 edits,
2155 [
2156 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2157 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2158 ]
2159 );
2160
2161 let edits = edits_for_prediction(
2162 indoc! {"
2163 fn main() {
2164 let story = \"the quick\"
2165 }
2166 "},
2167 indoc! {"
2168 <|editable_region_start|>
2169 fn main() {
2170 let story = \"the quick brown fox jumps over the lazy dog\";
2171 }
2172
2173 <|editable_region_end|>
2174 "},
2175 cx,
2176 )
2177 .await;
2178 assert_eq!(
2179 edits,
2180 [
2181 (
2182 Point::new(1, 26)..Point::new(1, 26),
2183 " brown fox jumps over the lazy dog".to_string()
2184 ),
2185 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2186 ]
2187 );
2188 }
2189
2190 #[gpui::test]
2191 async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2192 cx.update(|cx| {
2193 let settings_store = SettingsStore::test(cx);
2194 cx.set_global(settings_store);
2195 client::init_settings(cx);
2196 });
2197
2198 let buffer_content = "lorem\n";
2199 let completion_response = indoc! {"
2200 ```animals.js
2201 <|start_of_file|>
2202 <|editable_region_start|>
2203 lorem
2204 ipsum
2205 <|editable_region_end|>
2206 ```"};
2207
2208 let http_client = FakeHttpClient::create(move |req| async move {
2209 match (req.method(), req.uri().path()) {
2210 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2211 .status(200)
2212 .body(
2213 serde_json::to_string(&CreateLlmTokenResponse {
2214 token: LlmToken("the-llm-token".to_string()),
2215 })
2216 .unwrap()
2217 .into(),
2218 )
2219 .unwrap()),
2220 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2221 .status(200)
2222 .body(
2223 serde_json::to_string(&PredictEditsResponse {
2224 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2225 .unwrap(),
2226 output_excerpt: completion_response.to_string(),
2227 })
2228 .unwrap()
2229 .into(),
2230 )
2231 .unwrap()),
2232 _ => Ok(http_client::Response::builder()
2233 .status(404)
2234 .body("Not Found".into())
2235 .unwrap()),
2236 }
2237 });
2238
2239 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2240 cx.update(|cx| {
2241 RefreshLlmTokenListener::register(client.clone(), cx);
2242 });
2243 // Construct the fake server to authenticate.
2244 let _server = FakeServer::for_client(42, &client, cx).await;
2245 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2246 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2247
2248 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2249 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2250 let completion_task = zeta.update(cx, |zeta, cx| {
2251 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2252 });
2253
2254 let completion = completion_task.await.unwrap().unwrap();
2255 buffer.update(cx, |buffer, cx| {
2256 buffer.edit(completion.edits.iter().cloned(), None, cx)
2257 });
2258 assert_eq!(
2259 buffer.read_with(cx, |buffer, _| buffer.text()),
2260 "lorem\nipsum"
2261 );
2262 }
2263
2264 async fn edits_for_prediction(
2265 buffer_content: &str,
2266 completion_response: &str,
2267 cx: &mut TestAppContext,
2268 ) -> Vec<(Range<Point>, String)> {
2269 let completion_response = completion_response.to_string();
2270 let http_client = FakeHttpClient::create(move |req| {
2271 let completion = completion_response.clone();
2272 async move {
2273 match (req.method(), req.uri().path()) {
2274 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2275 .status(200)
2276 .body(
2277 serde_json::to_string(&CreateLlmTokenResponse {
2278 token: LlmToken("the-llm-token".to_string()),
2279 })
2280 .unwrap()
2281 .into(),
2282 )
2283 .unwrap()),
2284 (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2285 .status(200)
2286 .body(
2287 serde_json::to_string(&PredictEditsResponse {
2288 request_id: Uuid::new_v4(),
2289 output_excerpt: completion,
2290 })
2291 .unwrap()
2292 .into(),
2293 )
2294 .unwrap()),
2295 _ => Ok(http_client::Response::builder()
2296 .status(404)
2297 .body("Not Found".into())
2298 .unwrap()),
2299 }
2300 }
2301 });
2302
2303 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2304 cx.update(|cx| {
2305 RefreshLlmTokenListener::register(client.clone(), cx);
2306 });
2307 // Construct the fake server to authenticate.
2308 let _server = FakeServer::for_client(42, &client, cx).await;
2309 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2310 let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2311
2312 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2313 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2314 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2315 let completion_task = zeta.update(cx, |zeta, cx| {
2316 zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2317 });
2318
2319 let completion = completion_task.await.unwrap().unwrap();
2320 completion
2321 .edits
2322 .iter()
2323 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2324 .collect::<Vec<_>>()
2325 }
2326
2327 fn to_completion_edits(
2328 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2329 buffer: &Entity<Buffer>,
2330 cx: &App,
2331 ) -> Vec<(Range<Anchor>, String)> {
2332 let buffer = buffer.read(cx);
2333 iterator
2334 .into_iter()
2335 .map(|(range, text)| {
2336 (
2337 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2338 text,
2339 )
2340 })
2341 .collect()
2342 }
2343
2344 fn from_completion_edits(
2345 editor_edits: &[(Range<Anchor>, String)],
2346 buffer: &Entity<Buffer>,
2347 cx: &App,
2348 ) -> Vec<(Range<usize>, String)> {
2349 let buffer = buffer.read(cx);
2350 editor_edits
2351 .iter()
2352 .map(|(range, text)| {
2353 (
2354 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2355 text.clone(),
2356 )
2357 })
2358 .collect()
2359 }
2360
2361 #[ctor::ctor]
2362 fn init_logger() {
2363 zlog::init_test();
2364 }
2365}