1mod completion_diff_element;
2mod persistence;
3mod rate_completion_modal;
4
5pub(crate) use completion_diff_element::*;
6use db::kvp::KEY_VALUE_STORE;
7use inline_completion::DataCollectionState;
8pub use rate_completion_modal::*;
9
10use anyhow::{anyhow, Context as _, Result};
11use arrayvec::ArrayVec;
12use client::{Client, UserStore};
13use collections::hash_map::Entry;
14use collections::{HashMap, HashSet, VecDeque};
15use futures::AsyncReadExt;
16use gpui::{
17 actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
18 WeakEntity,
19};
20use http_client::{HttpClient, Method};
21use language::{
22 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
23 Point, ToOffset, ToPoint,
24};
25use language_models::LlmApiToken;
26use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
27use std::{
28 borrow::Cow,
29 cmp, env,
30 fmt::Write,
31 future::Future,
32 mem,
33 ops::Range,
34 path::{Path, PathBuf},
35 sync::Arc,
36 time::{Duration, Instant},
37};
38use telemetry_events::InlineCompletionRating;
39use util::ResultExt;
40use uuid::Uuid;
41use workspace::{
42 notifications::{simple_message_notification::MessageNotification, NotificationId},
43 Workspace,
44};
45
46const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
47const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
48const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
49const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
50const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
51const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
52 "zed_predict_data_collection_never_ask_again";
53
54actions!(edit_prediction, [ClearHistory]);
55
56#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
57pub struct InlineCompletionId(Uuid);
58
59impl From<InlineCompletionId> for gpui::ElementId {
60 fn from(value: InlineCompletionId) -> Self {
61 gpui::ElementId::Uuid(value.0)
62 }
63}
64
65impl std::fmt::Display for InlineCompletionId {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "{}", self.0)
68 }
69}
70
71impl InlineCompletionId {
72 fn new() -> Self {
73 Self(Uuid::new_v4())
74 }
75}
76
77#[derive(Clone)]
78struct ZetaGlobal(Entity<Zeta>);
79
80impl Global for ZetaGlobal {}
81
82#[derive(Clone)]
83pub struct InlineCompletion {
84 id: InlineCompletionId,
85 path: Arc<Path>,
86 excerpt_range: Range<usize>,
87 cursor_offset: usize,
88 edits: Arc<[(Range<Anchor>, String)]>,
89 snapshot: BufferSnapshot,
90 input_outline: Arc<str>,
91 input_events: Arc<str>,
92 input_excerpt: Arc<str>,
93 output_excerpt: Arc<str>,
94 request_sent_at: Instant,
95 response_received_at: Instant,
96}
97
98impl InlineCompletion {
99 fn latency(&self) -> Duration {
100 self.response_received_at
101 .duration_since(self.request_sent_at)
102 }
103
104 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
105 let mut edits = Vec::new();
106
107 let mut user_edits = new_snapshot
108 .edits_since::<usize>(&self.snapshot.version)
109 .peekable();
110 for (model_old_range, model_new_text) in self.edits.iter() {
111 let model_offset_range = model_old_range.to_offset(&self.snapshot);
112 while let Some(next_user_edit) = user_edits.peek() {
113 if next_user_edit.old.end < model_offset_range.start {
114 user_edits.next();
115 } else {
116 break;
117 }
118 }
119
120 if let Some(user_edit) = user_edits.peek() {
121 if user_edit.old.start > model_offset_range.end {
122 edits.push((model_old_range.clone(), model_new_text.clone()));
123 } else if user_edit.old == model_offset_range {
124 let user_new_text = new_snapshot
125 .text_for_range(user_edit.new.clone())
126 .collect::<String>();
127
128 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
129 if !model_suffix.is_empty() {
130 edits.push((
131 new_snapshot.anchor_after(user_edit.new.end)
132 ..new_snapshot.anchor_before(user_edit.new.end),
133 model_suffix.into(),
134 ));
135 }
136
137 user_edits.next();
138 } else {
139 return None;
140 }
141 } else {
142 return None;
143 }
144 } else {
145 edits.push((model_old_range.clone(), model_new_text.clone()));
146 }
147 }
148
149 if edits.is_empty() {
150 None
151 } else {
152 Some(edits)
153 }
154 }
155}
156
157impl std::fmt::Debug for InlineCompletion {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("InlineCompletion")
160 .field("id", &self.id)
161 .field("path", &self.path)
162 .field("edits", &self.edits)
163 .finish_non_exhaustive()
164 }
165}
166
167pub struct Zeta {
168 client: Arc<Client>,
169 events: VecDeque<Event>,
170 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
171 shown_completions: VecDeque<InlineCompletion>,
172 rated_completions: HashSet<InlineCompletionId>,
173 data_collection_preferences: DataCollectionPreferences,
174 llm_token: LlmApiToken,
175 _llm_token_subscription: Subscription,
176 tos_accepted: bool, // Terms of service accepted
177 _user_store_subscription: Subscription,
178}
179
180impl Zeta {
181 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
182 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
183 }
184
185 pub fn register(
186 client: Arc<Client>,
187 user_store: Entity<UserStore>,
188 cx: &mut App,
189 ) -> Entity<Self> {
190 Self::global(cx).unwrap_or_else(|| {
191 let model = cx.new(|cx| Self::new(client, user_store, cx));
192 cx.set_global(ZetaGlobal(model.clone()));
193 model
194 })
195 }
196
197 pub fn clear_history(&mut self) {
198 self.events.clear();
199 }
200
201 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
202 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
203 Self {
204 client,
205 events: VecDeque::new(),
206 shown_completions: VecDeque::new(),
207 rated_completions: HashSet::default(),
208 registered_buffers: HashMap::default(),
209 data_collection_preferences: Self::load_data_collection_preferences(cx),
210 llm_token: LlmApiToken::default(),
211 _llm_token_subscription: cx.subscribe(
212 &refresh_llm_token_listener,
213 |this, _listener, _event, cx| {
214 let client = this.client.clone();
215 let llm_token = this.llm_token.clone();
216 cx.spawn(|_this, _cx| async move {
217 llm_token.refresh(&client).await?;
218 anyhow::Ok(())
219 })
220 .detach_and_log_err(cx);
221 },
222 ),
223 tos_accepted: user_store
224 .read(cx)
225 .current_user_has_accepted_terms()
226 .unwrap_or(false),
227 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
228 match event {
229 client::user::Event::PrivateUserInfoUpdated => {
230 this.tos_accepted = user_store
231 .read(cx)
232 .current_user_has_accepted_terms()
233 .unwrap_or(false);
234 }
235 _ => {}
236 }
237 }),
238 }
239 }
240
241 fn push_event(&mut self, event: Event) {
242 const MAX_EVENT_COUNT: usize = 16;
243
244 if let Some(Event::BufferChange {
245 new_snapshot: last_new_snapshot,
246 timestamp: last_timestamp,
247 ..
248 }) = self.events.back_mut()
249 {
250 // Coalesce edits for the same buffer when they happen one after the other.
251 let Event::BufferChange {
252 old_snapshot,
253 new_snapshot,
254 timestamp,
255 } = &event;
256
257 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
258 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
259 && old_snapshot.version == last_new_snapshot.version
260 {
261 *last_new_snapshot = new_snapshot.clone();
262 *last_timestamp = *timestamp;
263 return;
264 }
265 }
266
267 self.events.push_back(event);
268 if self.events.len() >= MAX_EVENT_COUNT {
269 self.events.drain(..MAX_EVENT_COUNT / 2);
270 }
271 }
272
273 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
274 let buffer_id = buffer.entity_id();
275 let weak_buffer = buffer.downgrade();
276
277 if let std::collections::hash_map::Entry::Vacant(entry) =
278 self.registered_buffers.entry(buffer_id)
279 {
280 let snapshot = buffer.read(cx).snapshot();
281
282 entry.insert(RegisteredBuffer {
283 snapshot,
284 _subscriptions: [
285 cx.subscribe(buffer, move |this, buffer, event, cx| {
286 this.handle_buffer_event(buffer, event, cx);
287 }),
288 cx.observe_release(buffer, move |this, _buffer, _cx| {
289 this.registered_buffers.remove(&weak_buffer.entity_id());
290 }),
291 ],
292 });
293 };
294 }
295
296 fn handle_buffer_event(
297 &mut self,
298 buffer: Entity<Buffer>,
299 event: &language::BufferEvent,
300 cx: &mut Context<Self>,
301 ) {
302 if let language::BufferEvent::Edited = event {
303 self.report_changes_for_buffer(&buffer, cx);
304 }
305 }
306
307 pub fn request_completion_impl<F, R>(
308 &mut self,
309 buffer: &Entity<Buffer>,
310 cursor: language::Anchor,
311 can_collect_data: bool,
312 cx: &mut Context<Self>,
313 perform_predict_edits: F,
314 ) -> Task<Result<Option<InlineCompletion>>>
315 where
316 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
317 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
318 {
319 let snapshot = self.report_changes_for_buffer(buffer, cx);
320 let point = cursor.to_point(&snapshot);
321 let offset = point.to_offset(&snapshot);
322 let excerpt_range = excerpt_range_for_position(point, &snapshot);
323 let events = self.events.clone();
324 let path = snapshot
325 .file()
326 .map(|f| Arc::from(f.full_path(cx).as_path()))
327 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
328
329 let client = self.client.clone();
330 let llm_token = self.llm_token.clone();
331
332 cx.spawn(|_, cx| async move {
333 let request_sent_at = Instant::now();
334
335 let (input_events, input_excerpt, input_outline) = cx
336 .background_executor()
337 .spawn({
338 let snapshot = snapshot.clone();
339 let excerpt_range = excerpt_range.clone();
340 async move {
341 let mut input_events = String::new();
342 for event in events {
343 if !input_events.is_empty() {
344 input_events.push('\n');
345 input_events.push('\n');
346 }
347 input_events.push_str(&event.to_prompt());
348 }
349
350 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
351 let input_outline = prompt_for_outline(&snapshot);
352
353 (input_events, input_excerpt, input_outline)
354 }
355 })
356 .await;
357
358 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
359
360 let body = PredictEditsParams {
361 input_events: input_events.clone(),
362 input_excerpt: input_excerpt.clone(),
363 outline: Some(input_outline.clone()),
364 can_collect_data,
365 };
366
367 let response = perform_predict_edits(client, llm_token, body).await?;
368
369 let output_excerpt = response.output_excerpt;
370 log::debug!("completion response: {}", output_excerpt);
371
372 Self::process_completion_response(
373 output_excerpt,
374 &snapshot,
375 excerpt_range,
376 offset,
377 path,
378 input_outline,
379 input_events,
380 input_excerpt,
381 request_sent_at,
382 &cx,
383 )
384 .await
385 })
386 }
387
388 // Generates several example completions of various states to fill the Zeta completion modal
389 #[cfg(any(test, feature = "test-support"))]
390 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
391 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
392 And maybe a short line
393
394 Then a few lines
395
396 and then another
397 "#};
398
399 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
400 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
401
402 let completion_tasks = vec![
403 self.fake_completion(
404 &buffer,
405 position,
406 PredictEditsResponse {
407 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
408a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
409[here's an edit]
410And maybe a short line
411Then a few lines
412and then another
413{EDITABLE_REGION_END_MARKER}
414 ", ),
415 },
416 cx,
417 ),
418 self.fake_completion(
419 &buffer,
420 position,
421 PredictEditsResponse {
422 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
423a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
424And maybe a short line
425[and another edit]
426Then a few lines
427and then another
428{EDITABLE_REGION_END_MARKER}
429 "#),
430 },
431 cx,
432 ),
433 self.fake_completion(
434 &buffer,
435 position,
436 PredictEditsResponse {
437 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
438a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
439And maybe a short line
440
441Then a few lines
442
443and then another
444{EDITABLE_REGION_END_MARKER}
445 "#),
446 },
447 cx,
448 ),
449 self.fake_completion(
450 &buffer,
451 position,
452 PredictEditsResponse {
453 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
454a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
455And maybe a short line
456
457Then a few lines
458
459and then another
460{EDITABLE_REGION_END_MARKER}
461 "#),
462 },
463 cx,
464 ),
465 self.fake_completion(
466 &buffer,
467 position,
468 PredictEditsResponse {
469 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
470a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
471And maybe a short line
472Then a few lines
473[a third completion]
474and then another
475{EDITABLE_REGION_END_MARKER}
476 "#),
477 },
478 cx,
479 ),
480 self.fake_completion(
481 &buffer,
482 position,
483 PredictEditsResponse {
484 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
485a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
486And maybe a short line
487and then another
488[fourth completion example]
489{EDITABLE_REGION_END_MARKER}
490 "#),
491 },
492 cx,
493 ),
494 self.fake_completion(
495 &buffer,
496 position,
497 PredictEditsResponse {
498 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
499a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
500And maybe a short line
501Then a few lines
502and then another
503[fifth and final completion]
504{EDITABLE_REGION_END_MARKER}
505 "#),
506 },
507 cx,
508 ),
509 ];
510
511 cx.spawn(|zeta, mut cx| async move {
512 for task in completion_tasks {
513 task.await.unwrap();
514 }
515
516 zeta.update(&mut cx, |zeta, _cx| {
517 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
518 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
519 })
520 .ok();
521 })
522 }
523
524 #[cfg(any(test, feature = "test-support"))]
525 pub fn fake_completion(
526 &mut self,
527 buffer: &Entity<Buffer>,
528 position: language::Anchor,
529 response: PredictEditsResponse,
530 cx: &mut Context<Self>,
531 ) -> Task<Result<Option<InlineCompletion>>> {
532 use std::future::ready;
533
534 self.request_completion_impl(buffer, position, false, cx, |_, _, _| ready(Ok(response)))
535 }
536
537 pub fn request_completion(
538 &mut self,
539 buffer: &Entity<Buffer>,
540 position: language::Anchor,
541 can_collect_data: bool,
542 cx: &mut Context<Self>,
543 ) -> Task<Result<Option<InlineCompletion>>> {
544 self.request_completion_impl(
545 buffer,
546 position,
547 can_collect_data,
548 cx,
549 Self::perform_predict_edits,
550 )
551 }
552
553 fn perform_predict_edits(
554 client: Arc<Client>,
555 llm_token: LlmApiToken,
556 body: PredictEditsParams,
557 ) -> impl Future<Output = Result<PredictEditsResponse>> {
558 async move {
559 let http_client = client.http_client();
560 let mut token = llm_token.acquire(&client).await?;
561 let mut did_retry = false;
562
563 loop {
564 let request_builder = http_client::Request::builder();
565 let request = request_builder
566 .method(Method::POST)
567 .uri(
568 http_client
569 .build_zed_llm_url("/predict_edits", &[])?
570 .as_ref(),
571 )
572 .header("Content-Type", "application/json")
573 .header("Authorization", format!("Bearer {}", token))
574 .body(serde_json::to_string(&body)?.into())?;
575
576 let mut response = http_client.send(request).await?;
577
578 if response.status().is_success() {
579 let mut body = String::new();
580 response.body_mut().read_to_string(&mut body).await?;
581 return Ok(serde_json::from_str(&body)?);
582 } else if !did_retry
583 && response
584 .headers()
585 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
586 .is_some()
587 {
588 did_retry = true;
589 token = llm_token.refresh(&client).await?;
590 } else {
591 let mut body = String::new();
592 response.body_mut().read_to_string(&mut body).await?;
593 return Err(anyhow!(
594 "error predicting edits.\nStatus: {:?}\nBody: {}",
595 response.status(),
596 body
597 ));
598 }
599 }
600 }
601 }
602
603 #[allow(clippy::too_many_arguments)]
604 fn process_completion_response(
605 output_excerpt: String,
606 snapshot: &BufferSnapshot,
607 excerpt_range: Range<usize>,
608 cursor_offset: usize,
609 path: Arc<Path>,
610 input_outline: String,
611 input_events: String,
612 input_excerpt: String,
613 request_sent_at: Instant,
614 cx: &AsyncApp,
615 ) -> Task<Result<Option<InlineCompletion>>> {
616 let snapshot = snapshot.clone();
617 cx.background_executor().spawn(async move {
618 let content = output_excerpt.replace(CURSOR_MARKER, "");
619
620 let start_markers = content
621 .match_indices(EDITABLE_REGION_START_MARKER)
622 .collect::<Vec<_>>();
623 anyhow::ensure!(
624 start_markers.len() == 1,
625 "expected exactly one start marker, found {}",
626 start_markers.len()
627 );
628
629 let codefence_start = start_markers[0].0;
630 let content = &content[codefence_start..];
631
632 let newline_ix = content.find('\n').context("could not find newline")?;
633 let content = &content[newline_ix + 1..];
634
635 let codefence_end = content
636 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
637 .context("could not find end marker")?;
638 let new_text = &content[..codefence_end];
639
640 let old_text = snapshot
641 .text_for_range(excerpt_range.clone())
642 .collect::<String>();
643
644 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
645
646 Ok(Some(InlineCompletion {
647 id: InlineCompletionId::new(),
648 path,
649 excerpt_range,
650 cursor_offset,
651 edits: edits.into(),
652 snapshot: snapshot.clone(),
653 input_outline: input_outline.into(),
654 input_events: input_events.into(),
655 input_excerpt: input_excerpt.into(),
656 output_excerpt: output_excerpt.into(),
657 request_sent_at,
658 response_received_at: Instant::now(),
659 }))
660 })
661 }
662
663 pub fn compute_edits(
664 old_text: String,
665 new_text: &str,
666 offset: usize,
667 snapshot: &BufferSnapshot,
668 ) -> Vec<(Range<Anchor>, String)> {
669 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
670
671 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
672 let mut old_start = offset;
673 for change in diff.iter_all_changes() {
674 let value = change.value();
675 match change.tag() {
676 similar::ChangeTag::Equal => {
677 old_start += value.len();
678 }
679 similar::ChangeTag::Delete => {
680 let old_end = old_start + value.len();
681 if let Some((last_old_range, _)) = edits.last_mut() {
682 if last_old_range.end == old_start {
683 last_old_range.end = old_end;
684 } else {
685 edits.push((old_start..old_end, String::new()));
686 }
687 } else {
688 edits.push((old_start..old_end, String::new()));
689 }
690 old_start = old_end;
691 }
692 similar::ChangeTag::Insert => {
693 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
694 if last_old_range.end == old_start {
695 last_new_text.push_str(value);
696 } else {
697 edits.push((old_start..old_start, value.into()));
698 }
699 } else {
700 edits.push((old_start..old_start, value.into()));
701 }
702 }
703 }
704 }
705
706 edits
707 .into_iter()
708 .map(|(mut old_range, new_text)| {
709 let prefix_len = common_prefix(
710 snapshot.chars_for_range(old_range.clone()),
711 new_text.chars(),
712 );
713 old_range.start += prefix_len;
714 let suffix_len = common_prefix(
715 snapshot.reversed_chars_for_range(old_range.clone()),
716 new_text[prefix_len..].chars().rev(),
717 );
718 old_range.end = old_range.end.saturating_sub(suffix_len);
719
720 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
721 (
722 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
723 new_text,
724 )
725 })
726 .collect()
727 }
728
729 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
730 self.rated_completions.contains(&completion_id)
731 }
732
733 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
734 self.shown_completions.push_front(completion.clone());
735 if self.shown_completions.len() > 50 {
736 let completion = self.shown_completions.pop_back().unwrap();
737 self.rated_completions.remove(&completion.id);
738 }
739 cx.notify();
740 }
741
742 pub fn rate_completion(
743 &mut self,
744 completion: &InlineCompletion,
745 rating: InlineCompletionRating,
746 feedback: String,
747 cx: &mut Context<Self>,
748 ) {
749 self.rated_completions.insert(completion.id);
750 telemetry::event!(
751 "Inline Completion Rated",
752 rating,
753 input_events = completion.input_events,
754 input_excerpt = completion.input_excerpt,
755 input_outline = completion.input_outline,
756 output_excerpt = completion.output_excerpt,
757 feedback
758 );
759 self.client.telemetry().flush_events();
760 cx.notify();
761 }
762
763 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
764 self.shown_completions.iter()
765 }
766
767 pub fn shown_completions_len(&self) -> usize {
768 self.shown_completions.len()
769 }
770
771 fn report_changes_for_buffer(
772 &mut self,
773 buffer: &Entity<Buffer>,
774 cx: &mut Context<Self>,
775 ) -> BufferSnapshot {
776 self.register_buffer(buffer, cx);
777
778 let registered_buffer = self
779 .registered_buffers
780 .get_mut(&buffer.entity_id())
781 .unwrap();
782 let new_snapshot = buffer.read(cx).snapshot();
783
784 if new_snapshot.version != registered_buffer.snapshot.version {
785 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
786 self.push_event(Event::BufferChange {
787 old_snapshot,
788 new_snapshot: new_snapshot.clone(),
789 timestamp: Instant::now(),
790 });
791 }
792
793 new_snapshot
794 }
795
796 /// Creates a `Entity<DataCollectionChoice>` for each unique worktree abs path it sees.
797 pub fn data_collection_choice_at(
798 &mut self,
799 worktree_abs_path: PathBuf,
800 cx: &mut Context<Self>,
801 ) -> Entity<DataCollectionChoice> {
802 match self
803 .data_collection_preferences
804 .per_worktree
805 .entry(worktree_abs_path)
806 {
807 Entry::Vacant(entry) => {
808 let choice = cx.new(|_| DataCollectionChoice::NotAnswered);
809 entry.insert(choice.clone());
810 choice
811 }
812 Entry::Occupied(entry) => entry.get().clone(),
813 }
814 }
815
816 fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context<Self>) {
817 self.data_collection_preferences.never_ask_again = true;
818
819 // persist choice
820 db::write_and_log(cx, move || {
821 KEY_VALUE_STORE.write_kvp(
822 ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(),
823 "true".to_string(),
824 )
825 });
826 }
827
828 fn load_data_collection_preferences(cx: &mut Context<Self>) -> DataCollectionPreferences {
829 if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() {
830 db::write_and_log(cx, move || async move {
831 KEY_VALUE_STORE
832 .delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into())
833 .await
834 .log_err();
835
836 persistence::DB.clear_all_zeta_preferences().await
837 });
838 return DataCollectionPreferences::default();
839 }
840
841 let never_ask_again = KEY_VALUE_STORE
842 .read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY)
843 .log_err()
844 .flatten()
845 .map(|value| value == "true")
846 .unwrap_or(false);
847
848 let preferences_per_worktree = persistence::DB
849 .get_all_data_collection_preferences()
850 .log_err()
851 .into_iter()
852 .flatten()
853 .map(|(path, choice)| {
854 let choice = cx.new(|_| DataCollectionChoice::from(choice));
855 (path, choice)
856 })
857 .collect();
858
859 DataCollectionPreferences {
860 never_ask_again,
861 per_worktree: preferences_per_worktree,
862 }
863 }
864}
865
866#[derive(Default, Debug)]
867struct DataCollectionPreferences {
868 /// Set when a user clicks on "Never Ask Again", can never be unset.
869 never_ask_again: bool,
870 /// The choices for each worktree.
871 ///
872 /// This is filled when loading from database, or when querying if no matching path is found.
873 per_worktree: HashMap<PathBuf, Entity<DataCollectionChoice>>,
874}
875
876fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
877 a.zip(b)
878 .take_while(|(a, b)| a == b)
879 .map(|(a, _)| a.len_utf8())
880 .sum()
881}
882
883fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
884 let mut input_outline = String::new();
885
886 writeln!(
887 input_outline,
888 "```{}",
889 snapshot
890 .file()
891 .map_or(Cow::Borrowed("untitled"), |file| file
892 .path()
893 .to_string_lossy())
894 )
895 .unwrap();
896
897 if let Some(outline) = snapshot.outline(None) {
898 let guess_size = outline.items.len() * 15;
899 input_outline.reserve(guess_size);
900 for item in outline.items.iter() {
901 let spacing = " ".repeat(item.depth);
902 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
903 }
904 }
905
906 writeln!(input_outline, "```").unwrap();
907
908 input_outline
909}
910
911fn prompt_for_excerpt(
912 snapshot: &BufferSnapshot,
913 excerpt_range: &Range<usize>,
914 offset: usize,
915) -> String {
916 let mut prompt_excerpt = String::new();
917 writeln!(
918 prompt_excerpt,
919 "```{}",
920 snapshot
921 .file()
922 .map_or(Cow::Borrowed("untitled"), |file| file
923 .path()
924 .to_string_lossy())
925 )
926 .unwrap();
927
928 if excerpt_range.start == 0 {
929 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
930 }
931
932 let point_range = excerpt_range.to_point(snapshot);
933 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
934 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
935 for chunk in snapshot.text_for_range(extra_context_line_range) {
936 prompt_excerpt.push_str(chunk);
937 }
938 }
939 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
940 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
941 prompt_excerpt.push_str(chunk);
942 }
943 prompt_excerpt.push_str(CURSOR_MARKER);
944 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
945 prompt_excerpt.push_str(chunk);
946 }
947 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
948
949 if point_range.end.row < snapshot.max_point().row
950 && !snapshot.is_line_blank(point_range.end.row + 1)
951 {
952 let extra_context_line_range = point_range.end
953 ..Point::new(
954 point_range.end.row + 1,
955 snapshot.line_len(point_range.end.row + 1),
956 );
957 for chunk in snapshot.text_for_range(extra_context_line_range) {
958 prompt_excerpt.push_str(chunk);
959 }
960 }
961
962 write!(prompt_excerpt, "\n```").unwrap();
963 prompt_excerpt
964}
965
966fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
967 const CONTEXT_LINES: u32 = 32;
968
969 let mut context_lines_before = CONTEXT_LINES;
970 let mut context_lines_after = CONTEXT_LINES;
971 if point.row < CONTEXT_LINES {
972 context_lines_after += CONTEXT_LINES - point.row;
973 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
974 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
975 }
976
977 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
978 let excerpt_start = Point::new(excerpt_start_row, 0);
979 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
980 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
981 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
982}
983
984struct RegisteredBuffer {
985 snapshot: BufferSnapshot,
986 _subscriptions: [gpui::Subscription; 2],
987}
988
989#[derive(Clone)]
990enum Event {
991 BufferChange {
992 old_snapshot: BufferSnapshot,
993 new_snapshot: BufferSnapshot,
994 timestamp: Instant,
995 },
996}
997
998impl Event {
999 fn to_prompt(&self) -> String {
1000 match self {
1001 Event::BufferChange {
1002 old_snapshot,
1003 new_snapshot,
1004 ..
1005 } => {
1006 let mut prompt = String::new();
1007
1008 let old_path = old_snapshot
1009 .file()
1010 .map(|f| f.path().as_ref())
1011 .unwrap_or(Path::new("untitled"));
1012 let new_path = new_snapshot
1013 .file()
1014 .map(|f| f.path().as_ref())
1015 .unwrap_or(Path::new("untitled"));
1016 if old_path != new_path {
1017 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1018 }
1019
1020 let diff =
1021 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1022 .unified_diff()
1023 .to_string();
1024 if !diff.is_empty() {
1025 write!(
1026 prompt,
1027 "User edited {:?}:\n```diff\n{}\n```",
1028 new_path, diff
1029 )
1030 .unwrap();
1031 }
1032
1033 prompt
1034 }
1035 }
1036 }
1037}
1038
1039#[derive(Debug, Clone)]
1040struct CurrentInlineCompletion {
1041 buffer_id: EntityId,
1042 completion: InlineCompletion,
1043}
1044
1045impl CurrentInlineCompletion {
1046 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1047 if self.buffer_id != old_completion.buffer_id {
1048 return true;
1049 }
1050
1051 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1052 return true;
1053 };
1054 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1055 return false;
1056 };
1057
1058 if old_edits.len() == 1 && new_edits.len() == 1 {
1059 let (old_range, old_text) = &old_edits[0];
1060 let (new_range, new_text) = &new_edits[0];
1061 new_range == old_range && new_text.starts_with(old_text)
1062 } else {
1063 true
1064 }
1065 }
1066}
1067
1068struct PendingCompletion {
1069 id: usize,
1070 _task: Task<()>,
1071}
1072
1073#[derive(Debug, Clone, Copy)]
1074pub enum DataCollectionChoice {
1075 NotAnswered,
1076 Enabled,
1077 Disabled,
1078}
1079
1080impl DataCollectionChoice {
1081 pub fn is_enabled(self) -> bool {
1082 match self {
1083 Self::Enabled => true,
1084 Self::NotAnswered | Self::Disabled => false,
1085 }
1086 }
1087
1088 pub fn is_answered(self) -> bool {
1089 match self {
1090 Self::Enabled | Self::Disabled => true,
1091 Self::NotAnswered => false,
1092 }
1093 }
1094
1095 pub fn toggle(self) -> DataCollectionChoice {
1096 match self {
1097 Self::Enabled => Self::Disabled,
1098 Self::Disabled => Self::Enabled,
1099 Self::NotAnswered => Self::Enabled,
1100 }
1101 }
1102}
1103
1104impl From<bool> for DataCollectionChoice {
1105 fn from(value: bool) -> Self {
1106 match value {
1107 true => DataCollectionChoice::Enabled,
1108 false => DataCollectionChoice::Disabled,
1109 }
1110 }
1111}
1112
1113pub struct ZetaInlineCompletionProvider {
1114 zeta: Entity<Zeta>,
1115 pending_completions: ArrayVec<PendingCompletion, 2>,
1116 next_pending_completion_id: usize,
1117 current_completion: Option<CurrentInlineCompletion>,
1118 data_collection: Option<ProviderDataCollection>,
1119}
1120
1121pub struct ProviderDataCollection {
1122 workspace: WeakEntity<Workspace>,
1123 worktree_root_path: PathBuf,
1124 choice: Entity<DataCollectionChoice>,
1125}
1126
1127impl ProviderDataCollection {
1128 pub fn new(
1129 zeta: Entity<Zeta>,
1130 workspace: Option<Entity<Workspace>>,
1131 buffer: Option<Entity<Buffer>>,
1132 cx: &mut App,
1133 ) -> Option<ProviderDataCollection> {
1134 let workspace = workspace?;
1135
1136 let worktree_root_path = buffer?.update(cx, |buffer, cx| {
1137 let file = buffer.file()?;
1138
1139 if !file.is_local() || file.is_private() {
1140 return None;
1141 }
1142
1143 workspace.update(cx, |workspace, cx| {
1144 Some(
1145 workspace
1146 .absolute_path_of_worktree(file.worktree_id(cx), cx)?
1147 .to_path_buf(),
1148 )
1149 })
1150 })?;
1151
1152 let choice = zeta.update(cx, |zeta, cx| {
1153 zeta.data_collection_choice_at(worktree_root_path.clone(), cx)
1154 });
1155
1156 Some(ProviderDataCollection {
1157 workspace: workspace.downgrade(),
1158 worktree_root_path,
1159 choice,
1160 })
1161 }
1162
1163 fn set_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
1164 self.choice.update(cx, |this, _| *this = choice);
1165
1166 let worktree_root_path = self.worktree_root_path.clone();
1167
1168 db::write_and_log(cx, move || {
1169 persistence::DB.save_data_collection_choice(worktree_root_path, choice.is_enabled())
1170 });
1171 }
1172
1173 fn toggle_choice(&mut self, cx: &mut App) {
1174 self.set_choice(self.choice.read(cx).toggle(), cx);
1175 }
1176}
1177
1178impl ZetaInlineCompletionProvider {
1179 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
1180
1181 pub fn new(zeta: Entity<Zeta>, data_collection: Option<ProviderDataCollection>) -> Self {
1182 Self {
1183 zeta,
1184 pending_completions: ArrayVec::new(),
1185 next_pending_completion_id: 0,
1186 current_completion: None,
1187 data_collection,
1188 }
1189 }
1190
1191 fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
1192 if let Some(data_collection) = self.data_collection.as_mut() {
1193 data_collection.set_choice(choice, cx);
1194 }
1195 }
1196}
1197
1198impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1199 fn name() -> &'static str {
1200 "zed-predict"
1201 }
1202
1203 fn display_name() -> &'static str {
1204 "Zed's Edit Predictions"
1205 }
1206
1207 fn show_completions_in_menu() -> bool {
1208 true
1209 }
1210
1211 fn show_completions_in_normal_mode() -> bool {
1212 true
1213 }
1214
1215 fn show_tab_accept_marker() -> bool {
1216 true
1217 }
1218
1219 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1220 let Some(data_collection) = self.data_collection.as_ref() else {
1221 return DataCollectionState::Unknown;
1222 };
1223
1224 if data_collection.choice.read(cx).is_enabled() {
1225 DataCollectionState::Enabled
1226 } else {
1227 DataCollectionState::Disabled
1228 }
1229 }
1230
1231 fn toggle_data_collection(&mut self, cx: &mut App) {
1232 if let Some(data_collection) = self.data_collection.as_mut() {
1233 data_collection.toggle_choice(cx);
1234 }
1235 }
1236
1237 fn is_enabled(
1238 &self,
1239 buffer: &Entity<Buffer>,
1240 cursor_position: language::Anchor,
1241 cx: &App,
1242 ) -> bool {
1243 let buffer = buffer.read(cx);
1244 let file = buffer.file();
1245 let language = buffer.language_at(cursor_position);
1246 let settings = all_language_settings(file, cx);
1247 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1248 }
1249
1250 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1251 !self.zeta.read(cx).tos_accepted
1252 }
1253
1254 fn is_refreshing(&self) -> bool {
1255 !self.pending_completions.is_empty()
1256 }
1257
1258 fn refresh(
1259 &mut self,
1260 buffer: Entity<Buffer>,
1261 position: language::Anchor,
1262 debounce: bool,
1263 cx: &mut Context<Self>,
1264 ) {
1265 if !self.zeta.read(cx).tos_accepted {
1266 return;
1267 }
1268
1269 let pending_completion_id = self.next_pending_completion_id;
1270 self.next_pending_completion_id += 1;
1271 let can_collect_data = self
1272 .data_collection
1273 .as_ref()
1274 .map_or(false, |data_collection| {
1275 data_collection.choice.read(cx).is_enabled()
1276 });
1277
1278 let task = cx.spawn(|this, mut cx| async move {
1279 if debounce {
1280 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1281 }
1282
1283 let completion_request = this.update(&mut cx, |this, cx| {
1284 this.zeta.update(cx, |zeta, cx| {
1285 zeta.request_completion(&buffer, position, can_collect_data, cx)
1286 })
1287 });
1288
1289 let completion = match completion_request {
1290 Ok(completion_request) => {
1291 let completion_request = completion_request.await;
1292 completion_request.map(|c| {
1293 c.map(|completion| CurrentInlineCompletion {
1294 buffer_id: buffer.entity_id(),
1295 completion,
1296 })
1297 })
1298 }
1299 Err(error) => Err(error),
1300 };
1301 let Some(new_completion) = completion
1302 .context("edit prediction failed")
1303 .log_err()
1304 .flatten()
1305 else {
1306 return;
1307 };
1308
1309 this.update(&mut cx, |this, cx| {
1310 if this.pending_completions[0].id == pending_completion_id {
1311 this.pending_completions.remove(0);
1312 } else {
1313 this.pending_completions.clear();
1314 }
1315
1316 if let Some(old_completion) = this.current_completion.as_ref() {
1317 let snapshot = buffer.read(cx).snapshot();
1318 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1319 this.zeta.update(cx, |zeta, cx| {
1320 zeta.completion_shown(&new_completion.completion, cx);
1321 });
1322 this.current_completion = Some(new_completion);
1323 }
1324 } else {
1325 this.zeta.update(cx, |zeta, cx| {
1326 zeta.completion_shown(&new_completion.completion, cx);
1327 });
1328 this.current_completion = Some(new_completion);
1329 }
1330
1331 cx.notify();
1332 })
1333 .ok();
1334 });
1335
1336 // We always maintain at most two pending completions. When we already
1337 // have two, we replace the newest one.
1338 if self.pending_completions.len() <= 1 {
1339 self.pending_completions.push(PendingCompletion {
1340 id: pending_completion_id,
1341 _task: task,
1342 });
1343 } else if self.pending_completions.len() == 2 {
1344 self.pending_completions.pop();
1345 self.pending_completions.push(PendingCompletion {
1346 id: pending_completion_id,
1347 _task: task,
1348 });
1349 }
1350 }
1351
1352 fn cycle(
1353 &mut self,
1354 _buffer: Entity<Buffer>,
1355 _cursor_position: language::Anchor,
1356 _direction: inline_completion::Direction,
1357 _cx: &mut Context<Self>,
1358 ) {
1359 // Right now we don't support cycling.
1360 }
1361
1362 fn accept(&mut self, cx: &mut Context<Self>) {
1363 self.pending_completions.clear();
1364
1365 let Some(data_collection) = self.data_collection.as_mut() else {
1366 return;
1367 };
1368
1369 if data_collection.choice.read(cx).is_answered()
1370 || self
1371 .zeta
1372 .read(cx)
1373 .data_collection_preferences
1374 .never_ask_again
1375 {
1376 return;
1377 }
1378
1379 struct ZetaDataCollectionNotification;
1380 let notification_id = NotificationId::unique::<ZetaDataCollectionNotification>();
1381
1382 const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection
1383
1384 let this = cx.entity();
1385 data_collection
1386 .workspace
1387 .update(cx, |workspace, cx| {
1388 workspace.show_notification(notification_id, cx, |cx| {
1389 let zeta = self.zeta.clone();
1390
1391 cx.new(move |_cx| {
1392 let message =
1393 "To allow Zed to suggest better edits, turn on data collection. You \
1394 can turn off at any time via the status bar menu.";
1395 MessageNotification::new(message)
1396 .with_title("Per-Project Data Collection Program")
1397 .show_close_button(false)
1398 .with_click_message("Turn On")
1399 .on_click({
1400 let this = this.clone();
1401 move |_window, cx| {
1402 this.update(cx, |this, cx| {
1403 this.set_data_collection_choice(
1404 DataCollectionChoice::Enabled,
1405 cx,
1406 )
1407 });
1408 }
1409 })
1410 .with_secondary_click_message("Turn Off")
1411 .on_secondary_click({
1412 move |_window, cx| {
1413 this.update(cx, |this, cx| {
1414 this.set_data_collection_choice(
1415 DataCollectionChoice::Disabled,
1416 cx,
1417 )
1418 });
1419 }
1420 })
1421 .with_tertiary_click_message("Never Ask Again")
1422 .on_tertiary_click({
1423 move |_window, cx| {
1424 zeta.update(cx, |zeta, cx| {
1425 zeta.set_never_ask_again_for_data_collection(cx);
1426 });
1427 }
1428 })
1429 .more_info_message("Learn More")
1430 .more_info_url(DATA_COLLECTION_INFO_URL)
1431 })
1432 });
1433 })
1434 .log_err();
1435 }
1436
1437 fn discard(&mut self, _cx: &mut Context<Self>) {
1438 self.pending_completions.clear();
1439 self.current_completion.take();
1440 }
1441
1442 fn suggest(
1443 &mut self,
1444 buffer: &Entity<Buffer>,
1445 cursor_position: language::Anchor,
1446 cx: &mut Context<Self>,
1447 ) -> Option<inline_completion::InlineCompletion> {
1448 let CurrentInlineCompletion {
1449 buffer_id,
1450 completion,
1451 ..
1452 } = self.current_completion.as_mut()?;
1453
1454 // Invalidate previous completion if it was generated for a different buffer.
1455 if *buffer_id != buffer.entity_id() {
1456 self.current_completion.take();
1457 return None;
1458 }
1459
1460 let buffer = buffer.read(cx);
1461 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1462 self.current_completion.take();
1463 return None;
1464 };
1465
1466 let cursor_row = cursor_position.to_point(buffer).row;
1467 let (closest_edit_ix, (closest_edit_range, _)) =
1468 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1469 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1470 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1471 cmp::min(distance_from_start, distance_from_end)
1472 })?;
1473
1474 let mut edit_start_ix = closest_edit_ix;
1475 for (range, _) in edits[..edit_start_ix].iter().rev() {
1476 let distance_from_closest_edit =
1477 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1478 if distance_from_closest_edit <= 1 {
1479 edit_start_ix -= 1;
1480 } else {
1481 break;
1482 }
1483 }
1484
1485 let mut edit_end_ix = closest_edit_ix + 1;
1486 for (range, _) in &edits[edit_end_ix..] {
1487 let distance_from_closest_edit =
1488 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1489 if distance_from_closest_edit <= 1 {
1490 edit_end_ix += 1;
1491 } else {
1492 break;
1493 }
1494 }
1495
1496 Some(inline_completion::InlineCompletion {
1497 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1498 })
1499 }
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504 use client::test::FakeServer;
1505 use clock::FakeSystemClock;
1506 use gpui::TestAppContext;
1507 use http_client::FakeHttpClient;
1508 use indoc::indoc;
1509 use language_models::RefreshLlmTokenListener;
1510 use rpc::proto;
1511 use settings::SettingsStore;
1512
1513 use super::*;
1514
1515 #[gpui::test]
1516 fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1517 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1518 let completion = InlineCompletion {
1519 edits: cx
1520 .read(|cx| {
1521 to_completion_edits(
1522 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1523 &buffer,
1524 cx,
1525 )
1526 })
1527 .into(),
1528 path: Path::new("").into(),
1529 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1530 id: InlineCompletionId::new(),
1531 excerpt_range: 0..0,
1532 cursor_offset: 0,
1533 input_outline: "".into(),
1534 input_events: "".into(),
1535 input_excerpt: "".into(),
1536 output_excerpt: "".into(),
1537 request_sent_at: Instant::now(),
1538 response_received_at: Instant::now(),
1539 };
1540
1541 assert_eq!(
1542 cx.read(|cx| {
1543 from_completion_edits(
1544 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1545 &buffer,
1546 cx,
1547 )
1548 }),
1549 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1550 );
1551
1552 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1553 assert_eq!(
1554 cx.read(|cx| {
1555 from_completion_edits(
1556 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1557 &buffer,
1558 cx,
1559 )
1560 }),
1561 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1562 );
1563
1564 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1565 assert_eq!(
1566 cx.read(|cx| {
1567 from_completion_edits(
1568 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1569 &buffer,
1570 cx,
1571 )
1572 }),
1573 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1574 );
1575
1576 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1577 assert_eq!(
1578 cx.read(|cx| {
1579 from_completion_edits(
1580 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1581 &buffer,
1582 cx,
1583 )
1584 }),
1585 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1586 );
1587
1588 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1589 assert_eq!(
1590 cx.read(|cx| {
1591 from_completion_edits(
1592 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1593 &buffer,
1594 cx,
1595 )
1596 }),
1597 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1598 );
1599
1600 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1601 assert_eq!(
1602 cx.read(|cx| {
1603 from_completion_edits(
1604 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1605 &buffer,
1606 cx,
1607 )
1608 }),
1609 vec![(9..11, "".to_string())]
1610 );
1611
1612 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1613 assert_eq!(
1614 cx.read(|cx| {
1615 from_completion_edits(
1616 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1617 &buffer,
1618 cx,
1619 )
1620 }),
1621 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1622 );
1623
1624 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1625 assert_eq!(
1626 cx.read(|cx| {
1627 from_completion_edits(
1628 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1629 &buffer,
1630 cx,
1631 )
1632 }),
1633 vec![(4..4, "M".to_string())]
1634 );
1635
1636 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1637 assert_eq!(
1638 cx.read(|cx| completion.interpolate(&buffer.read(cx).snapshot())),
1639 None
1640 );
1641 }
1642
1643 #[gpui::test]
1644 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1645 cx.update(|cx| {
1646 let settings_store = SettingsStore::test(cx);
1647 cx.set_global(settings_store);
1648 client::init_settings(cx);
1649 });
1650
1651 let buffer_content = "lorem\n";
1652 let completion_response = indoc! {"
1653 ```animals.js
1654 <|start_of_file|>
1655 <|editable_region_start|>
1656 lorem
1657 ipsum
1658 <|editable_region_end|>
1659 ```"};
1660
1661 let http_client = FakeHttpClient::create(move |_| async move {
1662 Ok(http_client::Response::builder()
1663 .status(200)
1664 .body(
1665 serde_json::to_string(&PredictEditsResponse {
1666 output_excerpt: completion_response.to_string(),
1667 })
1668 .unwrap()
1669 .into(),
1670 )
1671 .unwrap())
1672 });
1673
1674 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1675 cx.update(|cx| {
1676 RefreshLlmTokenListener::register(client.clone(), cx);
1677 });
1678 let server = FakeServer::for_client(42, &client, cx).await;
1679 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1680 let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1681
1682 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1683 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1684 let completion_task = zeta.update(cx, |zeta, cx| {
1685 zeta.request_completion(&buffer, cursor, false, cx)
1686 });
1687
1688 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1689 server.respond(
1690 token_request.receipt(),
1691 proto::GetLlmTokenResponse { token: "".into() },
1692 );
1693
1694 let completion = completion_task.await.unwrap().unwrap();
1695 buffer.update(cx, |buffer, cx| {
1696 buffer.edit(completion.edits.iter().cloned(), None, cx)
1697 });
1698 assert_eq!(
1699 buffer.read_with(cx, |buffer, _| buffer.text()),
1700 "lorem\nipsum"
1701 );
1702 }
1703
1704 fn to_completion_edits(
1705 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1706 buffer: &Entity<Buffer>,
1707 cx: &App,
1708 ) -> Vec<(Range<Anchor>, String)> {
1709 let buffer = buffer.read(cx);
1710 iterator
1711 .into_iter()
1712 .map(|(range, text)| {
1713 (
1714 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1715 text,
1716 )
1717 })
1718 .collect()
1719 }
1720
1721 fn from_completion_edits(
1722 editor_edits: &[(Range<Anchor>, String)],
1723 buffer: &Entity<Buffer>,
1724 cx: &App,
1725 ) -> Vec<(Range<usize>, String)> {
1726 let buffer = buffer.read(cx);
1727 editor_edits
1728 .iter()
1729 .map(|(range, text)| {
1730 (
1731 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1732 text.clone(),
1733 )
1734 })
1735 .collect()
1736 }
1737
1738 #[ctor::ctor]
1739 fn init_logger() {
1740 if std::env::var("RUST_LOG").is_ok() {
1741 env_logger::init();
1742 }
1743 }
1744}