1mod completion_diff_element;
2mod rate_completion_modal;
3
4pub(crate) use completion_diff_element::*;
5pub use rate_completion_modal::*;
6
7use anyhow::{anyhow, Context as _, Result};
8use arrayvec::ArrayVec;
9use client::{Client, UserStore};
10use collections::{HashMap, HashSet, VecDeque};
11use futures::AsyncReadExt;
12use gpui::{
13 actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
14};
15use http_client::{HttpClient, Method};
16use language::{
17 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
18 Point, ToOffset, ToPoint,
19};
20use language_models::LlmApiToken;
21use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
22use std::{
23 borrow::Cow,
24 cmp,
25 fmt::Write,
26 future::Future,
27 mem,
28 ops::Range,
29 path::Path,
30 sync::Arc,
31 time::{Duration, Instant},
32};
33use telemetry_events::InlineCompletionRating;
34use util::ResultExt;
35use uuid::Uuid;
36
37const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
38const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
39const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
40const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43actions!(edit_prediction, [ClearHistory]);
44
45#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
46pub struct InlineCompletionId(Uuid);
47
48impl From<InlineCompletionId> for gpui::ElementId {
49 fn from(value: InlineCompletionId) -> Self {
50 gpui::ElementId::Uuid(value.0)
51 }
52}
53
54impl std::fmt::Display for InlineCompletionId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl InlineCompletionId {
61 fn new() -> Self {
62 Self(Uuid::new_v4())
63 }
64}
65
66#[derive(Clone)]
67struct ZetaGlobal(Entity<Zeta>);
68
69impl Global for ZetaGlobal {}
70
71#[derive(Clone)]
72pub struct InlineCompletion {
73 id: InlineCompletionId,
74 path: Arc<Path>,
75 excerpt_range: Range<usize>,
76 cursor_offset: usize,
77 edits: Arc<[(Range<Anchor>, String)]>,
78 snapshot: BufferSnapshot,
79 input_outline: Arc<str>,
80 input_events: Arc<str>,
81 input_excerpt: Arc<str>,
82 output_excerpt: Arc<str>,
83 request_sent_at: Instant,
84 response_received_at: Instant,
85}
86
87impl InlineCompletion {
88 fn latency(&self) -> Duration {
89 self.response_received_at
90 .duration_since(self.request_sent_at)
91 }
92
93 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
94 let mut edits = Vec::new();
95
96 let mut user_edits = new_snapshot
97 .edits_since::<usize>(&self.snapshot.version)
98 .peekable();
99 for (model_old_range, model_new_text) in self.edits.iter() {
100 let model_offset_range = model_old_range.to_offset(&self.snapshot);
101 while let Some(next_user_edit) = user_edits.peek() {
102 if next_user_edit.old.end < model_offset_range.start {
103 user_edits.next();
104 } else {
105 break;
106 }
107 }
108
109 if let Some(user_edit) = user_edits.peek() {
110 if user_edit.old.start > model_offset_range.end {
111 edits.push((model_old_range.clone(), model_new_text.clone()));
112 } else if user_edit.old == model_offset_range {
113 let user_new_text = new_snapshot
114 .text_for_range(user_edit.new.clone())
115 .collect::<String>();
116
117 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
118 if !model_suffix.is_empty() {
119 edits.push((
120 new_snapshot.anchor_after(user_edit.new.end)
121 ..new_snapshot.anchor_before(user_edit.new.end),
122 model_suffix.into(),
123 ));
124 }
125
126 user_edits.next();
127 } else {
128 return None;
129 }
130 } else {
131 return None;
132 }
133 } else {
134 edits.push((model_old_range.clone(), model_new_text.clone()));
135 }
136 }
137
138 if edits.is_empty() {
139 None
140 } else {
141 Some(edits)
142 }
143 }
144}
145
146impl std::fmt::Debug for InlineCompletion {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("InlineCompletion")
149 .field("id", &self.id)
150 .field("path", &self.path)
151 .field("edits", &self.edits)
152 .finish_non_exhaustive()
153 }
154}
155
156pub struct Zeta {
157 client: Arc<Client>,
158 events: VecDeque<Event>,
159 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
160 shown_completions: VecDeque<InlineCompletion>,
161 rated_completions: HashSet<InlineCompletionId>,
162 llm_token: LlmApiToken,
163 _llm_token_subscription: Subscription,
164 tos_accepted: bool, // Terms of service accepted
165 _user_store_subscription: Subscription,
166}
167
168impl Zeta {
169 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
170 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
171 }
172
173 pub fn register(
174 client: Arc<Client>,
175 user_store: Entity<UserStore>,
176 cx: &mut App,
177 ) -> Entity<Self> {
178 Self::global(cx).unwrap_or_else(|| {
179 let model = cx.new(|cx| Self::new(client, user_store, cx));
180 cx.set_global(ZetaGlobal(model.clone()));
181 model
182 })
183 }
184
185 pub fn clear_history(&mut self) {
186 self.events.clear();
187 }
188
189 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
190 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
191
192 Self {
193 client,
194 events: VecDeque::new(),
195 shown_completions: VecDeque::new(),
196 rated_completions: HashSet::default(),
197 registered_buffers: HashMap::default(),
198 llm_token: LlmApiToken::default(),
199 _llm_token_subscription: cx.subscribe(
200 &refresh_llm_token_listener,
201 |this, _listener, _event, cx| {
202 let client = this.client.clone();
203 let llm_token = this.llm_token.clone();
204 cx.spawn(|_this, _cx| async move {
205 llm_token.refresh(&client).await?;
206 anyhow::Ok(())
207 })
208 .detach_and_log_err(cx);
209 },
210 ),
211 tos_accepted: user_store
212 .read(cx)
213 .current_user_has_accepted_terms()
214 .unwrap_or(false),
215 _user_store_subscription: cx.subscribe(&user_store, |this, _, event, _| match event {
216 client::user::Event::TermsStatusUpdated { accepted } => {
217 this.tos_accepted = *accepted;
218 }
219 _ => {}
220 }),
221 }
222 }
223
224 fn push_event(&mut self, event: Event) {
225 const MAX_EVENT_COUNT: usize = 16;
226
227 if let Some(Event::BufferChange {
228 new_snapshot: last_new_snapshot,
229 timestamp: last_timestamp,
230 ..
231 }) = self.events.back_mut()
232 {
233 // Coalesce edits for the same buffer when they happen one after the other.
234 let Event::BufferChange {
235 old_snapshot,
236 new_snapshot,
237 timestamp,
238 } = &event;
239
240 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
241 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
242 && old_snapshot.version == last_new_snapshot.version
243 {
244 *last_new_snapshot = new_snapshot.clone();
245 *last_timestamp = *timestamp;
246 return;
247 }
248 }
249
250 self.events.push_back(event);
251 if self.events.len() >= MAX_EVENT_COUNT {
252 self.events.drain(..MAX_EVENT_COUNT / 2);
253 }
254 }
255
256 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
257 let buffer_id = buffer.entity_id();
258 let weak_buffer = buffer.downgrade();
259
260 if let std::collections::hash_map::Entry::Vacant(entry) =
261 self.registered_buffers.entry(buffer_id)
262 {
263 let snapshot = buffer.read(cx).snapshot();
264
265 entry.insert(RegisteredBuffer {
266 snapshot,
267 _subscriptions: [
268 cx.subscribe(buffer, move |this, buffer, event, cx| {
269 this.handle_buffer_event(buffer, event, cx);
270 }),
271 cx.observe_release(buffer, move |this, _buffer, _cx| {
272 this.registered_buffers.remove(&weak_buffer.entity_id());
273 }),
274 ],
275 });
276 };
277 }
278
279 fn handle_buffer_event(
280 &mut self,
281 buffer: Entity<Buffer>,
282 event: &language::BufferEvent,
283 cx: &mut Context<Self>,
284 ) {
285 match event {
286 language::BufferEvent::Edited => {
287 self.report_changes_for_buffer(&buffer, cx);
288 }
289 _ => {}
290 }
291 }
292
293 pub fn request_completion_impl<F, R>(
294 &mut self,
295 buffer: &Entity<Buffer>,
296 position: language::Anchor,
297 cx: &mut Context<Self>,
298 perform_predict_edits: F,
299 ) -> Task<Result<Option<InlineCompletion>>>
300 where
301 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
302 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
303 {
304 let snapshot = self.report_changes_for_buffer(buffer, cx);
305 let point = position.to_point(&snapshot);
306 let offset = point.to_offset(&snapshot);
307 let excerpt_range = excerpt_range_for_position(point, &snapshot);
308 let events = self.events.clone();
309 let path = snapshot
310 .file()
311 .map(|f| Arc::from(f.full_path(cx).as_path()))
312 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
313
314 let client = self.client.clone();
315 let llm_token = self.llm_token.clone();
316
317 cx.spawn(|_, cx| async move {
318 let request_sent_at = Instant::now();
319
320 let (input_events, input_excerpt, input_outline) = cx
321 .background_executor()
322 .spawn({
323 let snapshot = snapshot.clone();
324 let excerpt_range = excerpt_range.clone();
325 async move {
326 let mut input_events = String::new();
327 for event in events {
328 if !input_events.is_empty() {
329 input_events.push('\n');
330 input_events.push('\n');
331 }
332 input_events.push_str(&event.to_prompt());
333 }
334
335 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
336 let input_outline = prompt_for_outline(&snapshot);
337
338 (input_events, input_excerpt, input_outline)
339 }
340 })
341 .await;
342
343 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
344
345 let body = PredictEditsParams {
346 input_events: input_events.clone(),
347 input_excerpt: input_excerpt.clone(),
348 outline: Some(input_outline.clone()),
349 };
350
351 let response = perform_predict_edits(client, llm_token, body).await?;
352
353 let output_excerpt = response.output_excerpt;
354 log::debug!("completion response: {}", output_excerpt);
355
356 Self::process_completion_response(
357 output_excerpt,
358 &snapshot,
359 excerpt_range,
360 offset,
361 path,
362 input_outline,
363 input_events,
364 input_excerpt,
365 request_sent_at,
366 &cx,
367 )
368 .await
369 })
370 }
371
372 // Generates several example completions of various states to fill the Zeta completion modal
373 #[cfg(any(test, feature = "test-support"))]
374 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
375 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
376 And maybe a short line
377
378 Then a few lines
379
380 and then another
381 "#};
382
383 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
384 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
385
386 let completion_tasks = vec![
387 self.fake_completion(
388 &buffer,
389 position,
390 PredictEditsResponse {
391 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
392a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
393[here's an edit]
394And maybe a short line
395Then a few lines
396and then another
397{EDITABLE_REGION_END_MARKER}
398 ", ),
399 },
400 cx,
401 ),
402 self.fake_completion(
403 &buffer,
404 position,
405 PredictEditsResponse {
406 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
407a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
408And maybe a short line
409[and another edit]
410Then a few lines
411and then another
412{EDITABLE_REGION_END_MARKER}
413 "#),
414 },
415 cx,
416 ),
417 self.fake_completion(
418 &buffer,
419 position,
420 PredictEditsResponse {
421 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
422a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
423And maybe a short line
424
425Then a few lines
426
427and then another
428{EDITABLE_REGION_END_MARKER}
429 "#),
430 },
431 cx,
432 ),
433 self.fake_completion(
434 &buffer,
435 position,
436 PredictEditsResponse {
437 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
438a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
439And maybe a short line
440
441Then a few lines
442
443and then another
444{EDITABLE_REGION_END_MARKER}
445 "#),
446 },
447 cx,
448 ),
449 self.fake_completion(
450 &buffer,
451 position,
452 PredictEditsResponse {
453 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
454a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
455And maybe a short line
456Then a few lines
457[a third completion]
458and then another
459{EDITABLE_REGION_END_MARKER}
460 "#),
461 },
462 cx,
463 ),
464 self.fake_completion(
465 &buffer,
466 position,
467 PredictEditsResponse {
468 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
469a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
470And maybe a short line
471and then another
472[fourth completion example]
473{EDITABLE_REGION_END_MARKER}
474 "#),
475 },
476 cx,
477 ),
478 self.fake_completion(
479 &buffer,
480 position,
481 PredictEditsResponse {
482 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
483a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
484And maybe a short line
485Then a few lines
486and then another
487[fifth and final completion]
488{EDITABLE_REGION_END_MARKER}
489 "#),
490 },
491 cx,
492 ),
493 ];
494
495 cx.spawn(|zeta, mut cx| async move {
496 for task in completion_tasks {
497 task.await.unwrap();
498 }
499
500 zeta.update(&mut cx, |zeta, _cx| {
501 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
502 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
503 })
504 .ok();
505 })
506 }
507
508 #[cfg(any(test, feature = "test-support"))]
509 pub fn fake_completion(
510 &mut self,
511 buffer: &Entity<Buffer>,
512 position: language::Anchor,
513 response: PredictEditsResponse,
514 cx: &mut Context<Self>,
515 ) -> Task<Result<Option<InlineCompletion>>> {
516 use std::future::ready;
517
518 self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
519 }
520
521 pub fn request_completion(
522 &mut self,
523 buffer: &Entity<Buffer>,
524 position: language::Anchor,
525 cx: &mut Context<Self>,
526 ) -> Task<Result<Option<InlineCompletion>>> {
527 self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
528 }
529
530 fn perform_predict_edits(
531 client: Arc<Client>,
532 llm_token: LlmApiToken,
533 body: PredictEditsParams,
534 ) -> impl Future<Output = Result<PredictEditsResponse>> {
535 async move {
536 let http_client = client.http_client();
537 let mut token = llm_token.acquire(&client).await?;
538 let mut did_retry = false;
539
540 loop {
541 let request_builder = http_client::Request::builder();
542 let request = request_builder
543 .method(Method::POST)
544 .uri(
545 http_client
546 .build_zed_llm_url("/predict_edits", &[])?
547 .as_ref(),
548 )
549 .header("Content-Type", "application/json")
550 .header("Authorization", format!("Bearer {}", token))
551 .body(serde_json::to_string(&body)?.into())?;
552
553 let mut response = http_client.send(request).await?;
554
555 if response.status().is_success() {
556 let mut body = String::new();
557 response.body_mut().read_to_string(&mut body).await?;
558 return Ok(serde_json::from_str(&body)?);
559 } else if !did_retry
560 && response
561 .headers()
562 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
563 .is_some()
564 {
565 did_retry = true;
566 token = llm_token.refresh(&client).await?;
567 } else {
568 let mut body = String::new();
569 response.body_mut().read_to_string(&mut body).await?;
570 return Err(anyhow!(
571 "error predicting edits.\nStatus: {:?}\nBody: {}",
572 response.status(),
573 body
574 ));
575 }
576 }
577 }
578 }
579
580 #[allow(clippy::too_many_arguments)]
581 fn process_completion_response(
582 output_excerpt: String,
583 snapshot: &BufferSnapshot,
584 excerpt_range: Range<usize>,
585 cursor_offset: usize,
586 path: Arc<Path>,
587 input_outline: String,
588 input_events: String,
589 input_excerpt: String,
590 request_sent_at: Instant,
591 cx: &AsyncApp,
592 ) -> Task<Result<Option<InlineCompletion>>> {
593 let snapshot = snapshot.clone();
594 cx.background_executor().spawn(async move {
595 let content = output_excerpt.replace(CURSOR_MARKER, "");
596
597 let start_markers = content
598 .match_indices(EDITABLE_REGION_START_MARKER)
599 .collect::<Vec<_>>();
600 anyhow::ensure!(
601 start_markers.len() == 1,
602 "expected exactly one start marker, found {}",
603 start_markers.len()
604 );
605
606 let codefence_start = start_markers[0].0;
607 let content = &content[codefence_start..];
608
609 let newline_ix = content.find('\n').context("could not find newline")?;
610 let content = &content[newline_ix + 1..];
611
612 let codefence_end = content
613 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
614 .context("could not find end marker")?;
615 let new_text = &content[..codefence_end];
616
617 let old_text = snapshot
618 .text_for_range(excerpt_range.clone())
619 .collect::<String>();
620
621 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
622
623 Ok(Some(InlineCompletion {
624 id: InlineCompletionId::new(),
625 path,
626 excerpt_range,
627 cursor_offset,
628 edits: edits.into(),
629 snapshot: snapshot.clone(),
630 input_outline: input_outline.into(),
631 input_events: input_events.into(),
632 input_excerpt: input_excerpt.into(),
633 output_excerpt: output_excerpt.into(),
634 request_sent_at,
635 response_received_at: Instant::now(),
636 }))
637 })
638 }
639
640 pub fn compute_edits(
641 old_text: String,
642 new_text: &str,
643 offset: usize,
644 snapshot: &BufferSnapshot,
645 ) -> Vec<(Range<Anchor>, String)> {
646 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
647
648 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
649 let mut old_start = offset;
650 for change in diff.iter_all_changes() {
651 let value = change.value();
652 match change.tag() {
653 similar::ChangeTag::Equal => {
654 old_start += value.len();
655 }
656 similar::ChangeTag::Delete => {
657 let old_end = old_start + value.len();
658 if let Some((last_old_range, _)) = edits.last_mut() {
659 if last_old_range.end == old_start {
660 last_old_range.end = old_end;
661 } else {
662 edits.push((old_start..old_end, String::new()));
663 }
664 } else {
665 edits.push((old_start..old_end, String::new()));
666 }
667 old_start = old_end;
668 }
669 similar::ChangeTag::Insert => {
670 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
671 if last_old_range.end == old_start {
672 last_new_text.push_str(value);
673 } else {
674 edits.push((old_start..old_start, value.into()));
675 }
676 } else {
677 edits.push((old_start..old_start, value.into()));
678 }
679 }
680 }
681 }
682
683 edits
684 .into_iter()
685 .map(|(mut old_range, new_text)| {
686 let prefix_len = common_prefix(
687 snapshot.chars_for_range(old_range.clone()),
688 new_text.chars(),
689 );
690 old_range.start += prefix_len;
691 let suffix_len = common_prefix(
692 snapshot.reversed_chars_for_range(old_range.clone()),
693 new_text[prefix_len..].chars().rev(),
694 );
695 old_range.end = old_range.end.saturating_sub(suffix_len);
696
697 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
698 (
699 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
700 new_text,
701 )
702 })
703 .collect()
704 }
705
706 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
707 self.rated_completions.contains(&completion_id)
708 }
709
710 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
711 self.shown_completions.push_front(completion.clone());
712 if self.shown_completions.len() > 50 {
713 let completion = self.shown_completions.pop_back().unwrap();
714 self.rated_completions.remove(&completion.id);
715 }
716 cx.notify();
717 }
718
719 pub fn rate_completion(
720 &mut self,
721 completion: &InlineCompletion,
722 rating: InlineCompletionRating,
723 feedback: String,
724 cx: &mut Context<Self>,
725 ) {
726 self.rated_completions.insert(completion.id);
727 telemetry::event!(
728 "Inline Completion Rated",
729 rating,
730 input_events = completion.input_events,
731 input_excerpt = completion.input_excerpt,
732 input_outline = completion.input_outline,
733 output_excerpt = completion.output_excerpt,
734 feedback
735 );
736 self.client.telemetry().flush_events();
737 cx.notify();
738 }
739
740 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
741 self.shown_completions.iter()
742 }
743
744 pub fn shown_completions_len(&self) -> usize {
745 self.shown_completions.len()
746 }
747
748 fn report_changes_for_buffer(
749 &mut self,
750 buffer: &Entity<Buffer>,
751 cx: &mut Context<Self>,
752 ) -> BufferSnapshot {
753 self.register_buffer(buffer, cx);
754
755 let registered_buffer = self
756 .registered_buffers
757 .get_mut(&buffer.entity_id())
758 .unwrap();
759 let new_snapshot = buffer.read(cx).snapshot();
760
761 if new_snapshot.version != registered_buffer.snapshot.version {
762 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
763 self.push_event(Event::BufferChange {
764 old_snapshot,
765 new_snapshot: new_snapshot.clone(),
766 timestamp: Instant::now(),
767 });
768 }
769
770 new_snapshot
771 }
772}
773
774fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
775 a.zip(b)
776 .take_while(|(a, b)| a == b)
777 .map(|(a, _)| a.len_utf8())
778 .sum()
779}
780
781fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
782 let mut input_outline = String::new();
783
784 writeln!(
785 input_outline,
786 "```{}",
787 snapshot
788 .file()
789 .map_or(Cow::Borrowed("untitled"), |file| file
790 .path()
791 .to_string_lossy())
792 )
793 .unwrap();
794
795 if let Some(outline) = snapshot.outline(None) {
796 let guess_size = outline.items.len() * 15;
797 input_outline.reserve(guess_size);
798 for item in outline.items.iter() {
799 let spacing = " ".repeat(item.depth);
800 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
801 }
802 }
803
804 writeln!(input_outline, "```").unwrap();
805
806 input_outline
807}
808
809fn prompt_for_excerpt(
810 snapshot: &BufferSnapshot,
811 excerpt_range: &Range<usize>,
812 offset: usize,
813) -> String {
814 let mut prompt_excerpt = String::new();
815 writeln!(
816 prompt_excerpt,
817 "```{}",
818 snapshot
819 .file()
820 .map_or(Cow::Borrowed("untitled"), |file| file
821 .path()
822 .to_string_lossy())
823 )
824 .unwrap();
825
826 if excerpt_range.start == 0 {
827 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
828 }
829
830 let point_range = excerpt_range.to_point(snapshot);
831 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
832 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
833 for chunk in snapshot.text_for_range(extra_context_line_range) {
834 prompt_excerpt.push_str(chunk);
835 }
836 }
837 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
838 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
839 prompt_excerpt.push_str(chunk);
840 }
841 prompt_excerpt.push_str(CURSOR_MARKER);
842 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
843 prompt_excerpt.push_str(chunk);
844 }
845 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
846
847 if point_range.end.row < snapshot.max_point().row
848 && !snapshot.is_line_blank(point_range.end.row + 1)
849 {
850 let extra_context_line_range = point_range.end
851 ..Point::new(
852 point_range.end.row + 1,
853 snapshot.line_len(point_range.end.row + 1),
854 );
855 for chunk in snapshot.text_for_range(extra_context_line_range) {
856 prompt_excerpt.push_str(chunk);
857 }
858 }
859
860 write!(prompt_excerpt, "\n```").unwrap();
861 prompt_excerpt
862}
863
864fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
865 const CONTEXT_LINES: u32 = 32;
866
867 let mut context_lines_before = CONTEXT_LINES;
868 let mut context_lines_after = CONTEXT_LINES;
869 if point.row < CONTEXT_LINES {
870 context_lines_after += CONTEXT_LINES - point.row;
871 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
872 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
873 }
874
875 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
876 let excerpt_start = Point::new(excerpt_start_row, 0);
877 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
878 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
879 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
880}
881
882struct RegisteredBuffer {
883 snapshot: BufferSnapshot,
884 _subscriptions: [gpui::Subscription; 2],
885}
886
887#[derive(Clone)]
888enum Event {
889 BufferChange {
890 old_snapshot: BufferSnapshot,
891 new_snapshot: BufferSnapshot,
892 timestamp: Instant,
893 },
894}
895
896impl Event {
897 fn to_prompt(&self) -> String {
898 match self {
899 Event::BufferChange {
900 old_snapshot,
901 new_snapshot,
902 ..
903 } => {
904 let mut prompt = String::new();
905
906 let old_path = old_snapshot
907 .file()
908 .map(|f| f.path().as_ref())
909 .unwrap_or(Path::new("untitled"));
910 let new_path = new_snapshot
911 .file()
912 .map(|f| f.path().as_ref())
913 .unwrap_or(Path::new("untitled"));
914 if old_path != new_path {
915 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
916 }
917
918 let diff =
919 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
920 .unified_diff()
921 .to_string();
922 if !diff.is_empty() {
923 write!(
924 prompt,
925 "User edited {:?}:\n```diff\n{}\n```",
926 new_path, diff
927 )
928 .unwrap();
929 }
930
931 prompt
932 }
933 }
934 }
935}
936
937#[derive(Debug, Clone)]
938struct CurrentInlineCompletion {
939 buffer_id: EntityId,
940 completion: InlineCompletion,
941}
942
943impl CurrentInlineCompletion {
944 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
945 if self.buffer_id != old_completion.buffer_id {
946 return true;
947 }
948
949 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
950 return true;
951 };
952 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
953 return false;
954 };
955
956 if old_edits.len() == 1 && new_edits.len() == 1 {
957 let (old_range, old_text) = &old_edits[0];
958 let (new_range, new_text) = &new_edits[0];
959 new_range == old_range && new_text.starts_with(old_text)
960 } else {
961 true
962 }
963 }
964}
965
966struct PendingCompletion {
967 id: usize,
968 _task: Task<()>,
969}
970
971pub struct ZetaInlineCompletionProvider {
972 zeta: Entity<Zeta>,
973 pending_completions: ArrayVec<PendingCompletion, 2>,
974 next_pending_completion_id: usize,
975 current_completion: Option<CurrentInlineCompletion>,
976}
977
978impl ZetaInlineCompletionProvider {
979 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
980
981 pub fn new(zeta: Entity<Zeta>) -> Self {
982 Self {
983 zeta,
984 pending_completions: ArrayVec::new(),
985 next_pending_completion_id: 0,
986 current_completion: None,
987 }
988 }
989}
990
991impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
992 fn name() -> &'static str {
993 "zed-predict"
994 }
995
996 fn display_name() -> &'static str {
997 "Zed Predict"
998 }
999
1000 fn show_completions_in_menu() -> bool {
1001 true
1002 }
1003
1004 fn show_completions_in_normal_mode() -> bool {
1005 true
1006 }
1007
1008 fn show_tab_accept_marker() -> bool {
1009 true
1010 }
1011
1012 fn is_enabled(
1013 &self,
1014 buffer: &Entity<Buffer>,
1015 cursor_position: language::Anchor,
1016 cx: &App,
1017 ) -> bool {
1018 let buffer = buffer.read(cx);
1019 let file = buffer.file();
1020 let language = buffer.language_at(cursor_position);
1021 let settings = all_language_settings(file, cx);
1022 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1023 }
1024
1025 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1026 !self.zeta.read(cx).tos_accepted
1027 }
1028
1029 fn is_refreshing(&self) -> bool {
1030 !self.pending_completions.is_empty()
1031 }
1032
1033 fn refresh(
1034 &mut self,
1035 buffer: Entity<Buffer>,
1036 position: language::Anchor,
1037 debounce: bool,
1038 cx: &mut Context<Self>,
1039 ) {
1040 if !self.zeta.read(cx).tos_accepted {
1041 return;
1042 }
1043
1044 let pending_completion_id = self.next_pending_completion_id;
1045 self.next_pending_completion_id += 1;
1046
1047 let task = cx.spawn(|this, mut cx| async move {
1048 if debounce {
1049 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1050 }
1051
1052 let completion_request = this.update(&mut cx, |this, cx| {
1053 this.zeta.update(cx, |zeta, cx| {
1054 zeta.request_completion(&buffer, position, cx)
1055 })
1056 });
1057
1058 let completion = match completion_request {
1059 Ok(completion_request) => {
1060 let completion_request = completion_request.await;
1061 completion_request.map(|c| {
1062 c.map(|completion| CurrentInlineCompletion {
1063 buffer_id: buffer.entity_id(),
1064 completion,
1065 })
1066 })
1067 }
1068 Err(error) => Err(error),
1069 };
1070 let Some(new_completion) = completion
1071 .context("edit prediction failed")
1072 .log_err()
1073 .flatten()
1074 else {
1075 return;
1076 };
1077
1078 this.update(&mut cx, |this, cx| {
1079 if this.pending_completions[0].id == pending_completion_id {
1080 this.pending_completions.remove(0);
1081 } else {
1082 this.pending_completions.clear();
1083 }
1084
1085 if let Some(old_completion) = this.current_completion.as_ref() {
1086 let snapshot = buffer.read(cx).snapshot();
1087 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1088 this.zeta.update(cx, |zeta, cx| {
1089 zeta.completion_shown(&new_completion.completion, cx);
1090 });
1091 this.current_completion = Some(new_completion);
1092 }
1093 } else {
1094 this.zeta.update(cx, |zeta, cx| {
1095 zeta.completion_shown(&new_completion.completion, cx);
1096 });
1097 this.current_completion = Some(new_completion);
1098 }
1099
1100 cx.notify();
1101 })
1102 .ok();
1103 });
1104
1105 // We always maintain at most two pending completions. When we already
1106 // have two, we replace the newest one.
1107 if self.pending_completions.len() <= 1 {
1108 self.pending_completions.push(PendingCompletion {
1109 id: pending_completion_id,
1110 _task: task,
1111 });
1112 } else if self.pending_completions.len() == 2 {
1113 self.pending_completions.pop();
1114 self.pending_completions.push(PendingCompletion {
1115 id: pending_completion_id,
1116 _task: task,
1117 });
1118 }
1119 }
1120
1121 fn cycle(
1122 &mut self,
1123 _buffer: Entity<Buffer>,
1124 _cursor_position: language::Anchor,
1125 _direction: inline_completion::Direction,
1126 _cx: &mut Context<Self>,
1127 ) {
1128 // Right now we don't support cycling.
1129 }
1130
1131 fn accept(&mut self, _cx: &mut Context<Self>) {
1132 self.pending_completions.clear();
1133 }
1134
1135 fn discard(&mut self, _cx: &mut Context<Self>) {
1136 self.pending_completions.clear();
1137 self.current_completion.take();
1138 }
1139
1140 fn suggest(
1141 &mut self,
1142 buffer: &Entity<Buffer>,
1143 cursor_position: language::Anchor,
1144 cx: &mut Context<Self>,
1145 ) -> Option<inline_completion::InlineCompletion> {
1146 let CurrentInlineCompletion {
1147 buffer_id,
1148 completion,
1149 ..
1150 } = self.current_completion.as_mut()?;
1151
1152 // Invalidate previous completion if it was generated for a different buffer.
1153 if *buffer_id != buffer.entity_id() {
1154 self.current_completion.take();
1155 return None;
1156 }
1157
1158 let buffer = buffer.read(cx);
1159 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1160 self.current_completion.take();
1161 return None;
1162 };
1163
1164 let cursor_row = cursor_position.to_point(buffer).row;
1165 let (closest_edit_ix, (closest_edit_range, _)) =
1166 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1167 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1168 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1169 cmp::min(distance_from_start, distance_from_end)
1170 })?;
1171
1172 let mut edit_start_ix = closest_edit_ix;
1173 for (range, _) in edits[..edit_start_ix].iter().rev() {
1174 let distance_from_closest_edit =
1175 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1176 if distance_from_closest_edit <= 1 {
1177 edit_start_ix -= 1;
1178 } else {
1179 break;
1180 }
1181 }
1182
1183 let mut edit_end_ix = closest_edit_ix + 1;
1184 for (range, _) in &edits[edit_end_ix..] {
1185 let distance_from_closest_edit =
1186 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1187 if distance_from_closest_edit <= 1 {
1188 edit_end_ix += 1;
1189 } else {
1190 break;
1191 }
1192 }
1193
1194 Some(inline_completion::InlineCompletion {
1195 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1196 })
1197 }
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202 use client::test::FakeServer;
1203 use clock::FakeSystemClock;
1204 use gpui::TestAppContext;
1205 use http_client::FakeHttpClient;
1206 use indoc::indoc;
1207 use language_models::RefreshLlmTokenListener;
1208 use rpc::proto;
1209 use settings::SettingsStore;
1210
1211 use super::*;
1212
1213 #[gpui::test]
1214 fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1215 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1216 let completion = InlineCompletion {
1217 edits: cx
1218 .read(|cx| {
1219 to_completion_edits(
1220 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1221 &buffer,
1222 cx,
1223 )
1224 })
1225 .into(),
1226 path: Path::new("").into(),
1227 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1228 id: InlineCompletionId::new(),
1229 excerpt_range: 0..0,
1230 cursor_offset: 0,
1231 input_outline: "".into(),
1232 input_events: "".into(),
1233 input_excerpt: "".into(),
1234 output_excerpt: "".into(),
1235 request_sent_at: Instant::now(),
1236 response_received_at: Instant::now(),
1237 };
1238
1239 assert_eq!(
1240 cx.read(|cx| {
1241 from_completion_edits(
1242 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1243 &buffer,
1244 cx,
1245 )
1246 }),
1247 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1248 );
1249
1250 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1251 assert_eq!(
1252 cx.read(|cx| {
1253 from_completion_edits(
1254 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1255 &buffer,
1256 cx,
1257 )
1258 }),
1259 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1260 );
1261
1262 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1263 assert_eq!(
1264 cx.read(|cx| {
1265 from_completion_edits(
1266 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1267 &buffer,
1268 cx,
1269 )
1270 }),
1271 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1272 );
1273
1274 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1275 assert_eq!(
1276 cx.read(|cx| {
1277 from_completion_edits(
1278 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1279 &buffer,
1280 cx,
1281 )
1282 }),
1283 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1284 );
1285
1286 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1287 assert_eq!(
1288 cx.read(|cx| {
1289 from_completion_edits(
1290 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1291 &buffer,
1292 cx,
1293 )
1294 }),
1295 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1296 );
1297
1298 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1299 assert_eq!(
1300 cx.read(|cx| {
1301 from_completion_edits(
1302 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1303 &buffer,
1304 cx,
1305 )
1306 }),
1307 vec![(9..11, "".to_string())]
1308 );
1309
1310 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1311 assert_eq!(
1312 cx.read(|cx| {
1313 from_completion_edits(
1314 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1315 &buffer,
1316 cx,
1317 )
1318 }),
1319 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1320 );
1321
1322 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1323 assert_eq!(
1324 cx.read(|cx| {
1325 from_completion_edits(
1326 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1327 &buffer,
1328 cx,
1329 )
1330 }),
1331 vec![(4..4, "M".to_string())]
1332 );
1333
1334 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1335 assert_eq!(
1336 cx.read(|cx| completion.interpolate(&buffer.read(cx).snapshot())),
1337 None
1338 );
1339 }
1340
1341 #[gpui::test]
1342 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1343 cx.update(|cx| {
1344 let settings_store = SettingsStore::test(cx);
1345 cx.set_global(settings_store);
1346 client::init_settings(cx);
1347 });
1348
1349 let buffer_content = "lorem\n";
1350 let completion_response = indoc! {"
1351 ```animals.js
1352 <|start_of_file|>
1353 <|editable_region_start|>
1354 lorem
1355 ipsum
1356 <|editable_region_end|>
1357 ```"};
1358
1359 let http_client = FakeHttpClient::create(move |_| async move {
1360 Ok(http_client::Response::builder()
1361 .status(200)
1362 .body(
1363 serde_json::to_string(&PredictEditsResponse {
1364 output_excerpt: completion_response.to_string(),
1365 })
1366 .unwrap()
1367 .into(),
1368 )
1369 .unwrap())
1370 });
1371
1372 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1373 cx.update(|cx| {
1374 RefreshLlmTokenListener::register(client.clone(), cx);
1375 });
1376 let server = FakeServer::for_client(42, &client, cx).await;
1377 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1378 let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1379
1380 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1381 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1382 let completion_task =
1383 zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1384
1385 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1386 server.respond(
1387 token_request.receipt(),
1388 proto::GetLlmTokenResponse { token: "".into() },
1389 );
1390
1391 let completion = completion_task.await.unwrap().unwrap();
1392 buffer.update(cx, |buffer, cx| {
1393 buffer.edit(completion.edits.iter().cloned(), None, cx)
1394 });
1395 assert_eq!(
1396 buffer.read_with(cx, |buffer, _| buffer.text()),
1397 "lorem\nipsum"
1398 );
1399 }
1400
1401 fn to_completion_edits(
1402 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1403 buffer: &Entity<Buffer>,
1404 cx: &App,
1405 ) -> Vec<(Range<Anchor>, String)> {
1406 let buffer = buffer.read(cx);
1407 iterator
1408 .into_iter()
1409 .map(|(range, text)| {
1410 (
1411 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1412 text,
1413 )
1414 })
1415 .collect()
1416 }
1417
1418 fn from_completion_edits(
1419 editor_edits: &[(Range<Anchor>, String)],
1420 buffer: &Entity<Buffer>,
1421 cx: &App,
1422 ) -> Vec<(Range<usize>, String)> {
1423 let buffer = buffer.read(cx);
1424 editor_edits
1425 .iter()
1426 .map(|(range, text)| {
1427 (
1428 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1429 text.clone(),
1430 )
1431 })
1432 .collect()
1433 }
1434
1435 #[ctor::ctor]
1436 fn init_logger() {
1437 if std::env::var("RUST_LOG").is_ok() {
1438 env_logger::init();
1439 }
1440 }
1441}