1mod completion_diff_element;
2mod rate_completion_modal;
3
4pub(crate) use completion_diff_element::*;
5pub use rate_completion_modal::*;
6
7use anyhow::{anyhow, Context as _, Result};
8use arrayvec::ArrayVec;
9use client::Client;
10use collections::{HashMap, HashSet, VecDeque};
11use futures::AsyncReadExt;
12use gpui::{
13 actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
14 Subscription, Task,
15};
16use http_client::{HttpClient, Method};
17use language::{
18 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
19 Point, ToOffset, ToPoint,
20};
21use language_models::LlmApiToken;
22use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
23use std::{
24 borrow::Cow,
25 cmp,
26 fmt::Write,
27 future::Future,
28 mem,
29 ops::Range,
30 path::Path,
31 sync::Arc,
32 time::{Duration, Instant},
33};
34use telemetry_events::InlineCompletionRating;
35use util::ResultExt;
36use uuid::Uuid;
37
38const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
39const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
40const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
41const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
42const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
43
44actions!(zeta, [ClearHistory]);
45
46#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
47pub struct InlineCompletionId(Uuid);
48
49impl From<InlineCompletionId> for gpui::ElementId {
50 fn from(value: InlineCompletionId) -> Self {
51 gpui::ElementId::Uuid(value.0)
52 }
53}
54
55impl std::fmt::Display for InlineCompletionId {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{}", self.0)
58 }
59}
60
61impl InlineCompletionId {
62 fn new() -> Self {
63 Self(Uuid::new_v4())
64 }
65}
66
67#[derive(Clone)]
68struct ZetaGlobal(Model<Zeta>);
69
70impl Global for ZetaGlobal {}
71
72#[derive(Clone)]
73pub struct InlineCompletion {
74 id: InlineCompletionId,
75 path: Arc<Path>,
76 excerpt_range: Range<usize>,
77 cursor_offset: usize,
78 edits: Arc<[(Range<Anchor>, String)]>,
79 snapshot: BufferSnapshot,
80 input_outline: Arc<str>,
81 input_events: Arc<str>,
82 input_excerpt: Arc<str>,
83 output_excerpt: Arc<str>,
84 request_sent_at: Instant,
85 response_received_at: Instant,
86}
87
88impl InlineCompletion {
89 fn latency(&self) -> Duration {
90 self.response_received_at
91 .duration_since(self.request_sent_at)
92 }
93
94 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
95 let mut edits = Vec::new();
96
97 let mut user_edits = new_snapshot
98 .edits_since::<usize>(&self.snapshot.version)
99 .peekable();
100 for (model_old_range, model_new_text) in self.edits.iter() {
101 let model_offset_range = model_old_range.to_offset(&self.snapshot);
102 while let Some(next_user_edit) = user_edits.peek() {
103 if next_user_edit.old.end < model_offset_range.start {
104 user_edits.next();
105 } else {
106 break;
107 }
108 }
109
110 if let Some(user_edit) = user_edits.peek() {
111 if user_edit.old.start > model_offset_range.end {
112 edits.push((model_old_range.clone(), model_new_text.clone()));
113 } else if user_edit.old == model_offset_range {
114 let user_new_text = new_snapshot
115 .text_for_range(user_edit.new.clone())
116 .collect::<String>();
117
118 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
119 if !model_suffix.is_empty() {
120 edits.push((
121 new_snapshot.anchor_after(user_edit.new.end)
122 ..new_snapshot.anchor_before(user_edit.new.end),
123 model_suffix.into(),
124 ));
125 }
126
127 user_edits.next();
128 } else {
129 return None;
130 }
131 } else {
132 return None;
133 }
134 } else {
135 edits.push((model_old_range.clone(), model_new_text.clone()));
136 }
137 }
138
139 if edits.is_empty() {
140 None
141 } else {
142 Some(edits)
143 }
144 }
145}
146
147impl std::fmt::Debug for InlineCompletion {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("InlineCompletion")
150 .field("id", &self.id)
151 .field("path", &self.path)
152 .field("edits", &self.edits)
153 .finish_non_exhaustive()
154 }
155}
156
157pub struct Zeta {
158 client: Arc<Client>,
159 events: VecDeque<Event>,
160 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
161 shown_completions: VecDeque<InlineCompletion>,
162 rated_completions: HashSet<InlineCompletionId>,
163 llm_token: LlmApiToken,
164 _llm_token_subscription: Subscription,
165}
166
167impl Zeta {
168 pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
169 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
170 }
171
172 pub fn register(client: Arc<Client>, cx: &mut AppContext) -> Model<Self> {
173 Self::global(cx).unwrap_or_else(|| {
174 let model = cx.new_model(|cx| Self::new(client, cx));
175 cx.set_global(ZetaGlobal(model.clone()));
176 model
177 })
178 }
179
180 pub fn clear_history(&mut self) {
181 self.events.clear();
182 }
183
184 fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
185 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
186
187 Self {
188 client,
189 events: VecDeque::new(),
190 shown_completions: VecDeque::new(),
191 rated_completions: HashSet::default(),
192 registered_buffers: HashMap::default(),
193 llm_token: LlmApiToken::default(),
194 _llm_token_subscription: cx.subscribe(
195 &refresh_llm_token_listener,
196 |this, _listener, _event, cx| {
197 let client = this.client.clone();
198 let llm_token = this.llm_token.clone();
199 cx.spawn(|_this, _cx| async move {
200 llm_token.refresh(&client).await?;
201 anyhow::Ok(())
202 })
203 .detach_and_log_err(cx);
204 },
205 ),
206 }
207 }
208
209 fn push_event(&mut self, event: Event) {
210 const MAX_EVENT_COUNT: usize = 16;
211
212 if let Some(Event::BufferChange {
213 new_snapshot: last_new_snapshot,
214 timestamp: last_timestamp,
215 ..
216 }) = self.events.back_mut()
217 {
218 // Coalesce edits for the same buffer when they happen one after the other.
219 let Event::BufferChange {
220 old_snapshot,
221 new_snapshot,
222 timestamp,
223 } = &event;
224
225 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
226 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
227 && old_snapshot.version == last_new_snapshot.version
228 {
229 *last_new_snapshot = new_snapshot.clone();
230 *last_timestamp = *timestamp;
231 return;
232 }
233 }
234
235 self.events.push_back(event);
236 if self.events.len() >= MAX_EVENT_COUNT {
237 self.events.drain(..MAX_EVENT_COUNT / 2);
238 }
239 }
240
241 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
242 let buffer_id = buffer.entity_id();
243 let weak_buffer = buffer.downgrade();
244
245 if let std::collections::hash_map::Entry::Vacant(entry) =
246 self.registered_buffers.entry(buffer_id)
247 {
248 let snapshot = buffer.read(cx).snapshot();
249
250 entry.insert(RegisteredBuffer {
251 snapshot,
252 _subscriptions: [
253 cx.subscribe(buffer, move |this, buffer, event, cx| {
254 this.handle_buffer_event(buffer, event, cx);
255 }),
256 cx.observe_release(buffer, move |this, _buffer, _cx| {
257 this.registered_buffers.remove(&weak_buffer.entity_id());
258 }),
259 ],
260 });
261 };
262 }
263
264 fn handle_buffer_event(
265 &mut self,
266 buffer: Model<Buffer>,
267 event: &language::BufferEvent,
268 cx: &mut ModelContext<Self>,
269 ) {
270 match event {
271 language::BufferEvent::Edited => {
272 self.report_changes_for_buffer(&buffer, cx);
273 }
274 _ => {}
275 }
276 }
277
278 pub fn request_completion_impl<F, R>(
279 &mut self,
280 buffer: &Model<Buffer>,
281 position: language::Anchor,
282 cx: &mut ModelContext<Self>,
283 perform_predict_edits: F,
284 ) -> Task<Result<InlineCompletion>>
285 where
286 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
287 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
288 {
289 let snapshot = self.report_changes_for_buffer(buffer, cx);
290 let point = position.to_point(&snapshot);
291 let offset = point.to_offset(&snapshot);
292 let excerpt_range = excerpt_range_for_position(point, &snapshot);
293 let events = self.events.clone();
294 let path = snapshot
295 .file()
296 .map(|f| Arc::from(f.full_path(cx).as_path()))
297 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
298
299 let client = self.client.clone();
300 let llm_token = self.llm_token.clone();
301
302 cx.spawn(|_, cx| async move {
303 let request_sent_at = Instant::now();
304
305 let (input_events, input_excerpt, input_outline) = cx
306 .background_executor()
307 .spawn({
308 let snapshot = snapshot.clone();
309 let excerpt_range = excerpt_range.clone();
310 async move {
311 let mut input_events = String::new();
312 for event in events {
313 if !input_events.is_empty() {
314 input_events.push('\n');
315 input_events.push('\n');
316 }
317 input_events.push_str(&event.to_prompt());
318 }
319
320 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
321 let input_outline = prompt_for_outline(&snapshot);
322
323 (input_events, input_excerpt, input_outline)
324 }
325 })
326 .await;
327
328 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
329
330 let body = PredictEditsParams {
331 input_events: input_events.clone(),
332 input_excerpt: input_excerpt.clone(),
333 outline: Some(input_outline.clone()),
334 };
335
336 let response = perform_predict_edits(client, llm_token, body).await?;
337
338 let output_excerpt = response.output_excerpt;
339 log::debug!("completion response: {}", output_excerpt);
340
341 Self::process_completion_response(
342 output_excerpt,
343 &snapshot,
344 excerpt_range,
345 offset,
346 path,
347 input_outline,
348 input_events,
349 input_excerpt,
350 request_sent_at,
351 &cx,
352 )
353 .await
354 })
355 }
356
357 // Generates several example completions of various states to fill the Zeta completion modal
358 #[cfg(any(test, feature = "test-support"))]
359 pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
360 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
361 And maybe a short line
362
363 Then a few lines
364
365 and then another
366 "#};
367
368 let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
369 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
370
371 let completion_tasks = vec![
372 self.fake_completion(
373 &buffer,
374 position,
375 PredictEditsResponse {
376 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
377a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
378[here's an edit]
379And maybe a short line
380Then a few lines
381and then another
382{EDITABLE_REGION_END_MARKER}
383 ", ),
384 },
385 cx,
386 ),
387 self.fake_completion(
388 &buffer,
389 position,
390 PredictEditsResponse {
391 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
392a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
393And maybe a short line
394[and another edit]
395Then a few lines
396and then another
397{EDITABLE_REGION_END_MARKER}
398 "#),
399 },
400 cx,
401 ),
402 self.fake_completion(
403 &buffer,
404 position,
405 PredictEditsResponse {
406 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
407a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
408And maybe a short line
409
410Then a few lines
411
412and then another
413{EDITABLE_REGION_END_MARKER}
414 "#),
415 },
416 cx,
417 ),
418 self.fake_completion(
419 &buffer,
420 position,
421 PredictEditsResponse {
422 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
423a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
424And maybe a short line
425
426Then a few lines
427
428and then another
429{EDITABLE_REGION_END_MARKER}
430 "#),
431 },
432 cx,
433 ),
434 self.fake_completion(
435 &buffer,
436 position,
437 PredictEditsResponse {
438 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
439a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
440And maybe a short line
441Then a few lines
442[a third completion]
443and then another
444{EDITABLE_REGION_END_MARKER}
445 "#),
446 },
447 cx,
448 ),
449 self.fake_completion(
450 &buffer,
451 position,
452 PredictEditsResponse {
453 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
454a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
455And maybe a short line
456and then another
457[fourth completion example]
458{EDITABLE_REGION_END_MARKER}
459 "#),
460 },
461 cx,
462 ),
463 self.fake_completion(
464 &buffer,
465 position,
466 PredictEditsResponse {
467 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
468a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
469And maybe a short line
470Then a few lines
471and then another
472[fifth and final completion]
473{EDITABLE_REGION_END_MARKER}
474 "#),
475 },
476 cx,
477 ),
478 ];
479
480 cx.spawn(|zeta, mut cx| async move {
481 for task in completion_tasks {
482 task.await.unwrap();
483 }
484
485 zeta.update(&mut cx, |zeta, _cx| {
486 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
487 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
488 })
489 .ok();
490 })
491 }
492
493 #[cfg(any(test, feature = "test-support"))]
494 pub fn fake_completion(
495 &mut self,
496 buffer: &Model<Buffer>,
497 position: language::Anchor,
498 response: PredictEditsResponse,
499 cx: &mut ModelContext<Self>,
500 ) -> Task<Result<InlineCompletion>> {
501 use std::future::ready;
502
503 self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
504 }
505
506 pub fn request_completion(
507 &mut self,
508 buffer: &Model<Buffer>,
509 position: language::Anchor,
510 cx: &mut ModelContext<Self>,
511 ) -> Task<Result<InlineCompletion>> {
512 self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
513 }
514
515 fn perform_predict_edits(
516 client: Arc<Client>,
517 llm_token: LlmApiToken,
518 body: PredictEditsParams,
519 ) -> impl Future<Output = Result<PredictEditsResponse>> {
520 async move {
521 let http_client = client.http_client();
522 let mut token = llm_token.acquire(&client).await?;
523 let mut did_retry = false;
524
525 loop {
526 let request_builder = http_client::Request::builder();
527 let request = request_builder
528 .method(Method::POST)
529 .uri(
530 http_client
531 .build_zed_llm_url("/predict_edits", &[])?
532 .as_ref(),
533 )
534 .header("Content-Type", "application/json")
535 .header("Authorization", format!("Bearer {}", token))
536 .body(serde_json::to_string(&body)?.into())?;
537
538 let mut response = http_client.send(request).await?;
539
540 if response.status().is_success() {
541 let mut body = String::new();
542 response.body_mut().read_to_string(&mut body).await?;
543 return Ok(serde_json::from_str(&body)?);
544 } else if !did_retry
545 && response
546 .headers()
547 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
548 .is_some()
549 {
550 did_retry = true;
551 token = llm_token.refresh(&client).await?;
552 } else {
553 let mut body = String::new();
554 response.body_mut().read_to_string(&mut body).await?;
555 return Err(anyhow!(
556 "error predicting edits.\nStatus: {:?}\nBody: {}",
557 response.status(),
558 body
559 ));
560 }
561 }
562 }
563 }
564
565 #[allow(clippy::too_many_arguments)]
566 fn process_completion_response(
567 output_excerpt: String,
568 snapshot: &BufferSnapshot,
569 excerpt_range: Range<usize>,
570 cursor_offset: usize,
571 path: Arc<Path>,
572 input_outline: String,
573 input_events: String,
574 input_excerpt: String,
575 request_sent_at: Instant,
576 cx: &AsyncAppContext,
577 ) -> Task<Result<InlineCompletion>> {
578 let snapshot = snapshot.clone();
579 cx.background_executor().spawn(async move {
580 let content = output_excerpt.replace(CURSOR_MARKER, "");
581
582 let start_markers = content
583 .match_indices(EDITABLE_REGION_START_MARKER)
584 .collect::<Vec<_>>();
585 anyhow::ensure!(
586 start_markers.len() == 1,
587 "expected exactly one start marker, found {}",
588 start_markers.len()
589 );
590
591 let end_markers = content
592 .match_indices(EDITABLE_REGION_END_MARKER)
593 .collect::<Vec<_>>();
594 anyhow::ensure!(
595 end_markers.len() == 1,
596 "expected exactly one end marker, found {}",
597 end_markers.len()
598 );
599
600 let sof_markers = content
601 .match_indices(START_OF_FILE_MARKER)
602 .collect::<Vec<_>>();
603 anyhow::ensure!(
604 sof_markers.len() <= 1,
605 "expected at most one start-of-file marker, found {}",
606 sof_markers.len()
607 );
608
609 let codefence_start = start_markers[0].0;
610 let content = &content[codefence_start..];
611
612 let newline_ix = content.find('\n').context("could not find newline")?;
613 let content = &content[newline_ix + 1..];
614
615 let codefence_end = content
616 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
617 .context("could not find end marker")?;
618 let new_text = &content[..codefence_end];
619
620 let old_text = snapshot
621 .text_for_range(excerpt_range.clone())
622 .collect::<String>();
623
624 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
625
626 Ok(InlineCompletion {
627 id: InlineCompletionId::new(),
628 path,
629 excerpt_range,
630 cursor_offset,
631 edits: edits.into(),
632 snapshot: snapshot.clone(),
633 input_outline: input_outline.into(),
634 input_events: input_events.into(),
635 input_excerpt: input_excerpt.into(),
636 output_excerpt: output_excerpt.into(),
637 request_sent_at,
638 response_received_at: Instant::now(),
639 })
640 })
641 }
642
643 pub fn compute_edits(
644 old_text: String,
645 new_text: &str,
646 offset: usize,
647 snapshot: &BufferSnapshot,
648 ) -> Vec<(Range<Anchor>, String)> {
649 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
650
651 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
652 let mut old_start = offset;
653 for change in diff.iter_all_changes() {
654 let value = change.value();
655 match change.tag() {
656 similar::ChangeTag::Equal => {
657 old_start += value.len();
658 }
659 similar::ChangeTag::Delete => {
660 let old_end = old_start + value.len();
661 if let Some((last_old_range, _)) = edits.last_mut() {
662 if last_old_range.end == old_start {
663 last_old_range.end = old_end;
664 } else {
665 edits.push((old_start..old_end, String::new()));
666 }
667 } else {
668 edits.push((old_start..old_end, String::new()));
669 }
670 old_start = old_end;
671 }
672 similar::ChangeTag::Insert => {
673 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
674 if last_old_range.end == old_start {
675 last_new_text.push_str(value);
676 } else {
677 edits.push((old_start..old_start, value.into()));
678 }
679 } else {
680 edits.push((old_start..old_start, value.into()));
681 }
682 }
683 }
684 }
685
686 edits
687 .into_iter()
688 .map(|(mut old_range, new_text)| {
689 let prefix_len = common_prefix(
690 snapshot.chars_for_range(old_range.clone()),
691 new_text.chars(),
692 );
693 old_range.start += prefix_len;
694 let suffix_len = common_prefix(
695 snapshot.reversed_chars_for_range(old_range.clone()),
696 new_text[prefix_len..].chars().rev(),
697 );
698 old_range.end = old_range.end.saturating_sub(suffix_len);
699
700 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
701 (
702 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
703 new_text,
704 )
705 })
706 .collect()
707 }
708
709 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
710 self.rated_completions.contains(&completion_id)
711 }
712
713 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut ModelContext<Self>) {
714 self.shown_completions.push_front(completion.clone());
715 if self.shown_completions.len() > 50 {
716 let completion = self.shown_completions.pop_back().unwrap();
717 self.rated_completions.remove(&completion.id);
718 }
719 cx.notify();
720 }
721
722 pub fn rate_completion(
723 &mut self,
724 completion: &InlineCompletion,
725 rating: InlineCompletionRating,
726 feedback: String,
727 cx: &mut ModelContext<Self>,
728 ) {
729 self.rated_completions.insert(completion.id);
730 telemetry::event!(
731 "Inline Completion Rated",
732 rating,
733 input_events = completion.input_events,
734 input_excerpt = completion.input_excerpt,
735 input_outline = completion.input_outline,
736 output_excerpt = completion.output_excerpt,
737 feedback
738 );
739 self.client.telemetry().flush_events();
740 cx.notify();
741 }
742
743 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
744 self.shown_completions.iter()
745 }
746
747 pub fn shown_completions_len(&self) -> usize {
748 self.shown_completions.len()
749 }
750
751 fn report_changes_for_buffer(
752 &mut self,
753 buffer: &Model<Buffer>,
754 cx: &mut ModelContext<Self>,
755 ) -> BufferSnapshot {
756 self.register_buffer(buffer, cx);
757
758 let registered_buffer = self
759 .registered_buffers
760 .get_mut(&buffer.entity_id())
761 .unwrap();
762 let new_snapshot = buffer.read(cx).snapshot();
763
764 if new_snapshot.version != registered_buffer.snapshot.version {
765 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
766 self.push_event(Event::BufferChange {
767 old_snapshot,
768 new_snapshot: new_snapshot.clone(),
769 timestamp: Instant::now(),
770 });
771 }
772
773 new_snapshot
774 }
775}
776
777fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
778 a.zip(b)
779 .take_while(|(a, b)| a == b)
780 .map(|(a, _)| a.len_utf8())
781 .sum()
782}
783
784fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
785 let mut input_outline = String::new();
786
787 writeln!(
788 input_outline,
789 "```{}",
790 snapshot
791 .file()
792 .map_or(Cow::Borrowed("untitled"), |file| file
793 .path()
794 .to_string_lossy())
795 )
796 .unwrap();
797
798 if let Some(outline) = snapshot.outline(None) {
799 let guess_size = outline.items.len() * 15;
800 input_outline.reserve(guess_size);
801 for item in outline.items.iter() {
802 let spacing = " ".repeat(item.depth);
803 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
804 }
805 }
806
807 writeln!(input_outline, "```").unwrap();
808
809 input_outline
810}
811
812fn prompt_for_excerpt(
813 snapshot: &BufferSnapshot,
814 excerpt_range: &Range<usize>,
815 offset: usize,
816) -> String {
817 let mut prompt_excerpt = String::new();
818 writeln!(
819 prompt_excerpt,
820 "```{}",
821 snapshot
822 .file()
823 .map_or(Cow::Borrowed("untitled"), |file| file
824 .path()
825 .to_string_lossy())
826 )
827 .unwrap();
828
829 if excerpt_range.start == 0 {
830 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
831 }
832
833 let point_range = excerpt_range.to_point(snapshot);
834 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
835 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
836 for chunk in snapshot.text_for_range(extra_context_line_range) {
837 prompt_excerpt.push_str(chunk);
838 }
839 }
840 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
841 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
842 prompt_excerpt.push_str(chunk);
843 }
844 prompt_excerpt.push_str(CURSOR_MARKER);
845 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
846 prompt_excerpt.push_str(chunk);
847 }
848 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
849
850 if point_range.end.row < snapshot.max_point().row
851 && !snapshot.is_line_blank(point_range.end.row + 1)
852 {
853 let extra_context_line_range = point_range.end
854 ..Point::new(
855 point_range.end.row + 1,
856 snapshot.line_len(point_range.end.row + 1),
857 );
858 for chunk in snapshot.text_for_range(extra_context_line_range) {
859 prompt_excerpt.push_str(chunk);
860 }
861 }
862
863 write!(prompt_excerpt, "\n```").unwrap();
864 prompt_excerpt
865}
866
867fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
868 const CONTEXT_LINES: u32 = 32;
869
870 let mut context_lines_before = CONTEXT_LINES;
871 let mut context_lines_after = CONTEXT_LINES;
872 if point.row < CONTEXT_LINES {
873 context_lines_after += CONTEXT_LINES - point.row;
874 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
875 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
876 }
877
878 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
879 let excerpt_start = Point::new(excerpt_start_row, 0);
880 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
881 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
882 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
883}
884
885struct RegisteredBuffer {
886 snapshot: BufferSnapshot,
887 _subscriptions: [gpui::Subscription; 2],
888}
889
890#[derive(Clone)]
891enum Event {
892 BufferChange {
893 old_snapshot: BufferSnapshot,
894 new_snapshot: BufferSnapshot,
895 timestamp: Instant,
896 },
897}
898
899impl Event {
900 fn to_prompt(&self) -> String {
901 match self {
902 Event::BufferChange {
903 old_snapshot,
904 new_snapshot,
905 ..
906 } => {
907 let mut prompt = String::new();
908
909 let old_path = old_snapshot
910 .file()
911 .map(|f| f.path().as_ref())
912 .unwrap_or(Path::new("untitled"));
913 let new_path = new_snapshot
914 .file()
915 .map(|f| f.path().as_ref())
916 .unwrap_or(Path::new("untitled"));
917 if old_path != new_path {
918 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
919 }
920
921 let diff =
922 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
923 .unified_diff()
924 .to_string();
925 if !diff.is_empty() {
926 write!(
927 prompt,
928 "User edited {:?}:\n```diff\n{}\n```",
929 new_path, diff
930 )
931 .unwrap();
932 }
933
934 prompt
935 }
936 }
937 }
938}
939
940#[derive(Debug, Clone)]
941struct CurrentInlineCompletion {
942 buffer_id: EntityId,
943 completion: InlineCompletion,
944}
945
946impl CurrentInlineCompletion {
947 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
948 if self.buffer_id != old_completion.buffer_id {
949 return true;
950 }
951
952 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
953 return true;
954 };
955 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
956 return false;
957 };
958
959 if old_edits.len() == 1 && new_edits.len() == 1 {
960 let (old_range, old_text) = &old_edits[0];
961 let (new_range, new_text) = &new_edits[0];
962 new_range == old_range && new_text.starts_with(old_text)
963 } else {
964 true
965 }
966 }
967}
968
969struct PendingCompletion {
970 id: usize,
971 _task: Task<()>,
972}
973
974pub struct ZetaInlineCompletionProvider {
975 zeta: Model<Zeta>,
976 pending_completions: ArrayVec<PendingCompletion, 2>,
977 next_pending_completion_id: usize,
978 current_completion: Option<CurrentInlineCompletion>,
979}
980
981impl ZetaInlineCompletionProvider {
982 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
983
984 pub fn new(zeta: Model<Zeta>) -> Self {
985 Self {
986 zeta,
987 pending_completions: ArrayVec::new(),
988 next_pending_completion_id: 0,
989 current_completion: None,
990 }
991 }
992}
993
994impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
995 fn name() -> &'static str {
996 "zeta"
997 }
998
999 fn display_name() -> &'static str {
1000 "Zeta"
1001 }
1002
1003 fn show_completions_in_menu() -> bool {
1004 true
1005 }
1006
1007 fn show_completions_in_normal_mode() -> bool {
1008 true
1009 }
1010
1011 fn is_enabled(
1012 &self,
1013 buffer: &Model<Buffer>,
1014 cursor_position: language::Anchor,
1015 cx: &AppContext,
1016 ) -> bool {
1017 let buffer = buffer.read(cx);
1018 let file = buffer.file();
1019 let language = buffer.language_at(cursor_position);
1020 let settings = all_language_settings(file, cx);
1021 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1022 }
1023
1024 fn is_refreshing(&self) -> bool {
1025 !self.pending_completions.is_empty()
1026 }
1027
1028 fn refresh(
1029 &mut self,
1030 buffer: Model<Buffer>,
1031 position: language::Anchor,
1032 debounce: bool,
1033 cx: &mut ModelContext<Self>,
1034 ) {
1035 let pending_completion_id = self.next_pending_completion_id;
1036 self.next_pending_completion_id += 1;
1037
1038 let task = cx.spawn(|this, mut cx| async move {
1039 if debounce {
1040 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1041 }
1042
1043 let completion_request = this.update(&mut cx, |this, cx| {
1044 this.zeta.update(cx, |zeta, cx| {
1045 zeta.request_completion(&buffer, position, cx)
1046 })
1047 });
1048
1049 let completion = match completion_request {
1050 Ok(completion_request) => {
1051 let completion_request = completion_request.await;
1052 completion_request.map(|completion| CurrentInlineCompletion {
1053 buffer_id: buffer.entity_id(),
1054 completion,
1055 })
1056 }
1057 Err(error) => Err(error),
1058 };
1059
1060 this.update(&mut cx, |this, cx| {
1061 if this.pending_completions[0].id == pending_completion_id {
1062 this.pending_completions.remove(0);
1063 } else {
1064 this.pending_completions.clear();
1065 }
1066
1067 if let Some(new_completion) = completion.context("zeta prediction failed").log_err()
1068 {
1069 if let Some(old_completion) = this.current_completion.as_ref() {
1070 let snapshot = buffer.read(cx).snapshot();
1071 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1072 this.zeta.update(cx, |zeta, cx| {
1073 zeta.completion_shown(&new_completion.completion, cx);
1074 });
1075 this.current_completion = Some(new_completion);
1076 }
1077 } else {
1078 this.zeta.update(cx, |zeta, cx| {
1079 zeta.completion_shown(&new_completion.completion, cx);
1080 });
1081 this.current_completion = Some(new_completion);
1082 }
1083 }
1084
1085 cx.notify();
1086 })
1087 .ok();
1088 });
1089
1090 // We always maintain at most two pending completions. When we already
1091 // have two, we replace the newest one.
1092 if self.pending_completions.len() <= 1 {
1093 self.pending_completions.push(PendingCompletion {
1094 id: pending_completion_id,
1095 _task: task,
1096 });
1097 } else if self.pending_completions.len() == 2 {
1098 self.pending_completions.pop();
1099 self.pending_completions.push(PendingCompletion {
1100 id: pending_completion_id,
1101 _task: task,
1102 });
1103 }
1104 }
1105
1106 fn cycle(
1107 &mut self,
1108 _buffer: Model<Buffer>,
1109 _cursor_position: language::Anchor,
1110 _direction: inline_completion::Direction,
1111 _cx: &mut ModelContext<Self>,
1112 ) {
1113 // Right now we don't support cycling.
1114 }
1115
1116 fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1117 self.pending_completions.clear();
1118 }
1119
1120 fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1121 self.pending_completions.clear();
1122 self.current_completion.take();
1123 }
1124
1125 fn suggest(
1126 &mut self,
1127 buffer: &Model<Buffer>,
1128 cursor_position: language::Anchor,
1129 cx: &mut ModelContext<Self>,
1130 ) -> Option<inline_completion::InlineCompletion> {
1131 let CurrentInlineCompletion {
1132 buffer_id,
1133 completion,
1134 ..
1135 } = self.current_completion.as_mut()?;
1136
1137 // Invalidate previous completion if it was generated for a different buffer.
1138 if *buffer_id != buffer.entity_id() {
1139 self.current_completion.take();
1140 return None;
1141 }
1142
1143 let buffer = buffer.read(cx);
1144 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1145 self.current_completion.take();
1146 return None;
1147 };
1148
1149 let cursor_row = cursor_position.to_point(buffer).row;
1150 let (closest_edit_ix, (closest_edit_range, _)) =
1151 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1152 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1153 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1154 cmp::min(distance_from_start, distance_from_end)
1155 })?;
1156
1157 let mut edit_start_ix = closest_edit_ix;
1158 for (range, _) in edits[..edit_start_ix].iter().rev() {
1159 let distance_from_closest_edit =
1160 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1161 if distance_from_closest_edit <= 1 {
1162 edit_start_ix -= 1;
1163 } else {
1164 break;
1165 }
1166 }
1167
1168 let mut edit_end_ix = closest_edit_ix + 1;
1169 for (range, _) in &edits[edit_end_ix..] {
1170 let distance_from_closest_edit =
1171 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1172 if distance_from_closest_edit <= 1 {
1173 edit_end_ix += 1;
1174 } else {
1175 break;
1176 }
1177 }
1178
1179 Some(inline_completion::InlineCompletion {
1180 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1181 })
1182 }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187 use client::test::FakeServer;
1188 use clock::FakeSystemClock;
1189 use gpui::TestAppContext;
1190 use http_client::FakeHttpClient;
1191 use indoc::indoc;
1192 use language_models::RefreshLlmTokenListener;
1193 use rpc::proto;
1194 use settings::SettingsStore;
1195
1196 use super::*;
1197
1198 #[gpui::test]
1199 fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1200 let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1201 let completion = InlineCompletion {
1202 edits: to_completion_edits(
1203 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1204 &buffer,
1205 cx,
1206 )
1207 .into(),
1208 path: Path::new("").into(),
1209 snapshot: buffer.read(cx).snapshot(),
1210 id: InlineCompletionId::new(),
1211 excerpt_range: 0..0,
1212 cursor_offset: 0,
1213 input_outline: "".into(),
1214 input_events: "".into(),
1215 input_excerpt: "".into(),
1216 output_excerpt: "".into(),
1217 request_sent_at: Instant::now(),
1218 response_received_at: Instant::now(),
1219 };
1220
1221 assert_eq!(
1222 from_completion_edits(
1223 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1224 &buffer,
1225 cx
1226 ),
1227 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1228 );
1229
1230 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1231 assert_eq!(
1232 from_completion_edits(
1233 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1234 &buffer,
1235 cx
1236 ),
1237 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1238 );
1239
1240 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1241 assert_eq!(
1242 from_completion_edits(
1243 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1244 &buffer,
1245 cx
1246 ),
1247 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1248 );
1249
1250 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1251 assert_eq!(
1252 from_completion_edits(
1253 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1254 &buffer,
1255 cx
1256 ),
1257 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1258 );
1259
1260 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1261 assert_eq!(
1262 from_completion_edits(
1263 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1264 &buffer,
1265 cx
1266 ),
1267 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1268 );
1269
1270 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1271 assert_eq!(
1272 from_completion_edits(
1273 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1274 &buffer,
1275 cx
1276 ),
1277 vec![(9..11, "".to_string())]
1278 );
1279
1280 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1281 assert_eq!(
1282 from_completion_edits(
1283 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1284 &buffer,
1285 cx
1286 ),
1287 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1288 );
1289
1290 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1291 assert_eq!(
1292 from_completion_edits(
1293 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1294 &buffer,
1295 cx
1296 ),
1297 vec![(4..4, "M".to_string())]
1298 );
1299
1300 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1301 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1302 }
1303
1304 #[gpui::test]
1305 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1306 cx.update(|cx| {
1307 let settings_store = SettingsStore::test(cx);
1308 cx.set_global(settings_store);
1309 client::init_settings(cx);
1310 });
1311
1312 let buffer_content = "lorem\n";
1313 let completion_response = indoc! {"
1314 ```animals.js
1315 <|start_of_file|>
1316 <|editable_region_start|>
1317 lorem
1318 ipsum
1319 <|editable_region_end|>
1320 ```"};
1321
1322 let http_client = FakeHttpClient::create(move |_| async move {
1323 Ok(http_client::Response::builder()
1324 .status(200)
1325 .body(
1326 serde_json::to_string(&PredictEditsResponse {
1327 output_excerpt: completion_response.to_string(),
1328 })
1329 .unwrap()
1330 .into(),
1331 )
1332 .unwrap())
1333 });
1334
1335 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1336 cx.update(|cx| {
1337 RefreshLlmTokenListener::register(client.clone(), cx);
1338 });
1339 let server = FakeServer::for_client(42, &client, cx).await;
1340
1341 let zeta = cx.new_model(|cx| Zeta::new(client, cx));
1342 let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1343 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1344 let completion_task =
1345 zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1346
1347 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1348 server.respond(
1349 token_request.receipt(),
1350 proto::GetLlmTokenResponse { token: "".into() },
1351 );
1352
1353 let completion = completion_task.await.unwrap();
1354 buffer.update(cx, |buffer, cx| {
1355 buffer.edit(completion.edits.iter().cloned(), None, cx)
1356 });
1357 assert_eq!(
1358 buffer.read_with(cx, |buffer, _| buffer.text()),
1359 "lorem\nipsum"
1360 );
1361 }
1362
1363 fn to_completion_edits(
1364 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1365 buffer: &Model<Buffer>,
1366 cx: &AppContext,
1367 ) -> Vec<(Range<Anchor>, String)> {
1368 let buffer = buffer.read(cx);
1369 iterator
1370 .into_iter()
1371 .map(|(range, text)| {
1372 (
1373 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1374 text,
1375 )
1376 })
1377 .collect()
1378 }
1379
1380 fn from_completion_edits(
1381 editor_edits: &[(Range<Anchor>, String)],
1382 buffer: &Model<Buffer>,
1383 cx: &AppContext,
1384 ) -> Vec<(Range<usize>, String)> {
1385 let buffer = buffer.read(cx);
1386 editor_edits
1387 .iter()
1388 .map(|(range, text)| {
1389 (
1390 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1391 text.clone(),
1392 )
1393 })
1394 .collect()
1395 }
1396
1397 #[ctor::ctor]
1398 fn init_logger() {
1399 if std::env::var("RUST_LOG").is_ok() {
1400 env_logger::init();
1401 }
1402 }
1403}