1mod rate_completion_modal;
2
3pub use rate_completion_modal::*;
4
5use anyhow::{anyhow, Context as _, Result};
6use arrayvec::ArrayVec;
7use client::Client;
8use collections::{HashMap, HashSet, VecDeque};
9use futures::AsyncReadExt;
10use gpui::{
11 actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
12 Subscription, Task,
13};
14use http_client::{HttpClient, Method};
15use language::{
16 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
17 Point, ToOffset, ToPoint,
18};
19use language_models::LlmApiToken;
20use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
21use std::{
22 borrow::Cow,
23 cmp,
24 fmt::Write,
25 future::Future,
26 mem,
27 ops::Range,
28 path::Path,
29 sync::Arc,
30 time::{Duration, Instant},
31};
32use telemetry_events::InlineCompletionRating;
33use uuid::Uuid;
34
35const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
36const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
37const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
38const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
40
41actions!(zeta, [ClearHistory]);
42
43#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
44pub struct InlineCompletionId(Uuid);
45
46impl From<InlineCompletionId> for gpui::ElementId {
47 fn from(value: InlineCompletionId) -> Self {
48 gpui::ElementId::Uuid(value.0)
49 }
50}
51
52impl std::fmt::Display for InlineCompletionId {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58impl InlineCompletionId {
59 fn new() -> Self {
60 Self(Uuid::new_v4())
61 }
62}
63
64#[derive(Clone)]
65struct ZetaGlobal(Model<Zeta>);
66
67impl Global for ZetaGlobal {}
68
69#[derive(Clone)]
70pub struct InlineCompletion {
71 id: InlineCompletionId,
72 path: Arc<Path>,
73 excerpt_range: Range<usize>,
74 edits: Arc<[(Range<Anchor>, String)]>,
75 snapshot: BufferSnapshot,
76 input_events: Arc<str>,
77 input_excerpt: Arc<str>,
78 output_excerpt: Arc<str>,
79 request_sent_at: Instant,
80 response_received_at: Instant,
81}
82
83impl InlineCompletion {
84 fn latency(&self) -> Duration {
85 self.response_received_at
86 .duration_since(self.request_sent_at)
87 }
88
89 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
90 let mut edits = Vec::new();
91
92 let mut user_edits = new_snapshot
93 .edits_since::<usize>(&self.snapshot.version)
94 .peekable();
95 for (model_old_range, model_new_text) in self.edits.iter() {
96 let model_offset_range = model_old_range.to_offset(&self.snapshot);
97 while let Some(next_user_edit) = user_edits.peek() {
98 if next_user_edit.old.end < model_offset_range.start {
99 user_edits.next();
100 } else {
101 break;
102 }
103 }
104
105 if let Some(user_edit) = user_edits.peek() {
106 if user_edit.old.start > model_offset_range.end {
107 edits.push((model_old_range.clone(), model_new_text.clone()));
108 } else if user_edit.old == model_offset_range {
109 let user_new_text = new_snapshot
110 .text_for_range(user_edit.new.clone())
111 .collect::<String>();
112
113 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
114 if !model_suffix.is_empty() {
115 edits.push((
116 new_snapshot.anchor_after(user_edit.new.end)
117 ..new_snapshot.anchor_before(user_edit.new.end),
118 model_suffix.into(),
119 ));
120 }
121
122 user_edits.next();
123 } else {
124 return None;
125 }
126 } else {
127 return None;
128 }
129 } else {
130 edits.push((model_old_range.clone(), model_new_text.clone()));
131 }
132 }
133
134 if edits.is_empty() {
135 None
136 } else {
137 Some(edits)
138 }
139 }
140}
141
142impl std::fmt::Debug for InlineCompletion {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 f.debug_struct("InlineCompletion")
145 .field("id", &self.id)
146 .field("path", &self.path)
147 .field("edits", &self.edits)
148 .finish_non_exhaustive()
149 }
150}
151
152pub struct Zeta {
153 client: Arc<Client>,
154 events: VecDeque<Event>,
155 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
156 recent_completions: VecDeque<InlineCompletion>,
157 rated_completions: HashSet<InlineCompletionId>,
158 shown_completions: HashSet<InlineCompletionId>,
159 llm_token: LlmApiToken,
160 _llm_token_subscription: Subscription,
161}
162
163impl Zeta {
164 pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
165 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
166 }
167
168 pub fn register(client: Arc<Client>, cx: &mut AppContext) -> Model<Self> {
169 Self::global(cx).unwrap_or_else(|| {
170 let model = cx.new_model(|cx| Self::new(client, cx));
171 cx.set_global(ZetaGlobal(model.clone()));
172 model
173 })
174 }
175
176 pub fn clear_history(&mut self) {
177 self.events.clear();
178 }
179
180 fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
181 let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
182
183 Self {
184 client,
185 events: VecDeque::new(),
186 recent_completions: VecDeque::new(),
187 rated_completions: HashSet::default(),
188 shown_completions: HashSet::default(),
189 registered_buffers: HashMap::default(),
190 llm_token: LlmApiToken::default(),
191 _llm_token_subscription: cx.subscribe(
192 &refresh_llm_token_listener,
193 |this, _listener, _event, cx| {
194 let client = this.client.clone();
195 let llm_token = this.llm_token.clone();
196 cx.spawn(|_this, _cx| async move {
197 llm_token.refresh(&client).await?;
198 anyhow::Ok(())
199 })
200 .detach_and_log_err(cx);
201 },
202 ),
203 }
204 }
205
206 fn push_event(&mut self, event: Event) {
207 if let Some(Event::BufferChange {
208 new_snapshot: last_new_snapshot,
209 timestamp: last_timestamp,
210 ..
211 }) = self.events.back_mut()
212 {
213 // Coalesce edits for the same buffer when they happen one after the other.
214 let Event::BufferChange {
215 old_snapshot,
216 new_snapshot,
217 timestamp,
218 } = &event;
219
220 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
221 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
222 && old_snapshot.version == last_new_snapshot.version
223 {
224 *last_new_snapshot = new_snapshot.clone();
225 *last_timestamp = *timestamp;
226 return;
227 }
228 }
229
230 self.events.push_back(event);
231 if self.events.len() > 10 {
232 self.events.pop_front();
233 }
234 }
235
236 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
237 let buffer_id = buffer.entity_id();
238 let weak_buffer = buffer.downgrade();
239
240 if let std::collections::hash_map::Entry::Vacant(entry) =
241 self.registered_buffers.entry(buffer_id)
242 {
243 let snapshot = buffer.read(cx).snapshot();
244
245 entry.insert(RegisteredBuffer {
246 snapshot,
247 _subscriptions: [
248 cx.subscribe(buffer, move |this, buffer, event, cx| {
249 this.handle_buffer_event(buffer, event, cx);
250 }),
251 cx.observe_release(buffer, move |this, _buffer, _cx| {
252 this.registered_buffers.remove(&weak_buffer.entity_id());
253 }),
254 ],
255 });
256 };
257 }
258
259 fn handle_buffer_event(
260 &mut self,
261 buffer: Model<Buffer>,
262 event: &language::BufferEvent,
263 cx: &mut ModelContext<Self>,
264 ) {
265 match event {
266 language::BufferEvent::Edited => {
267 self.report_changes_for_buffer(&buffer, cx);
268 }
269 _ => {}
270 }
271 }
272
273 pub fn request_completion_impl<F, R>(
274 &mut self,
275 buffer: &Model<Buffer>,
276 position: language::Anchor,
277 cx: &mut ModelContext<Self>,
278 perform_predict_edits: F,
279 ) -> Task<Result<InlineCompletion>>
280 where
281 F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
282 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
283 {
284 let snapshot = self.report_changes_for_buffer(buffer, cx);
285 let point = position.to_point(&snapshot);
286 let offset = point.to_offset(&snapshot);
287 let excerpt_range = excerpt_range_for_position(point, &snapshot);
288 let events = self.events.clone();
289 let path = snapshot
290 .file()
291 .map(|f| f.path().clone())
292 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
293
294 let client = self.client.clone();
295 let llm_token = self.llm_token.clone();
296
297 cx.spawn(|this, mut cx| async move {
298 let request_sent_at = Instant::now();
299
300 let mut input_events = String::new();
301 for event in events {
302 if !input_events.is_empty() {
303 input_events.push('\n');
304 input_events.push('\n');
305 }
306 input_events.push_str(&event.to_prompt());
307 }
308 let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
309
310 log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
311
312 let body = PredictEditsParams {
313 input_events: input_events.clone(),
314 input_excerpt: input_excerpt.clone(),
315 };
316
317 let response = perform_predict_edits(client, llm_token, body).await?;
318
319 let output_excerpt = response.output_excerpt;
320 log::debug!("completion response: {}", output_excerpt);
321
322 let inline_completion = Self::process_completion_response(
323 output_excerpt,
324 &snapshot,
325 excerpt_range,
326 path,
327 input_events,
328 input_excerpt,
329 request_sent_at,
330 &cx,
331 )
332 .await?;
333
334 this.update(&mut cx, |this, cx| {
335 this.recent_completions
336 .push_front(inline_completion.clone());
337 if this.recent_completions.len() > 50 {
338 let completion = this.recent_completions.pop_back().unwrap();
339 this.shown_completions.remove(&completion.id);
340 this.rated_completions.remove(&completion.id);
341 }
342 cx.notify();
343 })?;
344
345 Ok(inline_completion)
346 })
347 }
348
349 // Generates several example completions of various states to fill the Zeta completion modal
350 #[cfg(any(test, feature = "test-support"))]
351 pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
352 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
353 And maybe a short line
354
355 Then a few lines
356
357 and then another
358 "#};
359
360 let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
361 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
362
363 let completion_tasks = vec![
364 self.fake_completion(
365 &buffer,
366 position,
367 PredictEditsResponse {
368 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
369a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
370[here's an edit]
371And maybe a short line
372Then a few lines
373and then another
374{EDITABLE_REGION_END_MARKER}
375 ", ),
376 },
377 cx,
378 ),
379 self.fake_completion(
380 &buffer,
381 position,
382 PredictEditsResponse {
383 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
384a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
385And maybe a short line
386[and another edit]
387Then a few lines
388and then another
389{EDITABLE_REGION_END_MARKER}
390 "#),
391 },
392 cx,
393 ),
394 self.fake_completion(
395 &buffer,
396 position,
397 PredictEditsResponse {
398 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
399a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
400And maybe a short line
401
402Then a few lines
403
404and then another
405{EDITABLE_REGION_END_MARKER}
406 "#),
407 },
408 cx,
409 ),
410 self.fake_completion(
411 &buffer,
412 position,
413 PredictEditsResponse {
414 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
415a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
416And maybe a short line
417
418Then a few lines
419
420and then another
421{EDITABLE_REGION_END_MARKER}
422 "#),
423 },
424 cx,
425 ),
426 self.fake_completion(
427 &buffer,
428 position,
429 PredictEditsResponse {
430 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
431a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
432And maybe a short line
433Then a few lines
434[a third completion]
435and then another
436{EDITABLE_REGION_END_MARKER}
437 "#),
438 },
439 cx,
440 ),
441 self.fake_completion(
442 &buffer,
443 position,
444 PredictEditsResponse {
445 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
446a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
447And maybe a short line
448and then another
449[fourth completion example]
450{EDITABLE_REGION_END_MARKER}
451 "#),
452 },
453 cx,
454 ),
455 self.fake_completion(
456 &buffer,
457 position,
458 PredictEditsResponse {
459 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
460a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
461And maybe a short line
462Then a few lines
463and then another
464[fifth and final completion]
465{EDITABLE_REGION_END_MARKER}
466 "#),
467 },
468 cx,
469 ),
470 ];
471
472 cx.spawn(|zeta, mut cx| async move {
473 for task in completion_tasks {
474 task.await.unwrap();
475 }
476
477 zeta.update(&mut cx, |zeta, _cx| {
478 zeta.recent_completions.get_mut(2).unwrap().edits = Arc::new([]);
479 zeta.recent_completions.get_mut(3).unwrap().edits = Arc::new([]);
480 })
481 .ok();
482 })
483 }
484
485 #[cfg(any(test, feature = "test-support"))]
486 pub fn fake_completion(
487 &mut self,
488 buffer: &Model<Buffer>,
489 position: language::Anchor,
490 response: PredictEditsResponse,
491 cx: &mut ModelContext<Self>,
492 ) -> Task<Result<InlineCompletion>> {
493 use std::future::ready;
494
495 self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
496 }
497
498 pub fn request_completion(
499 &mut self,
500 buffer: &Model<Buffer>,
501 position: language::Anchor,
502 cx: &mut ModelContext<Self>,
503 ) -> Task<Result<InlineCompletion>> {
504 self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
505 }
506
507 fn perform_predict_edits(
508 client: Arc<Client>,
509 llm_token: LlmApiToken,
510 body: PredictEditsParams,
511 ) -> impl Future<Output = Result<PredictEditsResponse>> {
512 async move {
513 let http_client = client.http_client();
514 let mut token = llm_token.acquire(&client).await?;
515 let mut did_retry = false;
516
517 loop {
518 let request_builder = http_client::Request::builder();
519 let request = request_builder
520 .method(Method::POST)
521 .uri(
522 http_client
523 .build_zed_llm_url("/predict_edits", &[])?
524 .as_ref(),
525 )
526 .header("Content-Type", "application/json")
527 .header("Authorization", format!("Bearer {}", token))
528 .body(serde_json::to_string(&body)?.into())?;
529
530 let mut response = http_client.send(request).await?;
531
532 if response.status().is_success() {
533 let mut body = String::new();
534 response.body_mut().read_to_string(&mut body).await?;
535 return Ok(serde_json::from_str(&body)?);
536 } else if !did_retry
537 && response
538 .headers()
539 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
540 .is_some()
541 {
542 did_retry = true;
543 token = llm_token.refresh(&client).await?;
544 } else {
545 let mut body = String::new();
546 response.body_mut().read_to_string(&mut body).await?;
547 return Err(anyhow!(
548 "error predicting edits.\nStatus: {:?}\nBody: {}",
549 response.status(),
550 body
551 ));
552 }
553 }
554 }
555 }
556
557 #[allow(clippy::too_many_arguments)]
558 fn process_completion_response(
559 output_excerpt: String,
560 snapshot: &BufferSnapshot,
561 excerpt_range: Range<usize>,
562 path: Arc<Path>,
563 input_events: String,
564 input_excerpt: String,
565 request_sent_at: Instant,
566 cx: &AsyncAppContext,
567 ) -> Task<Result<InlineCompletion>> {
568 let snapshot = snapshot.clone();
569 cx.background_executor().spawn(async move {
570 let content = output_excerpt.replace(CURSOR_MARKER, "");
571
572 let codefence_start = content
573 .find(EDITABLE_REGION_START_MARKER)
574 .context("could not find start marker")?;
575 let content = &content[codefence_start..];
576
577 let newline_ix = content.find('\n').context("could not find newline")?;
578 let content = &content[newline_ix + 1..];
579
580 let codefence_end = content
581 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
582 .context("could not find end marker")?;
583 let new_text = &content[..codefence_end];
584
585 let old_text = snapshot
586 .text_for_range(excerpt_range.clone())
587 .collect::<String>();
588
589 let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
590
591 Ok(InlineCompletion {
592 id: InlineCompletionId::new(),
593 path,
594 excerpt_range,
595 edits: edits.into(),
596 snapshot: snapshot.clone(),
597 input_events: input_events.into(),
598 input_excerpt: input_excerpt.into(),
599 output_excerpt: output_excerpt.into(),
600 request_sent_at,
601 response_received_at: Instant::now(),
602 })
603 })
604 }
605
606 pub fn compute_edits(
607 old_text: String,
608 new_text: &str,
609 offset: usize,
610 snapshot: &BufferSnapshot,
611 ) -> Vec<(Range<Anchor>, String)> {
612 let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
613
614 let mut edits: Vec<(Range<usize>, String)> = Vec::new();
615 let mut old_start = offset;
616 for change in diff.iter_all_changes() {
617 let value = change.value();
618 match change.tag() {
619 similar::ChangeTag::Equal => {
620 old_start += value.len();
621 }
622 similar::ChangeTag::Delete => {
623 let old_end = old_start + value.len();
624 if let Some((last_old_range, _)) = edits.last_mut() {
625 if last_old_range.end == old_start {
626 last_old_range.end = old_end;
627 } else {
628 edits.push((old_start..old_end, String::new()));
629 }
630 } else {
631 edits.push((old_start..old_end, String::new()));
632 }
633 old_start = old_end;
634 }
635 similar::ChangeTag::Insert => {
636 if let Some((last_old_range, last_new_text)) = edits.last_mut() {
637 if last_old_range.end == old_start {
638 last_new_text.push_str(value);
639 } else {
640 edits.push((old_start..old_start, value.into()));
641 }
642 } else {
643 edits.push((old_start..old_start, value.into()));
644 }
645 }
646 }
647 }
648
649 edits
650 .into_iter()
651 .map(|(mut old_range, new_text)| {
652 let prefix_len = common_prefix(
653 snapshot.chars_for_range(old_range.clone()),
654 new_text.chars(),
655 );
656 old_range.start += prefix_len;
657 let suffix_len = common_prefix(
658 snapshot.reversed_chars_for_range(old_range.clone()),
659 new_text[prefix_len..].chars().rev(),
660 );
661 old_range.end = old_range.end.saturating_sub(suffix_len);
662
663 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
664 (
665 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
666 new_text,
667 )
668 })
669 .collect()
670 }
671
672 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
673 self.rated_completions.contains(&completion_id)
674 }
675
676 pub fn was_completion_shown(&self, completion_id: InlineCompletionId) -> bool {
677 self.shown_completions.contains(&completion_id)
678 }
679
680 pub fn completion_shown(&mut self, completion_id: InlineCompletionId) {
681 self.shown_completions.insert(completion_id);
682 }
683
684 pub fn rate_completion(
685 &mut self,
686 completion: &InlineCompletion,
687 rating: InlineCompletionRating,
688 feedback: String,
689 cx: &mut ModelContext<Self>,
690 ) {
691 telemetry::event!(
692 "Inline Completion Rated",
693 rating,
694 input_events = completion.input_events,
695 input_excerpt = completion.input_excerpt,
696 output_excerpt = completion.output_excerpt,
697 feedback
698 );
699 self.client.telemetry().flush_events();
700 cx.notify();
701 }
702
703 pub fn recent_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
704 self.recent_completions.iter()
705 }
706
707 pub fn recent_completions_len(&self) -> usize {
708 self.recent_completions.len()
709 }
710
711 fn report_changes_for_buffer(
712 &mut self,
713 buffer: &Model<Buffer>,
714 cx: &mut ModelContext<Self>,
715 ) -> BufferSnapshot {
716 self.register_buffer(buffer, cx);
717
718 let registered_buffer = self
719 .registered_buffers
720 .get_mut(&buffer.entity_id())
721 .unwrap();
722 let new_snapshot = buffer.read(cx).snapshot();
723
724 if new_snapshot.version != registered_buffer.snapshot.version {
725 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
726 self.push_event(Event::BufferChange {
727 old_snapshot,
728 new_snapshot: new_snapshot.clone(),
729 timestamp: Instant::now(),
730 });
731 }
732
733 new_snapshot
734 }
735}
736
737fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
738 a.zip(b)
739 .take_while(|(a, b)| a == b)
740 .map(|(a, _)| a.len_utf8())
741 .sum()
742}
743
744fn prompt_for_excerpt(
745 snapshot: &BufferSnapshot,
746 excerpt_range: &Range<usize>,
747 offset: usize,
748) -> String {
749 let mut prompt_excerpt = String::new();
750 writeln!(
751 prompt_excerpt,
752 "```{}",
753 snapshot
754 .file()
755 .map_or(Cow::Borrowed("untitled"), |file| file
756 .path()
757 .to_string_lossy())
758 )
759 .unwrap();
760
761 if excerpt_range.start == 0 {
762 writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
763 }
764
765 let point_range = excerpt_range.to_point(snapshot);
766 if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
767 let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
768 for chunk in snapshot.text_for_range(extra_context_line_range) {
769 prompt_excerpt.push_str(chunk);
770 }
771 }
772 writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
773 for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
774 prompt_excerpt.push_str(chunk);
775 }
776 prompt_excerpt.push_str(CURSOR_MARKER);
777 for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
778 prompt_excerpt.push_str(chunk);
779 }
780 write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
781
782 if point_range.end.row < snapshot.max_point().row
783 && !snapshot.is_line_blank(point_range.end.row + 1)
784 {
785 let extra_context_line_range = point_range.end
786 ..Point::new(
787 point_range.end.row + 1,
788 snapshot.line_len(point_range.end.row + 1),
789 );
790 for chunk in snapshot.text_for_range(extra_context_line_range) {
791 prompt_excerpt.push_str(chunk);
792 }
793 }
794
795 write!(prompt_excerpt, "\n```").unwrap();
796 prompt_excerpt
797}
798
799fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
800 const CONTEXT_LINES: u32 = 32;
801
802 let mut context_lines_before = CONTEXT_LINES;
803 let mut context_lines_after = CONTEXT_LINES;
804 if point.row < CONTEXT_LINES {
805 context_lines_after += CONTEXT_LINES - point.row;
806 } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
807 context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
808 }
809
810 let excerpt_start_row = point.row.saturating_sub(context_lines_before);
811 let excerpt_start = Point::new(excerpt_start_row, 0);
812 let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
813 let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
814 excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
815}
816
817struct RegisteredBuffer {
818 snapshot: BufferSnapshot,
819 _subscriptions: [gpui::Subscription; 2],
820}
821
822#[derive(Clone)]
823enum Event {
824 BufferChange {
825 old_snapshot: BufferSnapshot,
826 new_snapshot: BufferSnapshot,
827 timestamp: Instant,
828 },
829}
830
831impl Event {
832 fn to_prompt(&self) -> String {
833 match self {
834 Event::BufferChange {
835 old_snapshot,
836 new_snapshot,
837 ..
838 } => {
839 let mut prompt = String::new();
840
841 let old_path = old_snapshot
842 .file()
843 .map(|f| f.path().as_ref())
844 .unwrap_or(Path::new("untitled"));
845 let new_path = new_snapshot
846 .file()
847 .map(|f| f.path().as_ref())
848 .unwrap_or(Path::new("untitled"));
849 if old_path != new_path {
850 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
851 }
852
853 let diff =
854 similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
855 .unified_diff()
856 .to_string();
857 if !diff.is_empty() {
858 write!(
859 prompt,
860 "User edited {:?}:\n```diff\n{}\n```",
861 new_path, diff
862 )
863 .unwrap();
864 }
865
866 prompt
867 }
868 }
869 }
870}
871
872#[derive(Debug, Clone)]
873struct CurrentInlineCompletion {
874 buffer_id: EntityId,
875 completion: InlineCompletion,
876}
877
878impl CurrentInlineCompletion {
879 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
880 if self.buffer_id != old_completion.buffer_id {
881 return true;
882 }
883
884 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
885 return true;
886 };
887 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
888 return false;
889 };
890
891 if old_edits.len() == 1 && new_edits.len() == 1 {
892 let (old_range, old_text) = &old_edits[0];
893 let (new_range, new_text) = &new_edits[0];
894 new_range == old_range && new_text.starts_with(old_text)
895 } else {
896 true
897 }
898 }
899}
900
901struct PendingCompletion {
902 id: usize,
903 _task: Task<Result<()>>,
904}
905
906pub struct ZetaInlineCompletionProvider {
907 zeta: Model<Zeta>,
908 pending_completions: ArrayVec<PendingCompletion, 2>,
909 next_pending_completion_id: usize,
910 current_completion: Option<CurrentInlineCompletion>,
911}
912
913impl ZetaInlineCompletionProvider {
914 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
915
916 pub fn new(zeta: Model<Zeta>) -> Self {
917 Self {
918 zeta,
919 pending_completions: ArrayVec::new(),
920 next_pending_completion_id: 0,
921 current_completion: None,
922 }
923 }
924}
925
926impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
927 fn name() -> &'static str {
928 "zeta"
929 }
930
931 fn display_name() -> &'static str {
932 "Zeta"
933 }
934
935 fn show_completions_in_menu() -> bool {
936 true
937 }
938
939 fn is_enabled(
940 &self,
941 buffer: &Model<Buffer>,
942 cursor_position: language::Anchor,
943 cx: &AppContext,
944 ) -> bool {
945 let buffer = buffer.read(cx);
946 let file = buffer.file();
947 let language = buffer.language_at(cursor_position);
948 let settings = all_language_settings(file, cx);
949 settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
950 }
951
952 fn refresh(
953 &mut self,
954 buffer: Model<Buffer>,
955 position: language::Anchor,
956 debounce: bool,
957 cx: &mut ModelContext<Self>,
958 ) {
959 let pending_completion_id = self.next_pending_completion_id;
960 self.next_pending_completion_id += 1;
961
962 let task = cx.spawn(|this, mut cx| async move {
963 if debounce {
964 cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
965 }
966
967 let completion_request = this.update(&mut cx, |this, cx| {
968 this.zeta.update(cx, |zeta, cx| {
969 zeta.request_completion(&buffer, position, cx)
970 })
971 });
972
973 let mut completion = None;
974 if let Ok(completion_request) = completion_request {
975 completion = Some(CurrentInlineCompletion {
976 buffer_id: buffer.entity_id(),
977 completion: completion_request.await?,
978 });
979 }
980
981 this.update(&mut cx, |this, cx| {
982 if this.pending_completions[0].id == pending_completion_id {
983 this.pending_completions.remove(0);
984 } else {
985 this.pending_completions.clear();
986 }
987
988 if let Some(new_completion) = completion {
989 if let Some(old_completion) = this.current_completion.as_ref() {
990 let snapshot = buffer.read(cx).snapshot();
991 if new_completion.should_replace_completion(&old_completion, &snapshot) {
992 this.zeta.update(cx, |zeta, _cx| {
993 zeta.completion_shown(new_completion.completion.id)
994 });
995 this.current_completion = Some(new_completion);
996 }
997 } else {
998 this.zeta.update(cx, |zeta, _cx| {
999 zeta.completion_shown(new_completion.completion.id)
1000 });
1001 this.current_completion = Some(new_completion);
1002 }
1003 } else {
1004 this.current_completion = None;
1005 }
1006
1007 cx.notify();
1008 })
1009 });
1010
1011 // We always maintain at most two pending completions. When we already
1012 // have two, we replace the newest one.
1013 if self.pending_completions.len() <= 1 {
1014 self.pending_completions.push(PendingCompletion {
1015 id: pending_completion_id,
1016 _task: task,
1017 });
1018 } else if self.pending_completions.len() == 2 {
1019 self.pending_completions.pop();
1020 self.pending_completions.push(PendingCompletion {
1021 id: pending_completion_id,
1022 _task: task,
1023 });
1024 }
1025 }
1026
1027 fn cycle(
1028 &mut self,
1029 _buffer: Model<Buffer>,
1030 _cursor_position: language::Anchor,
1031 _direction: inline_completion::Direction,
1032 _cx: &mut ModelContext<Self>,
1033 ) {
1034 // Right now we don't support cycling.
1035 }
1036
1037 fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1038 self.pending_completions.clear();
1039 }
1040
1041 fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1042 self.pending_completions.clear();
1043 self.current_completion.take();
1044 }
1045
1046 fn suggest(
1047 &mut self,
1048 buffer: &Model<Buffer>,
1049 cursor_position: language::Anchor,
1050 cx: &mut ModelContext<Self>,
1051 ) -> Option<inline_completion::InlineCompletion> {
1052 let CurrentInlineCompletion {
1053 buffer_id,
1054 completion,
1055 ..
1056 } = self.current_completion.as_mut()?;
1057
1058 // Invalidate previous completion if it was generated for a different buffer.
1059 if *buffer_id != buffer.entity_id() {
1060 self.current_completion.take();
1061 return None;
1062 }
1063
1064 let buffer = buffer.read(cx);
1065 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1066 self.current_completion.take();
1067 return None;
1068 };
1069
1070 let cursor_row = cursor_position.to_point(buffer).row;
1071 let (closest_edit_ix, (closest_edit_range, _)) =
1072 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1073 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1074 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1075 cmp::min(distance_from_start, distance_from_end)
1076 })?;
1077
1078 let mut edit_start_ix = closest_edit_ix;
1079 for (range, _) in edits[..edit_start_ix].iter().rev() {
1080 let distance_from_closest_edit =
1081 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1082 if distance_from_closest_edit <= 1 {
1083 edit_start_ix -= 1;
1084 } else {
1085 break;
1086 }
1087 }
1088
1089 let mut edit_end_ix = closest_edit_ix + 1;
1090 for (range, _) in &edits[edit_end_ix..] {
1091 let distance_from_closest_edit =
1092 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1093 if distance_from_closest_edit <= 1 {
1094 edit_end_ix += 1;
1095 } else {
1096 break;
1097 }
1098 }
1099
1100 Some(inline_completion::InlineCompletion {
1101 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1102 })
1103 }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use client::test::FakeServer;
1109 use clock::FakeSystemClock;
1110 use gpui::TestAppContext;
1111 use http_client::FakeHttpClient;
1112 use indoc::indoc;
1113 use language_models::RefreshLlmTokenListener;
1114 use rpc::proto;
1115 use settings::SettingsStore;
1116
1117 use super::*;
1118
1119 #[gpui::test]
1120 fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1121 let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1122 let completion = InlineCompletion {
1123 edits: to_completion_edits(
1124 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1125 &buffer,
1126 cx,
1127 )
1128 .into(),
1129 path: Path::new("").into(),
1130 snapshot: buffer.read(cx).snapshot(),
1131 id: InlineCompletionId::new(),
1132 excerpt_range: 0..0,
1133 input_events: "".into(),
1134 input_excerpt: "".into(),
1135 output_excerpt: "".into(),
1136 request_sent_at: Instant::now(),
1137 response_received_at: Instant::now(),
1138 };
1139
1140 assert_eq!(
1141 from_completion_edits(
1142 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1143 &buffer,
1144 cx
1145 ),
1146 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1147 );
1148
1149 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1150 assert_eq!(
1151 from_completion_edits(
1152 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1153 &buffer,
1154 cx
1155 ),
1156 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1157 );
1158
1159 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1160 assert_eq!(
1161 from_completion_edits(
1162 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1163 &buffer,
1164 cx
1165 ),
1166 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1167 );
1168
1169 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1170 assert_eq!(
1171 from_completion_edits(
1172 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1173 &buffer,
1174 cx
1175 ),
1176 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1177 );
1178
1179 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1180 assert_eq!(
1181 from_completion_edits(
1182 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1183 &buffer,
1184 cx
1185 ),
1186 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1187 );
1188
1189 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1190 assert_eq!(
1191 from_completion_edits(
1192 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1193 &buffer,
1194 cx
1195 ),
1196 vec![(9..11, "".to_string())]
1197 );
1198
1199 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1200 assert_eq!(
1201 from_completion_edits(
1202 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1203 &buffer,
1204 cx
1205 ),
1206 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1207 );
1208
1209 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1210 assert_eq!(
1211 from_completion_edits(
1212 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1213 &buffer,
1214 cx
1215 ),
1216 vec![(4..4, "M".to_string())]
1217 );
1218
1219 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1220 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1221 }
1222
1223 #[gpui::test]
1224 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1225 cx.update(|cx| {
1226 let settings_store = SettingsStore::test(cx);
1227 cx.set_global(settings_store);
1228 client::init_settings(cx);
1229 });
1230
1231 let buffer_content = "lorem\n";
1232 let completion_response = indoc! {"
1233 ```animals.js
1234 <|start_of_file|>
1235 <|editable_region_start|>
1236 lorem
1237 ipsum
1238 <|editable_region_end|>
1239 ```"};
1240
1241 let http_client = FakeHttpClient::create(move |_| async move {
1242 Ok(http_client::Response::builder()
1243 .status(200)
1244 .body(
1245 serde_json::to_string(&PredictEditsResponse {
1246 output_excerpt: completion_response.to_string(),
1247 })
1248 .unwrap()
1249 .into(),
1250 )
1251 .unwrap())
1252 });
1253
1254 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1255 cx.update(|cx| {
1256 RefreshLlmTokenListener::register(client.clone(), cx);
1257 });
1258 let server = FakeServer::for_client(42, &client, cx).await;
1259
1260 let zeta = cx.new_model(|cx| Zeta::new(client, cx));
1261 let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1262 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1263 let completion_task =
1264 zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1265
1266 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1267 server.respond(
1268 token_request.receipt(),
1269 proto::GetLlmTokenResponse { token: "".into() },
1270 );
1271
1272 let completion = completion_task.await.unwrap();
1273 buffer.update(cx, |buffer, cx| {
1274 buffer.edit(completion.edits.iter().cloned(), None, cx)
1275 });
1276 assert_eq!(
1277 buffer.read_with(cx, |buffer, _| buffer.text()),
1278 "lorem\nipsum"
1279 );
1280 }
1281
1282 fn to_completion_edits(
1283 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1284 buffer: &Model<Buffer>,
1285 cx: &AppContext,
1286 ) -> Vec<(Range<Anchor>, String)> {
1287 let buffer = buffer.read(cx);
1288 iterator
1289 .into_iter()
1290 .map(|(range, text)| {
1291 (
1292 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1293 text,
1294 )
1295 })
1296 .collect()
1297 }
1298
1299 fn from_completion_edits(
1300 editor_edits: &[(Range<Anchor>, String)],
1301 buffer: &Model<Buffer>,
1302 cx: &AppContext,
1303 ) -> Vec<(Range<usize>, String)> {
1304 let buffer = buffer.read(cx);
1305 editor_edits
1306 .iter()
1307 .map(|(range, text)| {
1308 (
1309 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1310 text.clone(),
1311 )
1312 })
1313 .collect()
1314 }
1315
1316 #[ctor::ctor]
1317 fn init_logger() {
1318 if std::env::var("RUST_LOG").is_ok() {
1319 env_logger::init();
1320 }
1321 }
1322}