1mod completion_diff_element;
2mod rate_completion_modal;
3
4pub(crate) use completion_diff_element::*;
5pub use rate_completion_modal::*;
6
7use anyhow::{anyhow, Context as _, Result};
8use arrayvec::ArrayVec;
9use client::{Client, UserStore};
10use collections::{HashMap, HashSet, VecDeque};
11use futures::AsyncReadExt;
12use gpui::{
13 actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
14 Subscription, Task,
15};
16use http_client::{HttpClient, Method};
17use language::{
18 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
19 Point, ToOffset, ToPoint,
20};
21use language_models::LlmApiToken;
22use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
23use std::{
24 borrow::Cow,
25 cmp,
26 fmt::Write,
27 future::Future,
28 mem,
29 ops::Range,
30 path::Path,
31 sync::Arc,
32 time::{Duration, Instant},
33};
34use telemetry_events::InlineCompletionRating;
35use util::ResultExt;
36use uuid::Uuid;
37
38const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
39const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
40const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
41const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
42const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
43
44actions!(zeta, [ClearHistory]);
45
46#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
47pub struct InlineCompletionId(Uuid);
48
49impl From<InlineCompletionId> for gpui::ElementId {
50 fn from(value: InlineCompletionId) -> Self {
51 gpui::ElementId::Uuid(value.0)
52 }
53}
54
55impl std::fmt::Display for InlineCompletionId {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{}", self.0)
58 }
59}
60
61impl InlineCompletionId {
62 fn new() -> Self {
63 Self(Uuid::new_v4())
64 }
65}
66
67#[derive(Clone)]
68struct ZetaGlobal(Model<Zeta>);
69
70impl Global for ZetaGlobal {}
71
72#[derive(Clone)]
73pub struct InlineCompletion {
74 id: InlineCompletionId,
75 path: Arc<Path>,
76 excerpt_range: Range<usize>,
77 cursor_offset: usize,
78 edits: Arc<[(Range<Anchor>, String)]>,
79 snapshot: BufferSnapshot,
80 input_outline: Arc<str>,
81 input_events: Arc<str>,
82 input_excerpt: Arc<str>,
83 output_excerpt: Arc<str>,
84 request_sent_at: Instant,
85 response_received_at: Instant,
86}
87
88impl InlineCompletion {
89 fn latency(&self) -> Duration {
90 self.response_received_at
91 .duration_since(self.request_sent_at)
92 }
93
94 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
95 let mut edits = Vec::new();
96
97 let mut user_edits = new_snapshot
98 .edits_since::<usize>(&self.snapshot.version)
99 .peekable();
100 for (model_old_range, model_new_text) in self.edits.iter() {
101 let model_offset_range = model_old_range.to_offset(&self.snapshot);
102 while let Some(next_user_edit) = user_edits.peek() {
103 if next_user_edit.old.end < model_offset_range.start {
104 user_edits.next();
105 } else {
106 break;
107 }
108 }
109
110 if let Some(user_edit) = user_edits.peek() {
111 if user_edit.old.start > model_offset_range.end {
112 edits.push((model_old_range.clone(), model_new_text.clone()));
113 } else if user_edit.old == model_offset_range {
114 let user_new_text = new_snapshot
115 .text_for_range(user_edit.new.clone())
116 .collect::<String>();
117
118 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
119 if !model_suffix.is_empty() {
120 edits.push((
121 new_snapshot.anchor_after(user_edit.new.end)
122 ..new_snapshot.anchor_before(user_edit.new.end),
123 model_suffix.into(),
124 ));
125 }
126
127 user_edits.next();
128 } else {
129 return None;
130 }
131 } else {
132 return None;
133 }
134 } else {
135 edits.push((model_old_range.clone(), model_new_text.clone()));
136 }
137 }
138
139 if edits.is_empty() {
140 None
141 } else {
142 Some(edits)
143 }
144 }
145}
146
147impl std::fmt::Debug for InlineCompletion {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("InlineCompletion")
150 .field("id", &self.id)
151 .field("path", &self.path)
152 .field("edits", &self.edits)
153 .finish_non_exhaustive()
154 }
155}
156
157pub struct Zeta {
158 client: Arc<Client>,
159 events: VecDeque<Event>,
160 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
161 shown_completions: VecDeque<InlineCompletion>,
162 rated_completions: HashSet<InlineCompletionId>,
163 llm_token: LlmApiToken,
164 _llm_token_subscription: Subscription,
165 tos_accepted: bool, // Terms of service accepted
166 _user_store_subscription: Subscription,
167}
168
169impl Zeta {
170 pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
171 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
172 }
173
174 pub fn register(
175 client: Arc<Client>,
176 user_store: Model<UserStore>,
177 cx: &mut AppContext,
178 ) -> Model<Self> {
179 Self::global(cx).unwrap_or_else(|| {
180 let model = cx.new_model(|cx| Self::new(client, user_store, cx));
181 cx.set_global(ZetaGlobal(model.clone()));
182 model
183 })
184 }
185
186 pub fn clear_history(&mut self) {
187 self.events.clear();
188 }
189
190 fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
191 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
192
193 Self {
194 client,
195 events: VecDeque::new(),
196 shown_completions: VecDeque::new(),
197 rated_completions: HashSet::default(),
198 registered_buffers: HashMap::default(),
199 llm_token: LlmApiToken::default(),
200 _llm_token_subscription: cx.subscribe(
201 &refresh_llm_token_listener,
202 |this, _listener, _event, cx| {
203 let client = this.client.clone();
204 let llm_token = this.llm_token.clone();
205 cx.spawn(|_this, _cx| async move {
206 llm_token.refresh(&client).await?;
207 anyhow::Ok(())
208 })
209 .detach_and_log_err(cx);
210 },
211 ),
212 tos_accepted: user_store
213 .read(cx)
214 .current_user_has_accepted_terms()
215 .unwrap_or(false),
216 _user_store_subscription: cx.subscribe(&user_store, |this, _, event, _| match event {
217 client::user::Event::TermsStatusUpdated { accepted } => {
218 this.tos_accepted = *accepted;
219 }
220 _ => {}
221 }),
222 }
223 }
224
225 fn push_event(&mut self, event: Event) {
226 const MAX_EVENT_COUNT: usize = 16;
227
228 if let Some(Event::BufferChange {
229 new_snapshot: last_new_snapshot,
230 timestamp: last_timestamp,
231 ..
232 }) = self.events.back_mut()
233 {
234 // Coalesce edits for the same buffer when they happen one after the other.
235 let Event::BufferChange {
236 old_snapshot,
237 new_snapshot,
238 timestamp,
239 } = &event;
240
241 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
242 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
243 && old_snapshot.version == last_new_snapshot.version
244 {
245 *last_new_snapshot = new_snapshot.clone();
246 *last_timestamp = *timestamp;
247 return;
248 }
249 }
250
251 self.events.push_back(event);
252 if self.events.len() >= MAX_EVENT_COUNT {
253 self.events.drain(..MAX_EVENT_COUNT / 2);
254 }
255 }
256
257 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
258 let buffer_id = buffer.entity_id();
259 let weak_buffer = buffer.downgrade();
260
261 if let std::collections::hash_map::Entry::Vacant(entry) =
262 self.registered_buffers.entry(buffer_id)
263 {
264 let snapshot = buffer.read(cx).snapshot();
265
266 entry.insert(RegisteredBuffer {
267 snapshot,
268 _subscriptions: [
269 cx.subscribe(buffer, move |this, buffer, event, cx| {
270 this.handle_buffer_event(buffer, event, cx);
271 }),
272 cx.observe_release(buffer, move |this, _buffer, _cx| {
273 this.registered_buffers.remove(&weak_buffer.entity_id());
274 }),
275 ],
276 });
277 };
278 }
279
280 fn handle_buffer_event(
281 &mut self,
282 buffer: Model<Buffer>,
283 event: &language::BufferEvent,
284 cx: &mut ModelContext<Self>,
285 ) {
286 match event {
287 language::BufferEvent::Edited => {
288 self.report_changes_for_buffer(&buffer, cx);
289 }
290 _ => {}
291 }
292 }
293
294 pub fn request_completion_impl<F, R>(
295 &mut self,
296 buffer: &Model<Buffer>,
297 position: language::Anchor,
298 cx: &mut ModelContext<Self>,
299 perform_predict_edits: F,
300 ) -> Task<Result<InlineCompletion>>
301 where
302 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
303 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
304 {
305 let snapshot = self.report_changes_for_buffer(buffer, cx);
306 let point = position.to_point(&snapshot);
307 let offset = point.to_offset(&snapshot);
308 let excerpt_range = excerpt_range_for_position(point, &snapshot);
309 let events = self.events.clone();
310 let path = snapshot
311 .file()
312 .map(|f| Arc::from(f.full_path(cx).as_path()))
313 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
314
315 let client = self.client.clone();
316 let llm_token = self.llm_token.clone();
317
318 cx.spawn(|_, cx| async move {
319 let request_sent_at = Instant::now();
320
321 let (input_events, input_excerpt, input_outline) = cx
322 .background_executor()
323 .spawn({
324 let snapshot = snapshot.clone();
325 let excerpt_range = excerpt_range.clone();
326 async move {
327 let mut input_events = String::new();
328 for event in events {
329 if !input_events.is_empty() {
330 input_events.push('\n');
331 input_events.push('\n');
332 }
333 input_events.push_str(&event.to_prompt());
334 }
335
336 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
337 let input_outline = prompt_for_outline(&snapshot);
338
339 (input_events, input_excerpt, input_outline)
340 }
341 })
342 .await;
343
344 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
345
346 let body = PredictEditsParams {
347 input_events: input_events.clone(),
348 input_excerpt: input_excerpt.clone(),
349 outline: Some(input_outline.clone()),
350 };
351
352 let response = perform_predict_edits(client, llm_token, body).await?;
353
354 let output_excerpt = response.output_excerpt;
355 log::debug!("completion response: {}", output_excerpt);
356
357 Self::process_completion_response(
358 output_excerpt,
359 &snapshot,
360 excerpt_range,
361 offset,
362 path,
363 input_outline,
364 input_events,
365 input_excerpt,
366 request_sent_at,
367 &cx,
368 )
369 .await
370 })
371 }
372
373 // Generates several example completions of various states to fill the Zeta completion modal
374 #[cfg(any(test, feature = "test-support"))]
375 pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
376 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
377 And maybe a short line
378
379 Then a few lines
380
381 and then another
382 "#};
383
384 let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
385 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
386
387 let completion_tasks = vec![
388 self.fake_completion(
389 &buffer,
390 position,
391 PredictEditsResponse {
392 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
393a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
394[here's an edit]
395And maybe a short line
396Then a few lines
397and then another
398{EDITABLE_REGION_END_MARKER}
399 ", ),
400 },
401 cx,
402 ),
403 self.fake_completion(
404 &buffer,
405 position,
406 PredictEditsResponse {
407 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
408a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
409And maybe a short line
410[and another edit]
411Then a few lines
412and then another
413{EDITABLE_REGION_END_MARKER}
414 "#),
415 },
416 cx,
417 ),
418 self.fake_completion(
419 &buffer,
420 position,
421 PredictEditsResponse {
422 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
423a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
424And maybe a short line
425
426Then a few lines
427
428and then another
429{EDITABLE_REGION_END_MARKER}
430 "#),
431 },
432 cx,
433 ),
434 self.fake_completion(
435 &buffer,
436 position,
437 PredictEditsResponse {
438 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
439a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
440And maybe a short line
441
442Then a few lines
443
444and then another
445{EDITABLE_REGION_END_MARKER}
446 "#),
447 },
448 cx,
449 ),
450 self.fake_completion(
451 &buffer,
452 position,
453 PredictEditsResponse {
454 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
455a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
456And maybe a short line
457Then a few lines
458[a third completion]
459and then another
460{EDITABLE_REGION_END_MARKER}
461 "#),
462 },
463 cx,
464 ),
465 self.fake_completion(
466 &buffer,
467 position,
468 PredictEditsResponse {
469 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
470a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
471And maybe a short line
472and then another
473[fourth completion example]
474{EDITABLE_REGION_END_MARKER}
475 "#),
476 },
477 cx,
478 ),
479 self.fake_completion(
480 &buffer,
481 position,
482 PredictEditsResponse {
483 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
484a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
485And maybe a short line
486Then a few lines
487and then another
488[fifth and final completion]
489{EDITABLE_REGION_END_MARKER}
490 "#),
491 },
492 cx,
493 ),
494 ];
495
496 cx.spawn(|zeta, mut cx| async move {
497 for task in completion_tasks {
498 task.await.unwrap();
499 }
500
501 zeta.update(&mut cx, |zeta, _cx| {
502 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
503 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
504 })
505 .ok();
506 })
507 }
508
509 #[cfg(any(test, feature = "test-support"))]
510 pub fn fake_completion(
511 &mut self,
512 buffer: &Model<Buffer>,
513 position: language::Anchor,
514 response: PredictEditsResponse,
515 cx: &mut ModelContext<Self>,
516 ) -> Task<Result<InlineCompletion>> {
517 use std::future::ready;
518
519 self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
520 }
521
522 pub fn request_completion(
523 &mut self,
524 buffer: &Model<Buffer>,
525 position: language::Anchor,
526 cx: &mut ModelContext<Self>,
527 ) -> Task<Result<InlineCompletion>> {
528 self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
529 }
530
531 fn perform_predict_edits(
532 client: Arc<Client>,
533 llm_token: LlmApiToken,
534 body: PredictEditsParams,
535 ) -> impl Future<Output = Result<PredictEditsResponse>> {
536 async move {
537 let http_client = client.http_client();
538 let mut token = llm_token.acquire(&client).await?;
539 let mut did_retry = false;
540
541 loop {
542 let request_builder = http_client::Request::builder();
543 let request = request_builder
544 .method(Method::POST)
545 .uri(
546 http_client
547 .build_zed_llm_url("/predict_edits", &[])?
548 .as_ref(),
549 )
550 .header("Content-Type", "application/json")
551 .header("Authorization", format!("Bearer {}", token))
552 .body(serde_json::to_string(&body)?.into())?;
553
554 let mut response = http_client.send(request).await?;
555
556 if response.status().is_success() {
557 let mut body = String::new();
558 response.body_mut().read_to_string(&mut body).await?;
559 return Ok(serde_json::from_str(&body)?);
560 } else if !did_retry
561 && response
562 .headers()
563 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
564 .is_some()
565 {
566 did_retry = true;
567 token = llm_token.refresh(&client).await?;
568 } else {
569 let mut body = String::new();
570 response.body_mut().read_to_string(&mut body).await?;
571 return Err(anyhow!(
572 "error predicting edits.\nStatus: {:?}\nBody: {}",
573 response.status(),
574 body
575 ));
576 }
577 }
578 }
579 }
580
581 #[allow(clippy::too_many_arguments)]
582 fn process_completion_response(
583 output_excerpt: String,
584 snapshot: &BufferSnapshot,
585 excerpt_range: Range<usize>,
586 cursor_offset: usize,
587 path: Arc<Path>,
588 input_outline: String,
589 input_events: String,
590 input_excerpt: String,
591 request_sent_at: Instant,
592 cx: &AsyncAppContext,
593 ) -> Task<Result<InlineCompletion>> {
594 let snapshot = snapshot.clone();
595 cx.background_executor().spawn(async move {
596 let content = output_excerpt.replace(CURSOR_MARKER, "");
597
598 let start_markers = content
599 .match_indices(EDITABLE_REGION_START_MARKER)
600 .collect::<Vec<_>>();
601 anyhow::ensure!(
602 start_markers.len() == 1,
603 "expected exactly one start marker, found {}",
604 start_markers.len()
605 );
606
607 let end_markers = content
608 .match_indices(EDITABLE_REGION_END_MARKER)
609 .collect::<Vec<_>>();
610 anyhow::ensure!(
611 end_markers.len() == 1,
612 "expected exactly one end marker, found {}",
613 end_markers.len()
614 );
615
616 let sof_markers = content
617 .match_indices(START_OF_FILE_MARKER)
618 .collect::<Vec<_>>();
619 anyhow::ensure!(
620 sof_markers.len() <= 1,
621 "expected at most one start-of-file marker, found {}",
622 sof_markers.len()
623 );
624
625 let codefence_start = start_markers[0].0;
626 let content = &content[codefence_start..];
627
628 let newline_ix = content.find('\n').context("could not find newline")?;
629 let content = &content[newline_ix + 1..];
630
631 let codefence_end = content
632 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
633 .context("could not find end marker")?;
634 let new_text = &content[..codefence_end];
635
636 let old_text = snapshot
637 .text_for_range(excerpt_range.clone())
638 .collect::<String>();
639
640 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
641
642 Ok(InlineCompletion {
643 id: InlineCompletionId::new(),
644 path,
645 excerpt_range,
646 cursor_offset,
647 edits: edits.into(),
648 snapshot: snapshot.clone(),
649 input_outline: input_outline.into(),
650 input_events: input_events.into(),
651 input_excerpt: input_excerpt.into(),
652 output_excerpt: output_excerpt.into(),
653 request_sent_at,
654 response_received_at: Instant::now(),
655 })
656 })
657 }
658
659 pub fn compute_edits(
660 old_text: String,
661 new_text: &str,
662 offset: usize,
663 snapshot: &BufferSnapshot,
664 ) -> Vec<(Range<Anchor>, String)> {
665 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
666
667 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
668 let mut old_start = offset;
669 for change in diff.iter_all_changes() {
670 let value = change.value();
671 match change.tag() {
672 similar::ChangeTag::Equal => {
673 old_start += value.len();
674 }
675 similar::ChangeTag::Delete => {
676 let old_end = old_start + value.len();
677 if let Some((last_old_range, _)) = edits.last_mut() {
678 if last_old_range.end == old_start {
679 last_old_range.end = old_end;
680 } else {
681 edits.push((old_start..old_end, String::new()));
682 }
683 } else {
684 edits.push((old_start..old_end, String::new()));
685 }
686 old_start = old_end;
687 }
688 similar::ChangeTag::Insert => {
689 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
690 if last_old_range.end == old_start {
691 last_new_text.push_str(value);
692 } else {
693 edits.push((old_start..old_start, value.into()));
694 }
695 } else {
696 edits.push((old_start..old_start, value.into()));
697 }
698 }
699 }
700 }
701
702 edits
703 .into_iter()
704 .map(|(mut old_range, new_text)| {
705 let prefix_len = common_prefix(
706 snapshot.chars_for_range(old_range.clone()),
707 new_text.chars(),
708 );
709 old_range.start += prefix_len;
710 let suffix_len = common_prefix(
711 snapshot.reversed_chars_for_range(old_range.clone()),
712 new_text[prefix_len..].chars().rev(),
713 );
714 old_range.end = old_range.end.saturating_sub(suffix_len);
715
716 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
717 (
718 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
719 new_text,
720 )
721 })
722 .collect()
723 }
724
725 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
726 self.rated_completions.contains(&completion_id)
727 }
728
729 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut ModelContext<Self>) {
730 self.shown_completions.push_front(completion.clone());
731 if self.shown_completions.len() > 50 {
732 let completion = self.shown_completions.pop_back().unwrap();
733 self.rated_completions.remove(&completion.id);
734 }
735 cx.notify();
736 }
737
738 pub fn rate_completion(
739 &mut self,
740 completion: &InlineCompletion,
741 rating: InlineCompletionRating,
742 feedback: String,
743 cx: &mut ModelContext<Self>,
744 ) {
745 self.rated_completions.insert(completion.id);
746 telemetry::event!(
747 "Inline Completion Rated",
748 rating,
749 input_events = completion.input_events,
750 input_excerpt = completion.input_excerpt,
751 input_outline = completion.input_outline,
752 output_excerpt = completion.output_excerpt,
753 feedback
754 );
755 self.client.telemetry().flush_events();
756 cx.notify();
757 }
758
759 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
760 self.shown_completions.iter()
761 }
762
763 pub fn shown_completions_len(&self) -> usize {
764 self.shown_completions.len()
765 }
766
767 fn report_changes_for_buffer(
768 &mut self,
769 buffer: &Model<Buffer>,
770 cx: &mut ModelContext<Self>,
771 ) -> BufferSnapshot {
772 self.register_buffer(buffer, cx);
773
774 let registered_buffer = self
775 .registered_buffers
776 .get_mut(&buffer.entity_id())
777 .unwrap();
778 let new_snapshot = buffer.read(cx).snapshot();
779
780 if new_snapshot.version != registered_buffer.snapshot.version {
781 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
782 self.push_event(Event::BufferChange {
783 old_snapshot,
784 new_snapshot: new_snapshot.clone(),
785 timestamp: Instant::now(),
786 });
787 }
788
789 new_snapshot
790 }
791}
792
793fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
794 a.zip(b)
795 .take_while(|(a, b)| a == b)
796 .map(|(a, _)| a.len_utf8())
797 .sum()
798}
799
800fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
801 let mut input_outline = String::new();
802
803 writeln!(
804 input_outline,
805 "```{}",
806 snapshot
807 .file()
808 .map_or(Cow::Borrowed("untitled"), |file| file
809 .path()
810 .to_string_lossy())
811 )
812 .unwrap();
813
814 if let Some(outline) = snapshot.outline(None) {
815 let guess_size = outline.items.len() * 15;
816 input_outline.reserve(guess_size);
817 for item in outline.items.iter() {
818 let spacing = " ".repeat(item.depth);
819 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
820 }
821 }
822
823 writeln!(input_outline, "```").unwrap();
824
825 input_outline
826}
827
828fn prompt_for_excerpt(
829 snapshot: &BufferSnapshot,
830 excerpt_range: &Range<usize>,
831 offset: usize,
832) -> String {
833 let mut prompt_excerpt = String::new();
834 writeln!(
835 prompt_excerpt,
836 "```{}",
837 snapshot
838 .file()
839 .map_or(Cow::Borrowed("untitled"), |file| file
840 .path()
841 .to_string_lossy())
842 )
843 .unwrap();
844
845 if excerpt_range.start == 0 {
846 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
847 }
848
849 let point_range = excerpt_range.to_point(snapshot);
850 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
851 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
852 for chunk in snapshot.text_for_range(extra_context_line_range) {
853 prompt_excerpt.push_str(chunk);
854 }
855 }
856 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
857 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
858 prompt_excerpt.push_str(chunk);
859 }
860 prompt_excerpt.push_str(CURSOR_MARKER);
861 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
862 prompt_excerpt.push_str(chunk);
863 }
864 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
865
866 if point_range.end.row < snapshot.max_point().row
867 && !snapshot.is_line_blank(point_range.end.row + 1)
868 {
869 let extra_context_line_range = point_range.end
870 ..Point::new(
871 point_range.end.row + 1,
872 snapshot.line_len(point_range.end.row + 1),
873 );
874 for chunk in snapshot.text_for_range(extra_context_line_range) {
875 prompt_excerpt.push_str(chunk);
876 }
877 }
878
879 write!(prompt_excerpt, "\n```").unwrap();
880 prompt_excerpt
881}
882
883fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
884 const CONTEXT_LINES: u32 = 32;
885
886 let mut context_lines_before = CONTEXT_LINES;
887 let mut context_lines_after = CONTEXT_LINES;
888 if point.row < CONTEXT_LINES {
889 context_lines_after += CONTEXT_LINES - point.row;
890 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
891 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
892 }
893
894 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
895 let excerpt_start = Point::new(excerpt_start_row, 0);
896 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
897 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
898 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
899}
900
901struct RegisteredBuffer {
902 snapshot: BufferSnapshot,
903 _subscriptions: [gpui::Subscription; 2],
904}
905
906#[derive(Clone)]
907enum Event {
908 BufferChange {
909 old_snapshot: BufferSnapshot,
910 new_snapshot: BufferSnapshot,
911 timestamp: Instant,
912 },
913}
914
915impl Event {
916 fn to_prompt(&self) -> String {
917 match self {
918 Event::BufferChange {
919 old_snapshot,
920 new_snapshot,
921 ..
922 } => {
923 let mut prompt = String::new();
924
925 let old_path = old_snapshot
926 .file()
927 .map(|f| f.path().as_ref())
928 .unwrap_or(Path::new("untitled"));
929 let new_path = new_snapshot
930 .file()
931 .map(|f| f.path().as_ref())
932 .unwrap_or(Path::new("untitled"));
933 if old_path != new_path {
934 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
935 }
936
937 let diff =
938 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
939 .unified_diff()
940 .to_string();
941 if !diff.is_empty() {
942 write!(
943 prompt,
944 "User edited {:?}:\n```diff\n{}\n```",
945 new_path, diff
946 )
947 .unwrap();
948 }
949
950 prompt
951 }
952 }
953 }
954}
955
956#[derive(Debug, Clone)]
957struct CurrentInlineCompletion {
958 buffer_id: EntityId,
959 completion: InlineCompletion,
960}
961
962impl CurrentInlineCompletion {
963 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
964 if self.buffer_id != old_completion.buffer_id {
965 return true;
966 }
967
968 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
969 return true;
970 };
971 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
972 return false;
973 };
974
975 if old_edits.len() == 1 && new_edits.len() == 1 {
976 let (old_range, old_text) = &old_edits[0];
977 let (new_range, new_text) = &new_edits[0];
978 new_range == old_range && new_text.starts_with(old_text)
979 } else {
980 true
981 }
982 }
983}
984
985struct PendingCompletion {
986 id: usize,
987 _task: Task<()>,
988}
989
990pub struct ZetaInlineCompletionProvider {
991 zeta: Model<Zeta>,
992 pending_completions: ArrayVec<PendingCompletion, 2>,
993 next_pending_completion_id: usize,
994 current_completion: Option<CurrentInlineCompletion>,
995}
996
997impl ZetaInlineCompletionProvider {
998 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
999
1000 pub fn new(zeta: Model<Zeta>) -> Self {
1001 Self {
1002 zeta,
1003 pending_completions: ArrayVec::new(),
1004 next_pending_completion_id: 0,
1005 current_completion: None,
1006 }
1007 }
1008}
1009
1010impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1011 fn name() -> &'static str {
1012 "zed-predict"
1013 }
1014
1015 fn display_name() -> &'static str {
1016 "Zed Predict"
1017 }
1018
1019 fn show_completions_in_menu() -> bool {
1020 true
1021 }
1022
1023 fn show_completions_in_normal_mode() -> bool {
1024 true
1025 }
1026
1027 fn show_tab_accept_marker() -> bool {
1028 true
1029 }
1030
1031 fn is_enabled(
1032 &self,
1033 buffer: &Model<Buffer>,
1034 cursor_position: language::Anchor,
1035 cx: &AppContext,
1036 ) -> bool {
1037 let buffer = buffer.read(cx);
1038 let file = buffer.file();
1039 let language = buffer.language_at(cursor_position);
1040 let settings = all_language_settings(file, cx);
1041 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1042 }
1043
1044 fn needs_terms_acceptance(&self, cx: &AppContext) -> bool {
1045 !self.zeta.read(cx).tos_accepted
1046 }
1047
1048 fn is_refreshing(&self) -> bool {
1049 !self.pending_completions.is_empty()
1050 }
1051
1052 fn refresh(
1053 &mut self,
1054 buffer: Model<Buffer>,
1055 position: language::Anchor,
1056 debounce: bool,
1057 cx: &mut ModelContext<Self>,
1058 ) {
1059 if !self.zeta.read(cx).tos_accepted {
1060 return;
1061 }
1062
1063 let pending_completion_id = self.next_pending_completion_id;
1064 self.next_pending_completion_id += 1;
1065
1066 let task = cx.spawn(|this, mut cx| async move {
1067 if debounce {
1068 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1069 }
1070
1071 let completion_request = this.update(&mut cx, |this, cx| {
1072 this.zeta.update(cx, |zeta, cx| {
1073 zeta.request_completion(&buffer, position, cx)
1074 })
1075 });
1076
1077 let completion = match completion_request {
1078 Ok(completion_request) => {
1079 let completion_request = completion_request.await;
1080 completion_request.map(|completion| CurrentInlineCompletion {
1081 buffer_id: buffer.entity_id(),
1082 completion,
1083 })
1084 }
1085 Err(error) => Err(error),
1086 };
1087
1088 this.update(&mut cx, |this, cx| {
1089 if this.pending_completions[0].id == pending_completion_id {
1090 this.pending_completions.remove(0);
1091 } else {
1092 this.pending_completions.clear();
1093 }
1094
1095 if let Some(new_completion) = completion.context("zeta prediction failed").log_err()
1096 {
1097 if let Some(old_completion) = this.current_completion.as_ref() {
1098 let snapshot = buffer.read(cx).snapshot();
1099 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1100 this.zeta.update(cx, |zeta, cx| {
1101 zeta.completion_shown(&new_completion.completion, cx);
1102 });
1103 this.current_completion = Some(new_completion);
1104 }
1105 } else {
1106 this.zeta.update(cx, |zeta, cx| {
1107 zeta.completion_shown(&new_completion.completion, cx);
1108 });
1109 this.current_completion = Some(new_completion);
1110 }
1111 }
1112
1113 cx.notify();
1114 })
1115 .ok();
1116 });
1117
1118 // We always maintain at most two pending completions. When we already
1119 // have two, we replace the newest one.
1120 if self.pending_completions.len() <= 1 {
1121 self.pending_completions.push(PendingCompletion {
1122 id: pending_completion_id,
1123 _task: task,
1124 });
1125 } else if self.pending_completions.len() == 2 {
1126 self.pending_completions.pop();
1127 self.pending_completions.push(PendingCompletion {
1128 id: pending_completion_id,
1129 _task: task,
1130 });
1131 }
1132 }
1133
1134 fn cycle(
1135 &mut self,
1136 _buffer: Model<Buffer>,
1137 _cursor_position: language::Anchor,
1138 _direction: inline_completion::Direction,
1139 _cx: &mut ModelContext<Self>,
1140 ) {
1141 // Right now we don't support cycling.
1142 }
1143
1144 fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1145 self.pending_completions.clear();
1146 }
1147
1148 fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1149 self.pending_completions.clear();
1150 self.current_completion.take();
1151 }
1152
1153 fn suggest(
1154 &mut self,
1155 buffer: &Model<Buffer>,
1156 cursor_position: language::Anchor,
1157 cx: &mut ModelContext<Self>,
1158 ) -> Option<inline_completion::InlineCompletion> {
1159 let CurrentInlineCompletion {
1160 buffer_id,
1161 completion,
1162 ..
1163 } = self.current_completion.as_mut()?;
1164
1165 // Invalidate previous completion if it was generated for a different buffer.
1166 if *buffer_id != buffer.entity_id() {
1167 self.current_completion.take();
1168 return None;
1169 }
1170
1171 let buffer = buffer.read(cx);
1172 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1173 self.current_completion.take();
1174 return None;
1175 };
1176
1177 let cursor_row = cursor_position.to_point(buffer).row;
1178 let (closest_edit_ix, (closest_edit_range, _)) =
1179 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1180 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1181 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1182 cmp::min(distance_from_start, distance_from_end)
1183 })?;
1184
1185 let mut edit_start_ix = closest_edit_ix;
1186 for (range, _) in edits[..edit_start_ix].iter().rev() {
1187 let distance_from_closest_edit =
1188 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1189 if distance_from_closest_edit <= 1 {
1190 edit_start_ix -= 1;
1191 } else {
1192 break;
1193 }
1194 }
1195
1196 let mut edit_end_ix = closest_edit_ix + 1;
1197 for (range, _) in &edits[edit_end_ix..] {
1198 let distance_from_closest_edit =
1199 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1200 if distance_from_closest_edit <= 1 {
1201 edit_end_ix += 1;
1202 } else {
1203 break;
1204 }
1205 }
1206
1207 Some(inline_completion::InlineCompletion {
1208 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1209 })
1210 }
1211}
1212
1213#[cfg(test)]
1214mod tests {
1215 use client::test::FakeServer;
1216 use clock::FakeSystemClock;
1217 use gpui::TestAppContext;
1218 use http_client::FakeHttpClient;
1219 use indoc::indoc;
1220 use language_models::RefreshLlmTokenListener;
1221 use rpc::proto;
1222 use settings::SettingsStore;
1223
1224 use super::*;
1225
1226 #[gpui::test]
1227 fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1228 let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1229 let completion = InlineCompletion {
1230 edits: to_completion_edits(
1231 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1232 &buffer,
1233 cx,
1234 )
1235 .into(),
1236 path: Path::new("").into(),
1237 snapshot: buffer.read(cx).snapshot(),
1238 id: InlineCompletionId::new(),
1239 excerpt_range: 0..0,
1240 cursor_offset: 0,
1241 input_outline: "".into(),
1242 input_events: "".into(),
1243 input_excerpt: "".into(),
1244 output_excerpt: "".into(),
1245 request_sent_at: Instant::now(),
1246 response_received_at: Instant::now(),
1247 };
1248
1249 assert_eq!(
1250 from_completion_edits(
1251 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1252 &buffer,
1253 cx
1254 ),
1255 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1256 );
1257
1258 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1259 assert_eq!(
1260 from_completion_edits(
1261 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1262 &buffer,
1263 cx
1264 ),
1265 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1266 );
1267
1268 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1269 assert_eq!(
1270 from_completion_edits(
1271 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1272 &buffer,
1273 cx
1274 ),
1275 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1276 );
1277
1278 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1279 assert_eq!(
1280 from_completion_edits(
1281 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1282 &buffer,
1283 cx
1284 ),
1285 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1286 );
1287
1288 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1289 assert_eq!(
1290 from_completion_edits(
1291 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1292 &buffer,
1293 cx
1294 ),
1295 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1296 );
1297
1298 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1299 assert_eq!(
1300 from_completion_edits(
1301 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1302 &buffer,
1303 cx
1304 ),
1305 vec![(9..11, "".to_string())]
1306 );
1307
1308 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1309 assert_eq!(
1310 from_completion_edits(
1311 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1312 &buffer,
1313 cx
1314 ),
1315 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1316 );
1317
1318 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1319 assert_eq!(
1320 from_completion_edits(
1321 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1322 &buffer,
1323 cx
1324 ),
1325 vec![(4..4, "M".to_string())]
1326 );
1327
1328 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1329 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1330 }
1331
1332 #[gpui::test]
1333 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1334 cx.update(|cx| {
1335 let settings_store = SettingsStore::test(cx);
1336 cx.set_global(settings_store);
1337 client::init_settings(cx);
1338 });
1339
1340 let buffer_content = "lorem\n";
1341 let completion_response = indoc! {"
1342 ```animals.js
1343 <|start_of_file|>
1344 <|editable_region_start|>
1345 lorem
1346 ipsum
1347 <|editable_region_end|>
1348 ```"};
1349
1350 let http_client = FakeHttpClient::create(move |_| async move {
1351 Ok(http_client::Response::builder()
1352 .status(200)
1353 .body(
1354 serde_json::to_string(&PredictEditsResponse {
1355 output_excerpt: completion_response.to_string(),
1356 })
1357 .unwrap()
1358 .into(),
1359 )
1360 .unwrap())
1361 });
1362
1363 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1364 cx.update(|cx| {
1365 RefreshLlmTokenListener::register(client.clone(), cx);
1366 });
1367 let server = FakeServer::for_client(42, &client, cx).await;
1368 let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
1369 let zeta = cx.new_model(|cx| Zeta::new(client, user_store, cx));
1370
1371 let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1372 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1373 let completion_task =
1374 zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1375
1376 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1377 server.respond(
1378 token_request.receipt(),
1379 proto::GetLlmTokenResponse { token: "".into() },
1380 );
1381
1382 let completion = completion_task.await.unwrap();
1383 buffer.update(cx, |buffer, cx| {
1384 buffer.edit(completion.edits.iter().cloned(), None, cx)
1385 });
1386 assert_eq!(
1387 buffer.read_with(cx, |buffer, _| buffer.text()),
1388 "lorem\nipsum"
1389 );
1390 }
1391
1392 fn to_completion_edits(
1393 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1394 buffer: &Model<Buffer>,
1395 cx: &AppContext,
1396 ) -> Vec<(Range<Anchor>, String)> {
1397 let buffer = buffer.read(cx);
1398 iterator
1399 .into_iter()
1400 .map(|(range, text)| {
1401 (
1402 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1403 text,
1404 )
1405 })
1406 .collect()
1407 }
1408
1409 fn from_completion_edits(
1410 editor_edits: &[(Range<Anchor>, String)],
1411 buffer: &Model<Buffer>,
1412 cx: &AppContext,
1413 ) -> Vec<(Range<usize>, String)> {
1414 let buffer = buffer.read(cx);
1415 editor_edits
1416 .iter()
1417 .map(|(range, text)| {
1418 (
1419 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1420 text.clone(),
1421 )
1422 })
1423 .collect()
1424 }
1425
1426 #[ctor::ctor]
1427 fn init_logger() {
1428 if std::env::var("RUST_LOG").is_ok() {
1429 env_logger::init();
1430 }
1431 }
1432}