1mod completion_diff_element;
2mod persistence;
3mod rate_completion_modal;
4
5pub(crate) use completion_diff_element::*;
6use db::kvp::KEY_VALUE_STORE;
7use inline_completion::DataCollectionState;
8pub use rate_completion_modal::*;
9
10use anyhow::{anyhow, Context as _, Result};
11use arrayvec::ArrayVec;
12use client::{Client, UserStore};
13use collections::{HashMap, HashSet, VecDeque};
14use futures::AsyncReadExt;
15use gpui::{
16 actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
17 WeakEntity,
18};
19use http_client::{HttpClient, Method};
20use language::{
21 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
22 Point, ToOffset, ToPoint,
23};
24use language_models::LlmApiToken;
25use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
26use serde::{Deserialize, Serialize};
27use std::{
28 borrow::Cow,
29 cmp, env,
30 fmt::Write,
31 future::Future,
32 mem,
33 ops::Range,
34 path::{Path, PathBuf},
35 sync::Arc,
36 time::{Duration, Instant},
37};
38use telemetry_events::InlineCompletionRating;
39use util::ResultExt;
40use uuid::Uuid;
41use workspace::{
42 notifications::{simple_message_notification::MessageNotification, NotificationId},
43 Workspace,
44};
45
46const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
47const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
48const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
49const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
50const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
51const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
52 "zed_predict_data_collection_never_ask_again";
53
54actions!(edit_prediction, [ClearHistory]);
55
56#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
57pub struct InlineCompletionId(Uuid);
58
59impl From<InlineCompletionId> for gpui::ElementId {
60 fn from(value: InlineCompletionId) -> Self {
61 gpui::ElementId::Uuid(value.0)
62 }
63}
64
65impl std::fmt::Display for InlineCompletionId {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "{}", self.0)
68 }
69}
70
71impl InlineCompletionId {
72 fn new() -> Self {
73 Self(Uuid::new_v4())
74 }
75}
76
77#[derive(Clone)]
78struct ZetaGlobal(Entity<Zeta>);
79
80impl Global for ZetaGlobal {}
81
82#[derive(Clone)]
83pub struct InlineCompletion {
84 id: InlineCompletionId,
85 path: Arc<Path>,
86 excerpt_range: Range<usize>,
87 cursor_offset: usize,
88 edits: Arc<[(Range<Anchor>, String)]>,
89 snapshot: BufferSnapshot,
90 input_outline: Arc<str>,
91 input_events: Arc<str>,
92 input_excerpt: Arc<str>,
93 output_excerpt: Arc<str>,
94 request_sent_at: Instant,
95 response_received_at: Instant,
96}
97
98impl InlineCompletion {
99 fn latency(&self) -> Duration {
100 self.response_received_at
101 .duration_since(self.request_sent_at)
102 }
103
104 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
105 let mut edits = Vec::new();
106
107 let mut user_edits = new_snapshot
108 .edits_since::<usize>(&self.snapshot.version)
109 .peekable();
110 for (model_old_range, model_new_text) in self.edits.iter() {
111 let model_offset_range = model_old_range.to_offset(&self.snapshot);
112 while let Some(next_user_edit) = user_edits.peek() {
113 if next_user_edit.old.end < model_offset_range.start {
114 user_edits.next();
115 } else {
116 break;
117 }
118 }
119
120 if let Some(user_edit) = user_edits.peek() {
121 if user_edit.old.start > model_offset_range.end {
122 edits.push((model_old_range.clone(), model_new_text.clone()));
123 } else if user_edit.old == model_offset_range {
124 let user_new_text = new_snapshot
125 .text_for_range(user_edit.new.clone())
126 .collect::<String>();
127
128 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
129 if !model_suffix.is_empty() {
130 edits.push((
131 new_snapshot.anchor_after(user_edit.new.end)
132 ..new_snapshot.anchor_before(user_edit.new.end),
133 model_suffix.into(),
134 ));
135 }
136
137 user_edits.next();
138 } else {
139 return None;
140 }
141 } else {
142 return None;
143 }
144 } else {
145 edits.push((model_old_range.clone(), model_new_text.clone()));
146 }
147 }
148
149 if edits.is_empty() {
150 None
151 } else {
152 Some(edits)
153 }
154 }
155}
156
157impl std::fmt::Debug for InlineCompletion {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("InlineCompletion")
160 .field("id", &self.id)
161 .field("path", &self.path)
162 .field("edits", &self.edits)
163 .finish_non_exhaustive()
164 }
165}
166
167pub struct Zeta {
168 client: Arc<Client>,
169 events: VecDeque<Event>,
170 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
171 shown_completions: VecDeque<InlineCompletion>,
172 rated_completions: HashSet<InlineCompletionId>,
173 data_collection_preferences: DataCollectionPreferences,
174 llm_token: LlmApiToken,
175 _llm_token_subscription: Subscription,
176 tos_accepted: bool, // Terms of service accepted
177 _user_store_subscription: Subscription,
178}
179
180impl Zeta {
181 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
182 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
183 }
184
185 pub fn register(
186 client: Arc<Client>,
187 user_store: Entity<UserStore>,
188 cx: &mut App,
189 ) -> Entity<Self> {
190 Self::global(cx).unwrap_or_else(|| {
191 let model = cx.new(|cx| Self::new(client, user_store, cx));
192 cx.set_global(ZetaGlobal(model.clone()));
193 model
194 })
195 }
196
197 pub fn clear_history(&mut self) {
198 self.events.clear();
199 }
200
201 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
202 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
203 Self {
204 client,
205 events: VecDeque::new(),
206 shown_completions: VecDeque::new(),
207 rated_completions: HashSet::default(),
208 registered_buffers: HashMap::default(),
209 data_collection_preferences: Self::load_data_collection_preferences(cx),
210 llm_token: LlmApiToken::default(),
211 _llm_token_subscription: cx.subscribe(
212 &refresh_llm_token_listener,
213 |this, _listener, _event, cx| {
214 let client = this.client.clone();
215 let llm_token = this.llm_token.clone();
216 cx.spawn(|_this, _cx| async move {
217 llm_token.refresh(&client).await?;
218 anyhow::Ok(())
219 })
220 .detach_and_log_err(cx);
221 },
222 ),
223 tos_accepted: user_store
224 .read(cx)
225 .current_user_has_accepted_terms()
226 .unwrap_or(false),
227 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
228 match event {
229 client::user::Event::PrivateUserInfoUpdated => {
230 this.tos_accepted = user_store
231 .read(cx)
232 .current_user_has_accepted_terms()
233 .unwrap_or(false);
234 }
235 _ => {}
236 }
237 }),
238 }
239 }
240
241 fn push_event(&mut self, event: Event) {
242 const MAX_EVENT_COUNT: usize = 16;
243
244 if let Some(Event::BufferChange {
245 new_snapshot: last_new_snapshot,
246 timestamp: last_timestamp,
247 ..
248 }) = self.events.back_mut()
249 {
250 // Coalesce edits for the same buffer when they happen one after the other.
251 let Event::BufferChange {
252 old_snapshot,
253 new_snapshot,
254 timestamp,
255 } = &event;
256
257 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
258 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
259 && old_snapshot.version == last_new_snapshot.version
260 {
261 *last_new_snapshot = new_snapshot.clone();
262 *last_timestamp = *timestamp;
263 return;
264 }
265 }
266
267 self.events.push_back(event);
268 if self.events.len() >= MAX_EVENT_COUNT {
269 self.events.drain(..MAX_EVENT_COUNT / 2);
270 }
271 }
272
273 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
274 let buffer_id = buffer.entity_id();
275 let weak_buffer = buffer.downgrade();
276
277 if let std::collections::hash_map::Entry::Vacant(entry) =
278 self.registered_buffers.entry(buffer_id)
279 {
280 let snapshot = buffer.read(cx).snapshot();
281
282 entry.insert(RegisteredBuffer {
283 snapshot,
284 _subscriptions: [
285 cx.subscribe(buffer, move |this, buffer, event, cx| {
286 this.handle_buffer_event(buffer, event, cx);
287 }),
288 cx.observe_release(buffer, move |this, _buffer, _cx| {
289 this.registered_buffers.remove(&weak_buffer.entity_id());
290 }),
291 ],
292 });
293 };
294 }
295
296 fn handle_buffer_event(
297 &mut self,
298 buffer: Entity<Buffer>,
299 event: &language::BufferEvent,
300 cx: &mut Context<Self>,
301 ) {
302 if let language::BufferEvent::Edited = event {
303 self.report_changes_for_buffer(&buffer, cx);
304 }
305 }
306
307 pub fn request_completion_impl<F, R>(
308 &mut self,
309 buffer: &Entity<Buffer>,
310 cursor: language::Anchor,
311 can_collect_data: bool,
312 cx: &mut Context<Self>,
313 perform_predict_edits: F,
314 ) -> Task<Result<Option<InlineCompletion>>>
315 where
316 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
317 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
318 {
319 let snapshot = self.report_changes_for_buffer(buffer, cx);
320 let point = cursor.to_point(&snapshot);
321 let offset = point.to_offset(&snapshot);
322 let excerpt_range = excerpt_range_for_position(point, &snapshot);
323 let events = self.events.clone();
324 let path = snapshot
325 .file()
326 .map(|f| Arc::from(f.full_path(cx).as_path()))
327 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
328
329 let client = self.client.clone();
330 let llm_token = self.llm_token.clone();
331
332 cx.spawn(|_, cx| async move {
333 let request_sent_at = Instant::now();
334
335 let (input_events, input_excerpt, input_outline) = cx
336 .background_executor()
337 .spawn({
338 let snapshot = snapshot.clone();
339 let excerpt_range = excerpt_range.clone();
340 async move {
341 let mut input_events = String::new();
342 for event in events {
343 if !input_events.is_empty() {
344 input_events.push('\n');
345 input_events.push('\n');
346 }
347 input_events.push_str(&event.to_prompt());
348 }
349
350 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
351 let input_outline = prompt_for_outline(&snapshot);
352
353 (input_events, input_excerpt, input_outline)
354 }
355 })
356 .await;
357
358 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
359
360 let body = PredictEditsParams {
361 input_events: input_events.clone(),
362 input_excerpt: input_excerpt.clone(),
363 outline: Some(input_outline.clone()),
364 can_collect_data,
365 };
366
367 let response = perform_predict_edits(client, llm_token, body).await?;
368
369 let output_excerpt = response.output_excerpt;
370 log::debug!("completion response: {}", output_excerpt);
371
372 Self::process_completion_response(
373 output_excerpt,
374 &snapshot,
375 excerpt_range,
376 offset,
377 path,
378 input_outline,
379 input_events,
380 input_excerpt,
381 request_sent_at,
382 &cx,
383 )
384 .await
385 })
386 }
387
388 // Generates several example completions of various states to fill the Zeta completion modal
389 #[cfg(any(test, feature = "test-support"))]
390 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
391 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
392 And maybe a short line
393
394 Then a few lines
395
396 and then another
397 "#};
398
399 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
400 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
401
402 let completion_tasks = vec![
403 self.fake_completion(
404 &buffer,
405 position,
406 PredictEditsResponse {
407 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
408a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
409[here's an edit]
410And maybe a short line
411Then a few lines
412and then another
413{EDITABLE_REGION_END_MARKER}
414 ", ),
415 },
416 cx,
417 ),
418 self.fake_completion(
419 &buffer,
420 position,
421 PredictEditsResponse {
422 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
423a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
424And maybe a short line
425[and another edit]
426Then a few lines
427and then another
428{EDITABLE_REGION_END_MARKER}
429 "#),
430 },
431 cx,
432 ),
433 self.fake_completion(
434 &buffer,
435 position,
436 PredictEditsResponse {
437 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
438a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
439And maybe a short line
440
441Then a few lines
442
443and then another
444{EDITABLE_REGION_END_MARKER}
445 "#),
446 },
447 cx,
448 ),
449 self.fake_completion(
450 &buffer,
451 position,
452 PredictEditsResponse {
453 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
454a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
455And maybe a short line
456
457Then a few lines
458
459and then another
460{EDITABLE_REGION_END_MARKER}
461 "#),
462 },
463 cx,
464 ),
465 self.fake_completion(
466 &buffer,
467 position,
468 PredictEditsResponse {
469 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
470a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
471And maybe a short line
472Then a few lines
473[a third completion]
474and then another
475{EDITABLE_REGION_END_MARKER}
476 "#),
477 },
478 cx,
479 ),
480 self.fake_completion(
481 &buffer,
482 position,
483 PredictEditsResponse {
484 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
485a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
486And maybe a short line
487and then another
488[fourth completion example]
489{EDITABLE_REGION_END_MARKER}
490 "#),
491 },
492 cx,
493 ),
494 self.fake_completion(
495 &buffer,
496 position,
497 PredictEditsResponse {
498 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
499a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
500And maybe a short line
501Then a few lines
502and then another
503[fifth and final completion]
504{EDITABLE_REGION_END_MARKER}
505 "#),
506 },
507 cx,
508 ),
509 ];
510
511 cx.spawn(|zeta, mut cx| async move {
512 for task in completion_tasks {
513 task.await.unwrap();
514 }
515
516 zeta.update(&mut cx, |zeta, _cx| {
517 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
518 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
519 })
520 .ok();
521 })
522 }
523
524 #[cfg(any(test, feature = "test-support"))]
525 pub fn fake_completion(
526 &mut self,
527 buffer: &Entity<Buffer>,
528 position: language::Anchor,
529 response: PredictEditsResponse,
530 cx: &mut Context<Self>,
531 ) -> Task<Result<Option<InlineCompletion>>> {
532 use std::future::ready;
533
534 self.request_completion_impl(buffer, position, false, cx, |_, _, _| ready(Ok(response)))
535 }
536
537 pub fn request_completion(
538 &mut self,
539 buffer: &Entity<Buffer>,
540 position: language::Anchor,
541 can_collect_data: bool,
542 cx: &mut Context<Self>,
543 ) -> Task<Result<Option<InlineCompletion>>> {
544 self.request_completion_impl(
545 buffer,
546 position,
547 can_collect_data,
548 cx,
549 Self::perform_predict_edits,
550 )
551 }
552
553 fn perform_predict_edits(
554 client: Arc<Client>,
555 llm_token: LlmApiToken,
556 body: PredictEditsParams,
557 ) -> impl Future<Output = Result<PredictEditsResponse>> {
558 async move {
559 let http_client = client.http_client();
560 let mut token = llm_token.acquire(&client).await?;
561 let mut did_retry = false;
562
563 loop {
564 let request_builder = http_client::Request::builder();
565 let request = request_builder
566 .method(Method::POST)
567 .uri(
568 http_client
569 .build_zed_llm_url("/predict_edits", &[])?
570 .as_ref(),
571 )
572 .header("Content-Type", "application/json")
573 .header("Authorization", format!("Bearer {}", token))
574 .body(serde_json::to_string(&body)?.into())?;
575
576 let mut response = http_client.send(request).await?;
577
578 if response.status().is_success() {
579 let mut body = String::new();
580 response.body_mut().read_to_string(&mut body).await?;
581 return Ok(serde_json::from_str(&body)?);
582 } else if !did_retry
583 && response
584 .headers()
585 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
586 .is_some()
587 {
588 did_retry = true;
589 token = llm_token.refresh(&client).await?;
590 } else {
591 let mut body = String::new();
592 response.body_mut().read_to_string(&mut body).await?;
593 return Err(anyhow!(
594 "error predicting edits.\nStatus: {:?}\nBody: {}",
595 response.status(),
596 body
597 ));
598 }
599 }
600 }
601 }
602
603 #[allow(clippy::too_many_arguments)]
604 fn process_completion_response(
605 output_excerpt: String,
606 snapshot: &BufferSnapshot,
607 excerpt_range: Range<usize>,
608 cursor_offset: usize,
609 path: Arc<Path>,
610 input_outline: String,
611 input_events: String,
612 input_excerpt: String,
613 request_sent_at: Instant,
614 cx: &AsyncApp,
615 ) -> Task<Result<Option<InlineCompletion>>> {
616 let snapshot = snapshot.clone();
617 cx.background_executor().spawn(async move {
618 let content = output_excerpt.replace(CURSOR_MARKER, "");
619
620 let start_markers = content
621 .match_indices(EDITABLE_REGION_START_MARKER)
622 .collect::<Vec<_>>();
623 anyhow::ensure!(
624 start_markers.len() == 1,
625 "expected exactly one start marker, found {}",
626 start_markers.len()
627 );
628
629 let codefence_start = start_markers[0].0;
630 let content = &content[codefence_start..];
631
632 let newline_ix = content.find('\n').context("could not find newline")?;
633 let content = &content[newline_ix + 1..];
634
635 let codefence_end = content
636 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
637 .context("could not find end marker")?;
638 let new_text = &content[..codefence_end];
639
640 let old_text = snapshot
641 .text_for_range(excerpt_range.clone())
642 .collect::<String>();
643
644 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
645
646 Ok(Some(InlineCompletion {
647 id: InlineCompletionId::new(),
648 path,
649 excerpt_range,
650 cursor_offset,
651 edits: edits.into(),
652 snapshot: snapshot.clone(),
653 input_outline: input_outline.into(),
654 input_events: input_events.into(),
655 input_excerpt: input_excerpt.into(),
656 output_excerpt: output_excerpt.into(),
657 request_sent_at,
658 response_received_at: Instant::now(),
659 }))
660 })
661 }
662
663 pub fn compute_edits(
664 old_text: String,
665 new_text: &str,
666 offset: usize,
667 snapshot: &BufferSnapshot,
668 ) -> Vec<(Range<Anchor>, String)> {
669 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
670
671 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
672 let mut old_start = offset;
673 for change in diff.iter_all_changes() {
674 let value = change.value();
675 match change.tag() {
676 similar::ChangeTag::Equal => {
677 old_start += value.len();
678 }
679 similar::ChangeTag::Delete => {
680 let old_end = old_start + value.len();
681 if let Some((last_old_range, _)) = edits.last_mut() {
682 if last_old_range.end == old_start {
683 last_old_range.end = old_end;
684 } else {
685 edits.push((old_start..old_end, String::new()));
686 }
687 } else {
688 edits.push((old_start..old_end, String::new()));
689 }
690 old_start = old_end;
691 }
692 similar::ChangeTag::Insert => {
693 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
694 if last_old_range.end == old_start {
695 last_new_text.push_str(value);
696 } else {
697 edits.push((old_start..old_start, value.into()));
698 }
699 } else {
700 edits.push((old_start..old_start, value.into()));
701 }
702 }
703 }
704 }
705
706 edits
707 .into_iter()
708 .map(|(mut old_range, new_text)| {
709 let prefix_len = common_prefix(
710 snapshot.chars_for_range(old_range.clone()),
711 new_text.chars(),
712 );
713 old_range.start += prefix_len;
714 let suffix_len = common_prefix(
715 snapshot.reversed_chars_for_range(old_range.clone()),
716 new_text[prefix_len..].chars().rev(),
717 );
718 old_range.end = old_range.end.saturating_sub(suffix_len);
719
720 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
721 (
722 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
723 new_text,
724 )
725 })
726 .collect()
727 }
728
729 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
730 self.rated_completions.contains(&completion_id)
731 }
732
733 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
734 self.shown_completions.push_front(completion.clone());
735 if self.shown_completions.len() > 50 {
736 let completion = self.shown_completions.pop_back().unwrap();
737 self.rated_completions.remove(&completion.id);
738 }
739 cx.notify();
740 }
741
742 pub fn rate_completion(
743 &mut self,
744 completion: &InlineCompletion,
745 rating: InlineCompletionRating,
746 feedback: String,
747 cx: &mut Context<Self>,
748 ) {
749 self.rated_completions.insert(completion.id);
750 telemetry::event!(
751 "Inline Completion Rated",
752 rating,
753 input_events = completion.input_events,
754 input_excerpt = completion.input_excerpt,
755 input_outline = completion.input_outline,
756 output_excerpt = completion.output_excerpt,
757 feedback
758 );
759 self.client.telemetry().flush_events();
760 cx.notify();
761 }
762
763 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
764 self.shown_completions.iter()
765 }
766
767 pub fn shown_completions_len(&self) -> usize {
768 self.shown_completions.len()
769 }
770
771 fn report_changes_for_buffer(
772 &mut self,
773 buffer: &Entity<Buffer>,
774 cx: &mut Context<Self>,
775 ) -> BufferSnapshot {
776 self.register_buffer(buffer, cx);
777
778 let registered_buffer = self
779 .registered_buffers
780 .get_mut(&buffer.entity_id())
781 .unwrap();
782 let new_snapshot = buffer.read(cx).snapshot();
783
784 if new_snapshot.version != registered_buffer.snapshot.version {
785 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
786 self.push_event(Event::BufferChange {
787 old_snapshot,
788 new_snapshot: new_snapshot.clone(),
789 timestamp: Instant::now(),
790 });
791 }
792
793 new_snapshot
794 }
795
796 pub fn data_collection_choice_at(&self, path: &Path) -> DataCollectionChoice {
797 match self.data_collection_preferences.per_worktree.get(path) {
798 Some(true) => DataCollectionChoice::Enabled,
799 Some(false) => DataCollectionChoice::Disabled,
800 None => DataCollectionChoice::NotAnswered,
801 }
802 }
803
804 fn update_data_collection_choice_for_worktree(
805 &mut self,
806 absolute_path_of_project_worktree: PathBuf,
807 can_collect_data: bool,
808 cx: &mut Context<Self>,
809 ) {
810 self.data_collection_preferences
811 .per_worktree
812 .insert(absolute_path_of_project_worktree.clone(), can_collect_data);
813
814 db::write_and_log(cx, move || {
815 persistence::DB
816 .save_accepted_data_collection(absolute_path_of_project_worktree, can_collect_data)
817 });
818 }
819
820 fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context<Self>) {
821 self.data_collection_preferences.never_ask_again = true;
822
823 // persist choice
824 db::write_and_log(cx, move || {
825 KEY_VALUE_STORE.write_kvp(
826 ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(),
827 "true".to_string(),
828 )
829 });
830 }
831
832 fn load_data_collection_preferences(cx: &mut Context<Self>) -> DataCollectionPreferences {
833 if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() {
834 db::write_and_log(cx, move || async move {
835 KEY_VALUE_STORE
836 .delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into())
837 .await
838 .log_err();
839
840 persistence::DB.clear_all_zeta_preferences().await
841 });
842 return DataCollectionPreferences::default();
843 }
844
845 let never_ask_again = KEY_VALUE_STORE
846 .read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY)
847 .log_err()
848 .flatten()
849 .map(|value| value == "true")
850 .unwrap_or(false);
851
852 let preferences_per_project = persistence::DB
853 .get_all_zeta_preferences()
854 .log_err()
855 .unwrap_or_else(HashMap::default);
856
857 DataCollectionPreferences {
858 never_ask_again,
859 per_worktree: preferences_per_project,
860 }
861 }
862}
863
864#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
865struct DataCollectionPreferences {
866 /// Set when a user clicks on "Never Ask Again", can never be unset.
867 never_ask_again: bool,
868 per_worktree: HashMap<PathBuf, bool>,
869}
870
871fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
872 a.zip(b)
873 .take_while(|(a, b)| a == b)
874 .map(|(a, _)| a.len_utf8())
875 .sum()
876}
877
878fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
879 let mut input_outline = String::new();
880
881 writeln!(
882 input_outline,
883 "```{}",
884 snapshot
885 .file()
886 .map_or(Cow::Borrowed("untitled"), |file| file
887 .path()
888 .to_string_lossy())
889 )
890 .unwrap();
891
892 if let Some(outline) = snapshot.outline(None) {
893 let guess_size = outline.items.len() * 15;
894 input_outline.reserve(guess_size);
895 for item in outline.items.iter() {
896 let spacing = " ".repeat(item.depth);
897 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
898 }
899 }
900
901 writeln!(input_outline, "```").unwrap();
902
903 input_outline
904}
905
906fn prompt_for_excerpt(
907 snapshot: &BufferSnapshot,
908 excerpt_range: &Range<usize>,
909 offset: usize,
910) -> String {
911 let mut prompt_excerpt = String::new();
912 writeln!(
913 prompt_excerpt,
914 "```{}",
915 snapshot
916 .file()
917 .map_or(Cow::Borrowed("untitled"), |file| file
918 .path()
919 .to_string_lossy())
920 )
921 .unwrap();
922
923 if excerpt_range.start == 0 {
924 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
925 }
926
927 let point_range = excerpt_range.to_point(snapshot);
928 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
929 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
930 for chunk in snapshot.text_for_range(extra_context_line_range) {
931 prompt_excerpt.push_str(chunk);
932 }
933 }
934 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
935 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
936 prompt_excerpt.push_str(chunk);
937 }
938 prompt_excerpt.push_str(CURSOR_MARKER);
939 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
940 prompt_excerpt.push_str(chunk);
941 }
942 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
943
944 if point_range.end.row < snapshot.max_point().row
945 && !snapshot.is_line_blank(point_range.end.row + 1)
946 {
947 let extra_context_line_range = point_range.end
948 ..Point::new(
949 point_range.end.row + 1,
950 snapshot.line_len(point_range.end.row + 1),
951 );
952 for chunk in snapshot.text_for_range(extra_context_line_range) {
953 prompt_excerpt.push_str(chunk);
954 }
955 }
956
957 write!(prompt_excerpt, "\n```").unwrap();
958 prompt_excerpt
959}
960
961fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
962 const CONTEXT_LINES: u32 = 32;
963
964 let mut context_lines_before = CONTEXT_LINES;
965 let mut context_lines_after = CONTEXT_LINES;
966 if point.row < CONTEXT_LINES {
967 context_lines_after += CONTEXT_LINES - point.row;
968 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
969 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
970 }
971
972 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
973 let excerpt_start = Point::new(excerpt_start_row, 0);
974 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
975 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
976 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
977}
978
979struct RegisteredBuffer {
980 snapshot: BufferSnapshot,
981 _subscriptions: [gpui::Subscription; 2],
982}
983
984#[derive(Clone)]
985enum Event {
986 BufferChange {
987 old_snapshot: BufferSnapshot,
988 new_snapshot: BufferSnapshot,
989 timestamp: Instant,
990 },
991}
992
993impl Event {
994 fn to_prompt(&self) -> String {
995 match self {
996 Event::BufferChange {
997 old_snapshot,
998 new_snapshot,
999 ..
1000 } => {
1001 let mut prompt = String::new();
1002
1003 let old_path = old_snapshot
1004 .file()
1005 .map(|f| f.path().as_ref())
1006 .unwrap_or(Path::new("untitled"));
1007 let new_path = new_snapshot
1008 .file()
1009 .map(|f| f.path().as_ref())
1010 .unwrap_or(Path::new("untitled"));
1011 if old_path != new_path {
1012 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1013 }
1014
1015 let diff =
1016 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1017 .unified_diff()
1018 .to_string();
1019 if !diff.is_empty() {
1020 write!(
1021 prompt,
1022 "User edited {:?}:\n```diff\n{}\n```",
1023 new_path, diff
1024 )
1025 .unwrap();
1026 }
1027
1028 prompt
1029 }
1030 }
1031 }
1032}
1033
1034#[derive(Debug, Clone)]
1035struct CurrentInlineCompletion {
1036 buffer_id: EntityId,
1037 completion: InlineCompletion,
1038}
1039
1040impl CurrentInlineCompletion {
1041 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1042 if self.buffer_id != old_completion.buffer_id {
1043 return true;
1044 }
1045
1046 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1047 return true;
1048 };
1049 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1050 return false;
1051 };
1052
1053 if old_edits.len() == 1 && new_edits.len() == 1 {
1054 let (old_range, old_text) = &old_edits[0];
1055 let (new_range, new_text) = &new_edits[0];
1056 new_range == old_range && new_text.starts_with(old_text)
1057 } else {
1058 true
1059 }
1060 }
1061}
1062
1063struct PendingCompletion {
1064 id: usize,
1065 _task: Task<()>,
1066}
1067
1068#[derive(Clone, Copy)]
1069pub enum DataCollectionChoice {
1070 NotAnswered,
1071 Enabled,
1072 Disabled,
1073}
1074
1075impl DataCollectionChoice {
1076 pub fn is_enabled(&self) -> bool {
1077 match self {
1078 Self::Enabled => true,
1079 Self::NotAnswered | Self::Disabled => false,
1080 }
1081 }
1082
1083 pub fn is_answered(&self) -> bool {
1084 match self {
1085 Self::Enabled | Self::Disabled => true,
1086 Self::NotAnswered => false,
1087 }
1088 }
1089
1090 pub fn toggle(&self) -> DataCollectionChoice {
1091 match self {
1092 Self::Enabled => Self::Disabled,
1093 Self::Disabled => Self::Enabled,
1094 Self::NotAnswered => Self::Enabled,
1095 }
1096 }
1097}
1098
1099pub struct ZetaInlineCompletionProvider {
1100 zeta: Entity<Zeta>,
1101 pending_completions: ArrayVec<PendingCompletion, 2>,
1102 next_pending_completion_id: usize,
1103 current_completion: Option<CurrentInlineCompletion>,
1104 data_collection: Option<ProviderDataCollection>,
1105}
1106
1107pub struct ProviderDataCollection {
1108 workspace: WeakEntity<Workspace>,
1109 worktree_root_path: PathBuf,
1110 choice: DataCollectionChoice,
1111}
1112
1113impl ProviderDataCollection {
1114 pub fn new(
1115 zeta: Entity<Zeta>,
1116 workspace: Option<Entity<Workspace>>,
1117 buffer: Option<Entity<Buffer>>,
1118 cx: &mut App,
1119 ) -> Option<ProviderDataCollection> {
1120 let workspace = workspace?;
1121
1122 let worktree_root_path = buffer?.update(cx, |buffer, cx| {
1123 let file = buffer.file()?;
1124
1125 if !file.is_local() || file.is_private() {
1126 return None;
1127 }
1128
1129 workspace.update(cx, |workspace, cx| {
1130 Some(
1131 workspace
1132 .absolute_path_of_worktree(file.worktree_id(cx), cx)?
1133 .to_path_buf(),
1134 )
1135 })
1136 })?;
1137
1138 let choice = zeta.read(cx).data_collection_choice_at(&worktree_root_path);
1139
1140 Some(ProviderDataCollection {
1141 workspace: workspace.downgrade(),
1142 worktree_root_path,
1143 choice,
1144 })
1145 }
1146
1147 fn set_choice(&mut self, choice: DataCollectionChoice, zeta: &Entity<Zeta>, cx: &mut App) {
1148 self.choice = choice;
1149
1150 let worktree_root_path = self.worktree_root_path.clone();
1151
1152 zeta.update(cx, |zeta, cx| {
1153 zeta.update_data_collection_choice_for_worktree(
1154 worktree_root_path,
1155 choice.is_enabled(),
1156 cx,
1157 )
1158 });
1159 }
1160
1161 fn toggle_choice(&mut self, zeta: &Entity<Zeta>, cx: &mut App) {
1162 self.set_choice(self.choice.toggle(), zeta, cx);
1163 }
1164}
1165
1166impl ZetaInlineCompletionProvider {
1167 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
1168
1169 pub fn new(zeta: Entity<Zeta>, data_collection: Option<ProviderDataCollection>) -> Self {
1170 Self {
1171 zeta,
1172 pending_completions: ArrayVec::new(),
1173 next_pending_completion_id: 0,
1174 current_completion: None,
1175 data_collection,
1176 }
1177 }
1178
1179 fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
1180 if let Some(data_collection) = self.data_collection.as_mut() {
1181 data_collection.set_choice(choice, &self.zeta, cx);
1182 }
1183 }
1184}
1185
1186impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1187 fn name() -> &'static str {
1188 "zed-predict"
1189 }
1190
1191 fn display_name() -> &'static str {
1192 "Zed's Edit Predictions"
1193 }
1194
1195 fn show_completions_in_menu() -> bool {
1196 true
1197 }
1198
1199 fn show_completions_in_normal_mode() -> bool {
1200 true
1201 }
1202
1203 fn show_tab_accept_marker() -> bool {
1204 true
1205 }
1206
1207 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
1208 let Some(data_collection) = self.data_collection.as_ref() else {
1209 return DataCollectionState::Unknown;
1210 };
1211
1212 if data_collection.choice.is_enabled() {
1213 DataCollectionState::Enabled
1214 } else {
1215 DataCollectionState::Disabled
1216 }
1217 }
1218
1219 fn toggle_data_collection(&mut self, cx: &mut App) {
1220 if let Some(data_collection) = self.data_collection.as_mut() {
1221 data_collection.toggle_choice(&self.zeta, cx);
1222 }
1223 }
1224
1225 fn is_enabled(
1226 &self,
1227 buffer: &Entity<Buffer>,
1228 cursor_position: language::Anchor,
1229 cx: &App,
1230 ) -> bool {
1231 let buffer = buffer.read(cx);
1232 let file = buffer.file();
1233 let language = buffer.language_at(cursor_position);
1234 let settings = all_language_settings(file, cx);
1235 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1236 }
1237
1238 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1239 !self.zeta.read(cx).tos_accepted
1240 }
1241
1242 fn is_refreshing(&self) -> bool {
1243 !self.pending_completions.is_empty()
1244 }
1245
1246 fn refresh(
1247 &mut self,
1248 buffer: Entity<Buffer>,
1249 position: language::Anchor,
1250 debounce: bool,
1251 cx: &mut Context<Self>,
1252 ) {
1253 if !self.zeta.read(cx).tos_accepted {
1254 return;
1255 }
1256
1257 let pending_completion_id = self.next_pending_completion_id;
1258 self.next_pending_completion_id += 1;
1259 let can_collect_data = self
1260 .data_collection
1261 .as_ref()
1262 .map_or(false, |data_collection| data_collection.choice.is_enabled());
1263
1264 let task = cx.spawn(|this, mut cx| async move {
1265 if debounce {
1266 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1267 }
1268
1269 let completion_request = this.update(&mut cx, |this, cx| {
1270 this.zeta.update(cx, |zeta, cx| {
1271 zeta.request_completion(&buffer, position, can_collect_data, cx)
1272 })
1273 });
1274
1275 let completion = match completion_request {
1276 Ok(completion_request) => {
1277 let completion_request = completion_request.await;
1278 completion_request.map(|c| {
1279 c.map(|completion| CurrentInlineCompletion {
1280 buffer_id: buffer.entity_id(),
1281 completion,
1282 })
1283 })
1284 }
1285 Err(error) => Err(error),
1286 };
1287 let Some(new_completion) = completion
1288 .context("edit prediction failed")
1289 .log_err()
1290 .flatten()
1291 else {
1292 return;
1293 };
1294
1295 this.update(&mut cx, |this, cx| {
1296 if this.pending_completions[0].id == pending_completion_id {
1297 this.pending_completions.remove(0);
1298 } else {
1299 this.pending_completions.clear();
1300 }
1301
1302 if let Some(old_completion) = this.current_completion.as_ref() {
1303 let snapshot = buffer.read(cx).snapshot();
1304 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1305 this.zeta.update(cx, |zeta, cx| {
1306 zeta.completion_shown(&new_completion.completion, cx);
1307 });
1308 this.current_completion = Some(new_completion);
1309 }
1310 } else {
1311 this.zeta.update(cx, |zeta, cx| {
1312 zeta.completion_shown(&new_completion.completion, cx);
1313 });
1314 this.current_completion = Some(new_completion);
1315 }
1316
1317 cx.notify();
1318 })
1319 .ok();
1320 });
1321
1322 // We always maintain at most two pending completions. When we already
1323 // have two, we replace the newest one.
1324 if self.pending_completions.len() <= 1 {
1325 self.pending_completions.push(PendingCompletion {
1326 id: pending_completion_id,
1327 _task: task,
1328 });
1329 } else if self.pending_completions.len() == 2 {
1330 self.pending_completions.pop();
1331 self.pending_completions.push(PendingCompletion {
1332 id: pending_completion_id,
1333 _task: task,
1334 });
1335 }
1336 }
1337
1338 fn cycle(
1339 &mut self,
1340 _buffer: Entity<Buffer>,
1341 _cursor_position: language::Anchor,
1342 _direction: inline_completion::Direction,
1343 _cx: &mut Context<Self>,
1344 ) {
1345 // Right now we don't support cycling.
1346 }
1347
1348 fn accept(&mut self, cx: &mut Context<Self>) {
1349 self.pending_completions.clear();
1350
1351 let Some(data_collection) = self.data_collection.as_mut() else {
1352 return;
1353 };
1354
1355 if data_collection.choice.is_answered()
1356 || self
1357 .zeta
1358 .read(cx)
1359 .data_collection_preferences
1360 .never_ask_again
1361 {
1362 return;
1363 }
1364
1365 struct ZetaDataCollectionNotification;
1366 let notification_id = NotificationId::unique::<ZetaDataCollectionNotification>();
1367
1368 const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection
1369
1370 let this = cx.entity();
1371 data_collection
1372 .workspace
1373 .update(cx, |workspace, cx| {
1374 workspace.show_notification(notification_id, cx, |cx| {
1375 let zeta = self.zeta.clone();
1376
1377 cx.new(move |_cx| {
1378 let message =
1379 "To allow Zed to suggest better edits, turn on data collection. You \
1380 can turn off at any time via the status bar menu.";
1381 MessageNotification::new(message)
1382 .with_title("Per-Project Data Collection Program")
1383 .show_close_button(false)
1384 .with_click_message("Turn On")
1385 .on_click({
1386 let this = this.clone();
1387 move |_window, cx| {
1388 this.update(cx, |this, cx| {
1389 this.set_data_collection_choice(
1390 DataCollectionChoice::Enabled,
1391 cx,
1392 )
1393 });
1394 }
1395 })
1396 .with_secondary_click_message("Turn Off")
1397 .on_secondary_click({
1398 move |_window, cx| {
1399 this.update(cx, |this, cx| {
1400 this.set_data_collection_choice(
1401 DataCollectionChoice::Disabled,
1402 cx,
1403 )
1404 });
1405 }
1406 })
1407 .with_tertiary_click_message("Never Ask Again")
1408 .on_tertiary_click({
1409 let zeta = zeta.clone();
1410 move |_window, cx| {
1411 zeta.update(cx, |zeta, cx| {
1412 zeta.set_never_ask_again_for_data_collection(cx);
1413 });
1414 }
1415 })
1416 .more_info_message("Learn More")
1417 .more_info_url(DATA_COLLECTION_INFO_URL)
1418 })
1419 });
1420 })
1421 .log_err();
1422 }
1423
1424 fn discard(&mut self, _cx: &mut Context<Self>) {
1425 self.pending_completions.clear();
1426 self.current_completion.take();
1427 }
1428
1429 fn suggest(
1430 &mut self,
1431 buffer: &Entity<Buffer>,
1432 cursor_position: language::Anchor,
1433 cx: &mut Context<Self>,
1434 ) -> Option<inline_completion::InlineCompletion> {
1435 let CurrentInlineCompletion {
1436 buffer_id,
1437 completion,
1438 ..
1439 } = self.current_completion.as_mut()?;
1440
1441 // Invalidate previous completion if it was generated for a different buffer.
1442 if *buffer_id != buffer.entity_id() {
1443 self.current_completion.take();
1444 return None;
1445 }
1446
1447 let buffer = buffer.read(cx);
1448 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1449 self.current_completion.take();
1450 return None;
1451 };
1452
1453 let cursor_row = cursor_position.to_point(buffer).row;
1454 let (closest_edit_ix, (closest_edit_range, _)) =
1455 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1456 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1457 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1458 cmp::min(distance_from_start, distance_from_end)
1459 })?;
1460
1461 let mut edit_start_ix = closest_edit_ix;
1462 for (range, _) in edits[..edit_start_ix].iter().rev() {
1463 let distance_from_closest_edit =
1464 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1465 if distance_from_closest_edit <= 1 {
1466 edit_start_ix -= 1;
1467 } else {
1468 break;
1469 }
1470 }
1471
1472 let mut edit_end_ix = closest_edit_ix + 1;
1473 for (range, _) in &edits[edit_end_ix..] {
1474 let distance_from_closest_edit =
1475 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1476 if distance_from_closest_edit <= 1 {
1477 edit_end_ix += 1;
1478 } else {
1479 break;
1480 }
1481 }
1482
1483 Some(inline_completion::InlineCompletion {
1484 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1485 })
1486 }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491 use client::test::FakeServer;
1492 use clock::FakeSystemClock;
1493 use gpui::TestAppContext;
1494 use http_client::FakeHttpClient;
1495 use indoc::indoc;
1496 use language_models::RefreshLlmTokenListener;
1497 use rpc::proto;
1498 use settings::SettingsStore;
1499
1500 use super::*;
1501
1502 #[gpui::test]
1503 fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1504 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1505 let completion = InlineCompletion {
1506 edits: cx
1507 .read(|cx| {
1508 to_completion_edits(
1509 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1510 &buffer,
1511 cx,
1512 )
1513 })
1514 .into(),
1515 path: Path::new("").into(),
1516 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1517 id: InlineCompletionId::new(),
1518 excerpt_range: 0..0,
1519 cursor_offset: 0,
1520 input_outline: "".into(),
1521 input_events: "".into(),
1522 input_excerpt: "".into(),
1523 output_excerpt: "".into(),
1524 request_sent_at: Instant::now(),
1525 response_received_at: Instant::now(),
1526 };
1527
1528 assert_eq!(
1529 cx.read(|cx| {
1530 from_completion_edits(
1531 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1532 &buffer,
1533 cx,
1534 )
1535 }),
1536 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1537 );
1538
1539 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1540 assert_eq!(
1541 cx.read(|cx| {
1542 from_completion_edits(
1543 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1544 &buffer,
1545 cx,
1546 )
1547 }),
1548 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1549 );
1550
1551 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1552 assert_eq!(
1553 cx.read(|cx| {
1554 from_completion_edits(
1555 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1556 &buffer,
1557 cx,
1558 )
1559 }),
1560 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1561 );
1562
1563 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1564 assert_eq!(
1565 cx.read(|cx| {
1566 from_completion_edits(
1567 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1568 &buffer,
1569 cx,
1570 )
1571 }),
1572 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1573 );
1574
1575 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1576 assert_eq!(
1577 cx.read(|cx| {
1578 from_completion_edits(
1579 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1580 &buffer,
1581 cx,
1582 )
1583 }),
1584 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1585 );
1586
1587 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1588 assert_eq!(
1589 cx.read(|cx| {
1590 from_completion_edits(
1591 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1592 &buffer,
1593 cx,
1594 )
1595 }),
1596 vec![(9..11, "".to_string())]
1597 );
1598
1599 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1600 assert_eq!(
1601 cx.read(|cx| {
1602 from_completion_edits(
1603 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1604 &buffer,
1605 cx,
1606 )
1607 }),
1608 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1609 );
1610
1611 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1612 assert_eq!(
1613 cx.read(|cx| {
1614 from_completion_edits(
1615 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1616 &buffer,
1617 cx,
1618 )
1619 }),
1620 vec![(4..4, "M".to_string())]
1621 );
1622
1623 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1624 assert_eq!(
1625 cx.read(|cx| completion.interpolate(&buffer.read(cx).snapshot())),
1626 None
1627 );
1628 }
1629
1630 #[gpui::test]
1631 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1632 cx.update(|cx| {
1633 let settings_store = SettingsStore::test(cx);
1634 cx.set_global(settings_store);
1635 client::init_settings(cx);
1636 });
1637
1638 let buffer_content = "lorem\n";
1639 let completion_response = indoc! {"
1640 ```animals.js
1641 <|start_of_file|>
1642 <|editable_region_start|>
1643 lorem
1644 ipsum
1645 <|editable_region_end|>
1646 ```"};
1647
1648 let http_client = FakeHttpClient::create(move |_| async move {
1649 Ok(http_client::Response::builder()
1650 .status(200)
1651 .body(
1652 serde_json::to_string(&PredictEditsResponse {
1653 output_excerpt: completion_response.to_string(),
1654 })
1655 .unwrap()
1656 .into(),
1657 )
1658 .unwrap())
1659 });
1660
1661 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1662 cx.update(|cx| {
1663 RefreshLlmTokenListener::register(client.clone(), cx);
1664 });
1665 let server = FakeServer::for_client(42, &client, cx).await;
1666 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1667 let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1668
1669 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1670 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1671 let completion_task = zeta.update(cx, |zeta, cx| {
1672 zeta.request_completion(&buffer, cursor, false, cx)
1673 });
1674
1675 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1676 server.respond(
1677 token_request.receipt(),
1678 proto::GetLlmTokenResponse { token: "".into() },
1679 );
1680
1681 let completion = completion_task.await.unwrap().unwrap();
1682 buffer.update(cx, |buffer, cx| {
1683 buffer.edit(completion.edits.iter().cloned(), None, cx)
1684 });
1685 assert_eq!(
1686 buffer.read_with(cx, |buffer, _| buffer.text()),
1687 "lorem\nipsum"
1688 );
1689 }
1690
1691 fn to_completion_edits(
1692 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1693 buffer: &Entity<Buffer>,
1694 cx: &App,
1695 ) -> Vec<(Range<Anchor>, String)> {
1696 let buffer = buffer.read(cx);
1697 iterator
1698 .into_iter()
1699 .map(|(range, text)| {
1700 (
1701 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1702 text,
1703 )
1704 })
1705 .collect()
1706 }
1707
1708 fn from_completion_edits(
1709 editor_edits: &[(Range<Anchor>, String)],
1710 buffer: &Entity<Buffer>,
1711 cx: &App,
1712 ) -> Vec<(Range<usize>, String)> {
1713 let buffer = buffer.read(cx);
1714 editor_edits
1715 .iter()
1716 .map(|(range, text)| {
1717 (
1718 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1719 text.clone(),
1720 )
1721 })
1722 .collect()
1723 }
1724
1725 #[ctor::ctor]
1726 fn init_logger() {
1727 if std::env::var("RUST_LOG").is_ok() {
1728 env_logger::init();
1729 }
1730 }
1731}