1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_banner;
6mod onboarding_modal;
7mod onboarding_telemetry;
8mod rate_completion_modal;
9
10pub(crate) use completion_diff_element::*;
11use db::kvp::KEY_VALUE_STORE;
12pub use init::*;
13use inline_completion::DataCollectionState;
14pub use license_detection::is_license_eligible_for_data_collection;
15use license_detection::LICENSE_FILES_TO_CHECK;
16pub use onboarding_banner::*;
17pub use rate_completion_modal::*;
18
19use anyhow::{anyhow, Context as _, Result};
20use arrayvec::ArrayVec;
21use client::{Client, UserStore};
22use collections::{HashMap, HashSet, VecDeque};
23use futures::AsyncReadExt;
24use gpui::{
25 actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
26 Subscription, Task, WeakEntity,
27};
28use http_client::{HttpClient, Method};
29use input_excerpt::excerpt_for_cursor_position;
30use language::{
31 text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint,
32};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use postage::watch;
35use project::Project;
36use release_channel::AppVersion;
37use settings::WorktreeId;
38use std::str::FromStr;
39use std::{
40 borrow::Cow,
41 cmp,
42 fmt::Write,
43 future::Future,
44 mem,
45 ops::Range,
46 path::Path,
47 rc::Rc,
48 sync::Arc,
49 time::{Duration, Instant},
50};
51use telemetry_events::InlineCompletionRating;
52use thiserror::Error;
53use util::ResultExt;
54use uuid::Uuid;
55use workspace::notifications::{ErrorMessagePrompt, NotificationId};
56use workspace::Workspace;
57use worktree::Worktree;
58use zed_llm_client::{
59 PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
60 MINIMUM_REQUIRED_VERSION_HEADER_NAME,
61};
62
63const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
64const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
65const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
66const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
69
70const MAX_CONTEXT_TOKENS: usize = 150;
71const MAX_REWRITE_TOKENS: usize = 350;
72const MAX_EVENT_TOKENS: usize = 500;
73
74/// Maximum number of events to track.
75const MAX_EVENT_COUNT: usize = 16;
76
77actions!(edit_prediction, [ClearHistory]);
78
79#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
80pub struct InlineCompletionId(Uuid);
81
82impl From<InlineCompletionId> for gpui::ElementId {
83 fn from(value: InlineCompletionId) -> Self {
84 gpui::ElementId::Uuid(value.0)
85 }
86}
87
88impl std::fmt::Display for InlineCompletionId {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", self.0)
91 }
92}
93
94#[derive(Clone)]
95struct ZetaGlobal(Entity<Zeta>);
96
97impl Global for ZetaGlobal {}
98
99#[derive(Clone)]
100pub struct InlineCompletion {
101 id: InlineCompletionId,
102 path: Arc<Path>,
103 excerpt_range: Range<usize>,
104 cursor_offset: usize,
105 edits: Arc<[(Range<Anchor>, String)]>,
106 snapshot: BufferSnapshot,
107 edit_preview: EditPreview,
108 input_outline: Arc<str>,
109 input_events: Arc<str>,
110 input_excerpt: Arc<str>,
111 output_excerpt: Arc<str>,
112 request_sent_at: Instant,
113 response_received_at: Instant,
114}
115
116impl InlineCompletion {
117 fn latency(&self) -> Duration {
118 self.response_received_at
119 .duration_since(self.request_sent_at)
120 }
121
122 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
123 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
124 }
125}
126
127fn interpolate(
128 old_snapshot: &BufferSnapshot,
129 new_snapshot: &BufferSnapshot,
130 current_edits: Arc<[(Range<Anchor>, String)]>,
131) -> Option<Vec<(Range<Anchor>, String)>> {
132 let mut edits = Vec::new();
133
134 let mut model_edits = current_edits.into_iter().peekable();
135 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
136 while let Some((model_old_range, _)) = model_edits.peek() {
137 let model_old_range = model_old_range.to_offset(old_snapshot);
138 if model_old_range.end < user_edit.old.start {
139 let (model_old_range, model_new_text) = model_edits.next().unwrap();
140 edits.push((model_old_range.clone(), model_new_text.clone()));
141 } else {
142 break;
143 }
144 }
145
146 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
147 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
148 if user_edit.old == model_old_offset_range {
149 let user_new_text = new_snapshot
150 .text_for_range(user_edit.new.clone())
151 .collect::<String>();
152
153 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
154 if !model_suffix.is_empty() {
155 let anchor = old_snapshot.anchor_after(user_edit.old.end);
156 edits.push((anchor..anchor, model_suffix.to_string()));
157 }
158
159 model_edits.next();
160 continue;
161 }
162 }
163 }
164
165 return None;
166 }
167
168 edits.extend(model_edits.cloned());
169
170 if edits.is_empty() {
171 None
172 } else {
173 Some(edits)
174 }
175}
176
177impl std::fmt::Debug for InlineCompletion {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("InlineCompletion")
180 .field("id", &self.id)
181 .field("path", &self.path)
182 .field("edits", &self.edits)
183 .finish_non_exhaustive()
184 }
185}
186
187pub struct Zeta {
188 workspace: Option<WeakEntity<Workspace>>,
189 client: Arc<Client>,
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192 shown_completions: VecDeque<InlineCompletion>,
193 rated_completions: HashSet<InlineCompletionId>,
194 data_collection_choice: Entity<DataCollectionChoice>,
195 llm_token: LlmApiToken,
196 _llm_token_subscription: Subscription,
197 /// Whether the terms of service have been accepted.
198 tos_accepted: bool,
199 /// Whether an update to a newer version of Zed is required to continue using Zeta.
200 update_required: bool,
201 _user_store_subscription: Subscription,
202 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
203}
204
205impl Zeta {
206 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
207 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
208 }
209
210 pub fn register(
211 workspace: Option<WeakEntity<Workspace>>,
212 worktree: Option<Entity<Worktree>>,
213 client: Arc<Client>,
214 user_store: Entity<UserStore>,
215 cx: &mut App,
216 ) -> Entity<Self> {
217 let this = Self::global(cx).unwrap_or_else(|| {
218 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
219 cx.set_global(ZetaGlobal(entity.clone()));
220 entity
221 });
222
223 this.update(cx, move |this, cx| {
224 if let Some(worktree) = worktree {
225 worktree.update(cx, |worktree, cx| {
226 this.license_detection_watchers
227 .entry(worktree.id())
228 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
229 });
230 }
231 });
232
233 this
234 }
235
236 pub fn clear_history(&mut self) {
237 self.events.clear();
238 }
239
240 fn new(
241 workspace: Option<WeakEntity<Workspace>>,
242 client: Arc<Client>,
243 user_store: Entity<UserStore>,
244 cx: &mut Context<Self>,
245 ) -> Self {
246 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
247
248 let data_collection_choice = Self::load_data_collection_choices();
249 let data_collection_choice = cx.new(|_| data_collection_choice);
250
251 Self {
252 workspace,
253 client,
254 events: VecDeque::new(),
255 shown_completions: VecDeque::new(),
256 rated_completions: HashSet::default(),
257 registered_buffers: HashMap::default(),
258 data_collection_choice,
259 llm_token: LlmApiToken::default(),
260 _llm_token_subscription: cx.subscribe(
261 &refresh_llm_token_listener,
262 |this, _listener, _event, cx| {
263 let client = this.client.clone();
264 let llm_token = this.llm_token.clone();
265 cx.spawn(|_this, _cx| async move {
266 llm_token.refresh(&client).await?;
267 anyhow::Ok(())
268 })
269 .detach_and_log_err(cx);
270 },
271 ),
272 tos_accepted: user_store
273 .read(cx)
274 .current_user_has_accepted_terms()
275 .unwrap_or(false),
276 update_required: false,
277 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
278 match event {
279 client::user::Event::PrivateUserInfoUpdated => {
280 this.tos_accepted = user_store
281 .read(cx)
282 .current_user_has_accepted_terms()
283 .unwrap_or(false);
284 }
285 _ => {}
286 }
287 }),
288 license_detection_watchers: HashMap::default(),
289 }
290 }
291
292 fn push_event(&mut self, event: Event) {
293 if let Some(Event::BufferChange {
294 new_snapshot: last_new_snapshot,
295 timestamp: last_timestamp,
296 ..
297 }) = self.events.back_mut()
298 {
299 // Coalesce edits for the same buffer when they happen one after the other.
300 let Event::BufferChange {
301 old_snapshot,
302 new_snapshot,
303 timestamp,
304 } = &event;
305
306 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
307 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
308 && old_snapshot.version == last_new_snapshot.version
309 {
310 *last_new_snapshot = new_snapshot.clone();
311 *last_timestamp = *timestamp;
312 return;
313 }
314 }
315
316 self.events.push_back(event);
317 if self.events.len() >= MAX_EVENT_COUNT {
318 self.events.drain(..MAX_EVENT_COUNT / 2);
319 }
320 }
321
322 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
323 let buffer_id = buffer.entity_id();
324 let weak_buffer = buffer.downgrade();
325
326 if let std::collections::hash_map::Entry::Vacant(entry) =
327 self.registered_buffers.entry(buffer_id)
328 {
329 let snapshot = buffer.read(cx).snapshot();
330
331 entry.insert(RegisteredBuffer {
332 snapshot,
333 _subscriptions: [
334 cx.subscribe(buffer, move |this, buffer, event, cx| {
335 this.handle_buffer_event(buffer, event, cx);
336 }),
337 cx.observe_release(buffer, move |this, _buffer, _cx| {
338 this.registered_buffers.remove(&weak_buffer.entity_id());
339 }),
340 ],
341 });
342 };
343 }
344
345 fn handle_buffer_event(
346 &mut self,
347 buffer: Entity<Buffer>,
348 event: &language::BufferEvent,
349 cx: &mut Context<Self>,
350 ) {
351 if let language::BufferEvent::Edited = event {
352 self.report_changes_for_buffer(&buffer, cx);
353 }
354 }
355
356 #[allow(clippy::too_many_arguments)]
357 fn request_completion_impl<F, R>(
358 &mut self,
359 workspace: Option<Entity<Workspace>>,
360 project: Option<&Entity<Project>>,
361 buffer: &Entity<Buffer>,
362 cursor: language::Anchor,
363 can_collect_data: bool,
364 cx: &mut Context<Self>,
365 perform_predict_edits: F,
366 ) -> Task<Result<Option<InlineCompletion>>>
367 where
368 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
369 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
370 {
371 let snapshot = self.report_changes_for_buffer(&buffer, cx);
372 let diagnostic_groups = snapshot.diagnostic_groups(None);
373 let cursor_point = cursor.to_point(&snapshot);
374 let cursor_offset = cursor_point.to_offset(&snapshot);
375 let events = self.events.clone();
376 let path: Arc<Path> = snapshot
377 .file()
378 .map(|f| Arc::from(f.full_path(cx).as_path()))
379 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
380
381 let zeta = cx.entity();
382 let client = self.client.clone();
383 let llm_token = self.llm_token.clone();
384 let app_version = AppVersion::global(cx);
385
386 let buffer = buffer.clone();
387
388 let local_lsp_store =
389 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
390 let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
391 Some(
392 diagnostic_groups
393 .into_iter()
394 .filter_map(|(language_server_id, diagnostic_group)| {
395 let language_server =
396 local_lsp_store.running_language_server_for_id(language_server_id)?;
397
398 Some((
399 language_server.name(),
400 diagnostic_group.resolve::<usize>(&snapshot),
401 ))
402 })
403 .collect::<Vec<_>>(),
404 )
405 } else {
406 None
407 };
408
409 cx.spawn(|_, cx| async move {
410 let request_sent_at = Instant::now();
411
412 struct BackgroundValues {
413 input_events: String,
414 input_excerpt: String,
415 speculated_output: String,
416 editable_range: Range<usize>,
417 input_outline: String,
418 }
419
420 let values = cx
421 .background_spawn({
422 let snapshot = snapshot.clone();
423 let path = path.clone();
424 async move {
425 let path = path.to_string_lossy();
426 let input_excerpt = excerpt_for_cursor_position(
427 cursor_point,
428 &path,
429 &snapshot,
430 MAX_REWRITE_TOKENS,
431 MAX_CONTEXT_TOKENS,
432 );
433 let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
434 let input_outline = prompt_for_outline(&snapshot);
435
436 anyhow::Ok(BackgroundValues {
437 input_events,
438 input_excerpt: input_excerpt.prompt,
439 speculated_output: input_excerpt.speculated_output,
440 editable_range: input_excerpt.editable_range.to_offset(&snapshot),
441 input_outline,
442 })
443 }
444 })
445 .await?;
446
447 log::debug!(
448 "Events:\n{}\nExcerpt:\n{:?}",
449 values.input_events,
450 values.input_excerpt
451 );
452
453 let body = PredictEditsBody {
454 input_events: values.input_events.clone(),
455 input_excerpt: values.input_excerpt.clone(),
456 speculated_output: Some(values.speculated_output),
457 outline: Some(values.input_outline.clone()),
458 can_collect_data,
459 diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
460 diagnostic_groups
461 .into_iter()
462 .map(|(name, diagnostic_group)| {
463 Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
464 })
465 .collect::<Result<Vec<_>>>()
466 .log_err()
467 }),
468 };
469
470 let response = perform_predict_edits(PerformPredictEditsParams {
471 client,
472 llm_token,
473 app_version,
474 body,
475 })
476 .await;
477 let response = match response {
478 Ok(response) => response,
479 Err(err) => {
480 if err.is::<ZedUpdateRequiredError>() {
481 cx.update(|cx| {
482 zeta.update(cx, |zeta, _cx| {
483 zeta.update_required = true;
484 });
485
486 if let Some(workspace) = workspace {
487 workspace.update(cx, |workspace, cx| {
488 workspace.show_notification(
489 NotificationId::unique::<ZedUpdateRequiredError>(),
490 cx,
491 |cx| {
492 cx.new(|_| {
493 ErrorMessagePrompt::new(err.to_string())
494 .with_link_button(
495 "Update Zed",
496 "https://zed.dev/releases",
497 )
498 })
499 },
500 );
501 });
502 }
503 })
504 .ok();
505 }
506
507 return Err(err);
508 }
509 };
510
511 log::debug!("completion response: {}", &response.output_excerpt);
512
513 Self::process_completion_response(
514 response,
515 buffer,
516 &snapshot,
517 values.editable_range,
518 cursor_offset,
519 path,
520 values.input_outline,
521 values.input_events,
522 values.input_excerpt,
523 request_sent_at,
524 &cx,
525 )
526 .await
527 })
528 }
529
530 // Generates several example completions of various states to fill the Zeta completion modal
531 #[cfg(any(test, feature = "test-support"))]
532 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
533 use language::Point;
534
535 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
536 And maybe a short line
537
538 Then a few lines
539
540 and then another
541 "#};
542
543 let project = None;
544 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
545 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
546
547 let completion_tasks = vec![
548 self.fake_completion(
549 project,
550 &buffer,
551 position,
552 PredictEditsResponse {
553 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
554 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
555a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
556[here's an edit]
557And maybe a short line
558Then a few lines
559and then another
560{EDITABLE_REGION_END_MARKER}
561 ", ),
562 },
563 cx,
564 ),
565 self.fake_completion(
566 project,
567 &buffer,
568 position,
569 PredictEditsResponse {
570 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
571 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
572a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
573And maybe a short line
574[and another edit]
575Then a few lines
576and then another
577{EDITABLE_REGION_END_MARKER}
578 "#),
579 },
580 cx,
581 ),
582 self.fake_completion(
583 project,
584 &buffer,
585 position,
586 PredictEditsResponse {
587 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
588 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
589a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
590And maybe a short line
591
592Then a few lines
593
594and then another
595{EDITABLE_REGION_END_MARKER}
596 "#),
597 },
598 cx,
599 ),
600 self.fake_completion(
601 project,
602 &buffer,
603 position,
604 PredictEditsResponse {
605 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
606 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
607a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
608And maybe a short line
609
610Then a few lines
611
612and then another
613{EDITABLE_REGION_END_MARKER}
614 "#),
615 },
616 cx,
617 ),
618 self.fake_completion(
619 project,
620 &buffer,
621 position,
622 PredictEditsResponse {
623 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
624 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
625a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
626And maybe a short line
627Then a few lines
628[a third completion]
629and then another
630{EDITABLE_REGION_END_MARKER}
631 "#),
632 },
633 cx,
634 ),
635 self.fake_completion(
636 project,
637 &buffer,
638 position,
639 PredictEditsResponse {
640 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
641 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
642a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
643And maybe a short line
644and then another
645[fourth completion example]
646{EDITABLE_REGION_END_MARKER}
647 "#),
648 },
649 cx,
650 ),
651 self.fake_completion(
652 project,
653 &buffer,
654 position,
655 PredictEditsResponse {
656 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
657 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
658a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
659And maybe a short line
660Then a few lines
661and then another
662[fifth and final completion]
663{EDITABLE_REGION_END_MARKER}
664 "#),
665 },
666 cx,
667 ),
668 ];
669
670 cx.spawn(|zeta, mut cx| async move {
671 for task in completion_tasks {
672 task.await.unwrap();
673 }
674
675 zeta.update(&mut cx, |zeta, _cx| {
676 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
677 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
678 })
679 .ok();
680 })
681 }
682
683 #[cfg(any(test, feature = "test-support"))]
684 pub fn fake_completion(
685 &mut self,
686 project: Option<&Entity<Project>>,
687 buffer: &Entity<Buffer>,
688 position: language::Anchor,
689 response: PredictEditsResponse,
690 cx: &mut Context<Self>,
691 ) -> Task<Result<Option<InlineCompletion>>> {
692 use std::future::ready;
693
694 self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
695 ready(Ok(response))
696 })
697 }
698
699 pub fn request_completion(
700 &mut self,
701 project: Option<&Entity<Project>>,
702 buffer: &Entity<Buffer>,
703 position: language::Anchor,
704 can_collect_data: bool,
705 cx: &mut Context<Self>,
706 ) -> Task<Result<Option<InlineCompletion>>> {
707 let workspace = self
708 .workspace
709 .as_ref()
710 .and_then(|workspace| workspace.upgrade());
711 self.request_completion_impl(
712 workspace,
713 project,
714 buffer,
715 position,
716 can_collect_data,
717 cx,
718 Self::perform_predict_edits,
719 )
720 }
721
722 fn perform_predict_edits(
723 params: PerformPredictEditsParams,
724 ) -> impl Future<Output = Result<PredictEditsResponse>> {
725 async move {
726 let PerformPredictEditsParams {
727 client,
728 llm_token,
729 app_version,
730 body,
731 ..
732 } = params;
733
734 let http_client = client.http_client();
735 let mut token = llm_token.acquire(&client).await?;
736 let mut did_retry = false;
737
738 loop {
739 let request_builder = http_client::Request::builder().method(Method::POST);
740 let request_builder =
741 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
742 request_builder.uri(predict_edits_url)
743 } else {
744 request_builder.uri(
745 http_client
746 .build_zed_llm_url("/predict_edits/v2", &[])?
747 .as_ref(),
748 )
749 };
750 let request = request_builder
751 .header("Content-Type", "application/json")
752 .header("Authorization", format!("Bearer {}", token))
753 .body(serde_json::to_string(&body)?.into())?;
754
755 let mut response = http_client.send(request).await?;
756
757 if let Some(minimum_required_version) = response
758 .headers()
759 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
760 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
761 {
762 if app_version < minimum_required_version {
763 return Err(anyhow!(ZedUpdateRequiredError {
764 minimum_version: minimum_required_version
765 }));
766 }
767 }
768
769 if response.status().is_success() {
770 let mut body = String::new();
771 response.body_mut().read_to_string(&mut body).await?;
772 return Ok(serde_json::from_str(&body)?);
773 } else if !did_retry
774 && response
775 .headers()
776 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
777 .is_some()
778 {
779 did_retry = true;
780 token = llm_token.refresh(&client).await?;
781 } else {
782 let mut body = String::new();
783 response.body_mut().read_to_string(&mut body).await?;
784 return Err(anyhow!(
785 "error predicting edits.\nStatus: {:?}\nBody: {}",
786 response.status(),
787 body
788 ));
789 }
790 }
791 }
792 }
793
794 #[allow(clippy::too_many_arguments)]
795 fn process_completion_response(
796 prediction_response: PredictEditsResponse,
797 buffer: Entity<Buffer>,
798 snapshot: &BufferSnapshot,
799 editable_range: Range<usize>,
800 cursor_offset: usize,
801 path: Arc<Path>,
802 input_outline: String,
803 input_events: String,
804 input_excerpt: String,
805 request_sent_at: Instant,
806 cx: &AsyncApp,
807 ) -> Task<Result<Option<InlineCompletion>>> {
808 let snapshot = snapshot.clone();
809 let request_id = prediction_response.request_id;
810 let output_excerpt = prediction_response.output_excerpt;
811 cx.spawn(|cx| async move {
812 let output_excerpt: Arc<str> = output_excerpt.into();
813
814 let edits: Arc<[(Range<Anchor>, String)]> = cx
815 .background_spawn({
816 let output_excerpt = output_excerpt.clone();
817 let editable_range = editable_range.clone();
818 let snapshot = snapshot.clone();
819 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
820 })
821 .await?
822 .into();
823
824 let Some((edits, snapshot, edit_preview)) = buffer.read_with(&cx, {
825 let edits = edits.clone();
826 |buffer, cx| {
827 let new_snapshot = buffer.snapshot();
828 let edits: Arc<[(Range<Anchor>, String)]> =
829 interpolate(&snapshot, &new_snapshot, edits)?.into();
830 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
831 }
832 })?
833 else {
834 return anyhow::Ok(None);
835 };
836
837 let edit_preview = edit_preview.await;
838
839 Ok(Some(InlineCompletion {
840 id: InlineCompletionId(request_id),
841 path,
842 excerpt_range: editable_range,
843 cursor_offset,
844 edits,
845 edit_preview,
846 snapshot,
847 input_outline: input_outline.into(),
848 input_events: input_events.into(),
849 input_excerpt: input_excerpt.into(),
850 output_excerpt,
851 request_sent_at,
852 response_received_at: Instant::now(),
853 }))
854 })
855 }
856
857 fn parse_edits(
858 output_excerpt: Arc<str>,
859 editable_range: Range<usize>,
860 snapshot: &BufferSnapshot,
861 ) -> Result<Vec<(Range<Anchor>, String)>> {
862 let content = output_excerpt.replace(CURSOR_MARKER, "");
863
864 let start_markers = content
865 .match_indices(EDITABLE_REGION_START_MARKER)
866 .collect::<Vec<_>>();
867 anyhow::ensure!(
868 start_markers.len() == 1,
869 "expected exactly one start marker, found {}",
870 start_markers.len()
871 );
872
873 let end_markers = content
874 .match_indices(EDITABLE_REGION_END_MARKER)
875 .collect::<Vec<_>>();
876 anyhow::ensure!(
877 end_markers.len() == 1,
878 "expected exactly one end marker, found {}",
879 end_markers.len()
880 );
881
882 let sof_markers = content
883 .match_indices(START_OF_FILE_MARKER)
884 .collect::<Vec<_>>();
885 anyhow::ensure!(
886 sof_markers.len() <= 1,
887 "expected at most one start-of-file marker, found {}",
888 sof_markers.len()
889 );
890
891 let codefence_start = start_markers[0].0;
892 let content = &content[codefence_start..];
893
894 let newline_ix = content.find('\n').context("could not find newline")?;
895 let content = &content[newline_ix + 1..];
896
897 let codefence_end = content
898 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
899 .context("could not find end marker")?;
900 let new_text = &content[..codefence_end];
901
902 let old_text = snapshot
903 .text_for_range(editable_range.clone())
904 .collect::<String>();
905
906 Ok(Self::compute_edits(
907 old_text,
908 new_text,
909 editable_range.start,
910 &snapshot,
911 ))
912 }
913
914 pub fn compute_edits(
915 old_text: String,
916 new_text: &str,
917 offset: usize,
918 snapshot: &BufferSnapshot,
919 ) -> Vec<(Range<Anchor>, String)> {
920 text_diff(&old_text, &new_text)
921 .into_iter()
922 .map(|(mut old_range, new_text)| {
923 old_range.start += offset;
924 old_range.end += offset;
925
926 let prefix_len = common_prefix(
927 snapshot.chars_for_range(old_range.clone()),
928 new_text.chars(),
929 );
930 old_range.start += prefix_len;
931
932 let suffix_len = common_prefix(
933 snapshot.reversed_chars_for_range(old_range.clone()),
934 new_text[prefix_len..].chars().rev(),
935 );
936 old_range.end = old_range.end.saturating_sub(suffix_len);
937
938 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
939 let range = if old_range.is_empty() {
940 let anchor = snapshot.anchor_after(old_range.start);
941 anchor..anchor
942 } else {
943 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
944 };
945 (range, new_text)
946 })
947 .collect()
948 }
949
950 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
951 self.rated_completions.contains(&completion_id)
952 }
953
954 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
955 self.shown_completions.push_front(completion.clone());
956 if self.shown_completions.len() > 50 {
957 let completion = self.shown_completions.pop_back().unwrap();
958 self.rated_completions.remove(&completion.id);
959 }
960 cx.notify();
961 }
962
963 pub fn rate_completion(
964 &mut self,
965 completion: &InlineCompletion,
966 rating: InlineCompletionRating,
967 feedback: String,
968 cx: &mut Context<Self>,
969 ) {
970 self.rated_completions.insert(completion.id);
971 telemetry::event!(
972 "Edit Prediction Rated",
973 rating,
974 input_events = completion.input_events,
975 input_excerpt = completion.input_excerpt,
976 input_outline = completion.input_outline,
977 output_excerpt = completion.output_excerpt,
978 feedback
979 );
980 self.client.telemetry().flush_events();
981 cx.notify();
982 }
983
984 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
985 self.shown_completions.iter()
986 }
987
988 pub fn shown_completions_len(&self) -> usize {
989 self.shown_completions.len()
990 }
991
992 fn report_changes_for_buffer(
993 &mut self,
994 buffer: &Entity<Buffer>,
995 cx: &mut Context<Self>,
996 ) -> BufferSnapshot {
997 self.register_buffer(buffer, cx);
998
999 let registered_buffer = self
1000 .registered_buffers
1001 .get_mut(&buffer.entity_id())
1002 .unwrap();
1003 let new_snapshot = buffer.read(cx).snapshot();
1004
1005 if new_snapshot.version != registered_buffer.snapshot.version {
1006 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1007 self.push_event(Event::BufferChange {
1008 old_snapshot,
1009 new_snapshot: new_snapshot.clone(),
1010 timestamp: Instant::now(),
1011 });
1012 }
1013
1014 new_snapshot
1015 }
1016
1017 fn load_data_collection_choices() -> DataCollectionChoice {
1018 let choice = KEY_VALUE_STORE
1019 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1020 .log_err()
1021 .flatten();
1022
1023 match choice.as_deref() {
1024 Some("true") => DataCollectionChoice::Enabled,
1025 Some("false") => DataCollectionChoice::Disabled,
1026 Some(_) => {
1027 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1028 DataCollectionChoice::NotAnswered
1029 }
1030 None => DataCollectionChoice::NotAnswered,
1031 }
1032 }
1033}
1034
1035struct PerformPredictEditsParams {
1036 pub client: Arc<Client>,
1037 pub llm_token: LlmApiToken,
1038 pub app_version: SemanticVersion,
1039 pub body: PredictEditsBody,
1040}
1041
1042#[derive(Error, Debug)]
1043#[error(
1044 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1045)]
1046pub struct ZedUpdateRequiredError {
1047 minimum_version: SemanticVersion,
1048}
1049
1050struct LicenseDetectionWatcher {
1051 is_open_source_rx: watch::Receiver<bool>,
1052 _is_open_source_task: Task<()>,
1053}
1054
1055impl LicenseDetectionWatcher {
1056 pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1057 let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1058
1059 // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1060 let task = if worktree.abs_path().is_file() {
1061 Task::ready(())
1062 } else {
1063 let loaded_files = LICENSE_FILES_TO_CHECK
1064 .iter()
1065 .map(Path::new)
1066 .map(|file| worktree.load_file(file, cx))
1067 .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1068
1069 cx.background_spawn(async move {
1070 for loaded_file in loaded_files.into_iter() {
1071 let Ok(loaded_file) = loaded_file.await else {
1072 continue;
1073 };
1074
1075 let path = &loaded_file.file.path;
1076 if is_license_eligible_for_data_collection(&loaded_file.text) {
1077 log::info!("detected '{path:?}' as open source license");
1078 *is_open_source_tx.borrow_mut() = true;
1079 } else {
1080 log::info!("didn't detect '{path:?}' as open source license");
1081 }
1082
1083 // stop on the first license that successfully read
1084 return;
1085 }
1086
1087 log::debug!("didn't find a license file to check, assuming closed source");
1088 })
1089 };
1090
1091 Self {
1092 is_open_source_rx,
1093 _is_open_source_task: task,
1094 }
1095 }
1096
1097 /// Answers false until we find out it's open source
1098 pub fn is_project_open_source(&self) -> bool {
1099 *self.is_open_source_rx.borrow()
1100 }
1101}
1102
1103fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1104 a.zip(b)
1105 .take_while(|(a, b)| a == b)
1106 .map(|(a, _)| a.len_utf8())
1107 .sum()
1108}
1109
1110fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1111 let mut input_outline = String::new();
1112
1113 writeln!(
1114 input_outline,
1115 "```{}",
1116 snapshot
1117 .file()
1118 .map_or(Cow::Borrowed("untitled"), |file| file
1119 .path()
1120 .to_string_lossy())
1121 )
1122 .unwrap();
1123
1124 if let Some(outline) = snapshot.outline(None) {
1125 for item in &outline.items {
1126 let spacing = " ".repeat(item.depth);
1127 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1128 }
1129 }
1130
1131 writeln!(input_outline, "```").unwrap();
1132
1133 input_outline
1134}
1135
1136fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1137 let mut result = String::new();
1138 for event in events.iter().rev() {
1139 let event_string = event.to_prompt();
1140 let event_tokens = tokens_for_bytes(event_string.len());
1141 if event_tokens > remaining_tokens {
1142 break;
1143 }
1144
1145 if !result.is_empty() {
1146 result.insert_str(0, "\n\n");
1147 }
1148 result.insert_str(0, &event_string);
1149 remaining_tokens -= event_tokens;
1150 }
1151 result
1152}
1153
1154struct RegisteredBuffer {
1155 snapshot: BufferSnapshot,
1156 _subscriptions: [gpui::Subscription; 2],
1157}
1158
1159#[derive(Clone)]
1160enum Event {
1161 BufferChange {
1162 old_snapshot: BufferSnapshot,
1163 new_snapshot: BufferSnapshot,
1164 timestamp: Instant,
1165 },
1166}
1167
1168impl Event {
1169 fn to_prompt(&self) -> String {
1170 match self {
1171 Event::BufferChange {
1172 old_snapshot,
1173 new_snapshot,
1174 ..
1175 } => {
1176 let mut prompt = String::new();
1177
1178 let old_path = old_snapshot
1179 .file()
1180 .map(|f| f.path().as_ref())
1181 .unwrap_or(Path::new("untitled"));
1182 let new_path = new_snapshot
1183 .file()
1184 .map(|f| f.path().as_ref())
1185 .unwrap_or(Path::new("untitled"));
1186 if old_path != new_path {
1187 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1188 }
1189
1190 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1191 if !diff.is_empty() {
1192 write!(
1193 prompt,
1194 "User edited {:?}:\n```diff\n{}\n```",
1195 new_path, diff
1196 )
1197 .unwrap();
1198 }
1199
1200 prompt
1201 }
1202 }
1203 }
1204}
1205
1206#[derive(Debug, Clone)]
1207struct CurrentInlineCompletion {
1208 buffer_id: EntityId,
1209 completion: InlineCompletion,
1210}
1211
1212impl CurrentInlineCompletion {
1213 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1214 if self.buffer_id != old_completion.buffer_id {
1215 return true;
1216 }
1217
1218 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1219 return true;
1220 };
1221 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1222 return false;
1223 };
1224
1225 if old_edits.len() == 1 && new_edits.len() == 1 {
1226 let (old_range, old_text) = &old_edits[0];
1227 let (new_range, new_text) = &new_edits[0];
1228 new_range == old_range && new_text.starts_with(old_text)
1229 } else {
1230 true
1231 }
1232 }
1233}
1234
1235struct PendingCompletion {
1236 id: usize,
1237 _task: Task<()>,
1238}
1239
1240#[derive(Debug, Clone, Copy)]
1241pub enum DataCollectionChoice {
1242 NotAnswered,
1243 Enabled,
1244 Disabled,
1245}
1246
1247impl DataCollectionChoice {
1248 pub fn is_enabled(self) -> bool {
1249 match self {
1250 Self::Enabled => true,
1251 Self::NotAnswered | Self::Disabled => false,
1252 }
1253 }
1254
1255 pub fn is_answered(self) -> bool {
1256 match self {
1257 Self::Enabled | Self::Disabled => true,
1258 Self::NotAnswered => false,
1259 }
1260 }
1261
1262 pub fn toggle(&self) -> DataCollectionChoice {
1263 match self {
1264 Self::Enabled => Self::Disabled,
1265 Self::Disabled => Self::Enabled,
1266 Self::NotAnswered => Self::Enabled,
1267 }
1268 }
1269}
1270
1271impl From<bool> for DataCollectionChoice {
1272 fn from(value: bool) -> Self {
1273 match value {
1274 true => DataCollectionChoice::Enabled,
1275 false => DataCollectionChoice::Disabled,
1276 }
1277 }
1278}
1279
1280pub struct ProviderDataCollection {
1281 /// When set to None, data collection is not possible in the provider buffer
1282 choice: Option<Entity<DataCollectionChoice>>,
1283 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1284}
1285
1286impl ProviderDataCollection {
1287 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1288 let choice_and_watcher = buffer.and_then(|buffer| {
1289 let file = buffer.read(cx).file()?;
1290
1291 if !file.is_local() || file.is_private() {
1292 return None;
1293 }
1294
1295 let zeta = zeta.read(cx);
1296 let choice = zeta.data_collection_choice.clone();
1297
1298 let license_detection_watcher = zeta
1299 .license_detection_watchers
1300 .get(&file.worktree_id(cx))
1301 .cloned()?;
1302
1303 Some((choice, license_detection_watcher))
1304 });
1305
1306 if let Some((choice, watcher)) = choice_and_watcher {
1307 ProviderDataCollection {
1308 choice: Some(choice),
1309 license_detection_watcher: Some(watcher),
1310 }
1311 } else {
1312 ProviderDataCollection {
1313 choice: None,
1314 license_detection_watcher: None,
1315 }
1316 }
1317 }
1318
1319 pub fn can_collect_data(&self, cx: &App) -> bool {
1320 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1321 }
1322
1323 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1324 self.choice
1325 .as_ref()
1326 .is_some_and(|choice| choice.read(cx).is_enabled())
1327 }
1328
1329 fn is_project_open_source(&self) -> bool {
1330 self.license_detection_watcher
1331 .as_ref()
1332 .is_some_and(|watcher| watcher.is_project_open_source())
1333 }
1334
1335 pub fn toggle(&mut self, cx: &mut App) {
1336 if let Some(choice) = self.choice.as_mut() {
1337 let new_choice = choice.update(cx, |choice, _cx| {
1338 let new_choice = choice.toggle();
1339 *choice = new_choice;
1340 new_choice
1341 });
1342
1343 db::write_and_log(cx, move || {
1344 KEY_VALUE_STORE.write_kvp(
1345 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1346 new_choice.is_enabled().to_string(),
1347 )
1348 });
1349 }
1350 }
1351}
1352
1353pub struct ZetaInlineCompletionProvider {
1354 zeta: Entity<Zeta>,
1355 pending_completions: ArrayVec<PendingCompletion, 2>,
1356 next_pending_completion_id: usize,
1357 current_completion: Option<CurrentInlineCompletion>,
1358 /// None if this is entirely disabled for this provider
1359 provider_data_collection: ProviderDataCollection,
1360 last_request_timestamp: Instant,
1361}
1362
1363impl ZetaInlineCompletionProvider {
1364 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1365
1366 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1367 Self {
1368 zeta,
1369 pending_completions: ArrayVec::new(),
1370 next_pending_completion_id: 0,
1371 current_completion: None,
1372 provider_data_collection,
1373 last_request_timestamp: Instant::now(),
1374 }
1375 }
1376}
1377
1378impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1379 fn name() -> &'static str {
1380 "zed-predict"
1381 }
1382
1383 fn display_name() -> &'static str {
1384 "Zed's Edit Predictions"
1385 }
1386
1387 fn show_completions_in_menu() -> bool {
1388 true
1389 }
1390
1391 fn show_tab_accept_marker() -> bool {
1392 true
1393 }
1394
1395 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1396 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1397
1398 if self.provider_data_collection.is_data_collection_enabled(cx) {
1399 DataCollectionState::Enabled {
1400 is_project_open_source,
1401 }
1402 } else {
1403 DataCollectionState::Disabled {
1404 is_project_open_source,
1405 }
1406 }
1407 }
1408
1409 fn toggle_data_collection(&mut self, cx: &mut App) {
1410 self.provider_data_collection.toggle(cx);
1411 }
1412
1413 fn is_enabled(
1414 &self,
1415 _buffer: &Entity<Buffer>,
1416 _cursor_position: language::Anchor,
1417 _cx: &App,
1418 ) -> bool {
1419 true
1420 }
1421
1422 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1423 !self.zeta.read(cx).tos_accepted
1424 }
1425
1426 fn is_refreshing(&self) -> bool {
1427 !self.pending_completions.is_empty()
1428 }
1429
1430 fn refresh(
1431 &mut self,
1432 project: Option<Entity<Project>>,
1433 buffer: Entity<Buffer>,
1434 position: language::Anchor,
1435 _debounce: bool,
1436 cx: &mut Context<Self>,
1437 ) {
1438 if !self.zeta.read(cx).tos_accepted {
1439 return;
1440 }
1441
1442 if self.zeta.read(cx).update_required {
1443 return;
1444 }
1445
1446 if let Some(current_completion) = self.current_completion.as_ref() {
1447 let snapshot = buffer.read(cx).snapshot();
1448 if current_completion
1449 .completion
1450 .interpolate(&snapshot)
1451 .is_some()
1452 {
1453 return;
1454 }
1455 }
1456
1457 let pending_completion_id = self.next_pending_completion_id;
1458 self.next_pending_completion_id += 1;
1459 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1460 let last_request_timestamp = self.last_request_timestamp;
1461
1462 let task = cx.spawn(|this, mut cx| async move {
1463 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1464 .checked_duration_since(Instant::now())
1465 {
1466 cx.background_executor().timer(timeout).await;
1467 }
1468
1469 let completion_request = this.update(&mut cx, |this, cx| {
1470 this.last_request_timestamp = Instant::now();
1471 this.zeta.update(cx, |zeta, cx| {
1472 zeta.request_completion(
1473 project.as_ref(),
1474 &buffer,
1475 position,
1476 can_collect_data,
1477 cx,
1478 )
1479 })
1480 });
1481
1482 let completion = match completion_request {
1483 Ok(completion_request) => {
1484 let completion_request = completion_request.await;
1485 completion_request.map(|c| {
1486 c.map(|completion| CurrentInlineCompletion {
1487 buffer_id: buffer.entity_id(),
1488 completion,
1489 })
1490 })
1491 }
1492 Err(error) => Err(error),
1493 };
1494 let Some(new_completion) = completion
1495 .context("edit prediction failed")
1496 .log_err()
1497 .flatten()
1498 else {
1499 this.update(&mut cx, |this, cx| {
1500 if this.pending_completions[0].id == pending_completion_id {
1501 this.pending_completions.remove(0);
1502 } else {
1503 this.pending_completions.clear();
1504 }
1505
1506 cx.notify();
1507 })
1508 .ok();
1509 return;
1510 };
1511
1512 this.update(&mut cx, |this, cx| {
1513 if this.pending_completions[0].id == pending_completion_id {
1514 this.pending_completions.remove(0);
1515 } else {
1516 this.pending_completions.clear();
1517 }
1518
1519 if let Some(old_completion) = this.current_completion.as_ref() {
1520 let snapshot = buffer.read(cx).snapshot();
1521 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1522 this.zeta.update(cx, |zeta, cx| {
1523 zeta.completion_shown(&new_completion.completion, cx);
1524 });
1525 this.current_completion = Some(new_completion);
1526 }
1527 } else {
1528 this.zeta.update(cx, |zeta, cx| {
1529 zeta.completion_shown(&new_completion.completion, cx);
1530 });
1531 this.current_completion = Some(new_completion);
1532 }
1533
1534 cx.notify();
1535 })
1536 .ok();
1537 });
1538
1539 // We always maintain at most two pending completions. When we already
1540 // have two, we replace the newest one.
1541 if self.pending_completions.len() <= 1 {
1542 self.pending_completions.push(PendingCompletion {
1543 id: pending_completion_id,
1544 _task: task,
1545 });
1546 } else if self.pending_completions.len() == 2 {
1547 self.pending_completions.pop();
1548 self.pending_completions.push(PendingCompletion {
1549 id: pending_completion_id,
1550 _task: task,
1551 });
1552 }
1553 }
1554
1555 fn cycle(
1556 &mut self,
1557 _buffer: Entity<Buffer>,
1558 _cursor_position: language::Anchor,
1559 _direction: inline_completion::Direction,
1560 _cx: &mut Context<Self>,
1561 ) {
1562 // Right now we don't support cycling.
1563 }
1564
1565 fn accept(&mut self, _cx: &mut Context<Self>) {
1566 self.pending_completions.clear();
1567 }
1568
1569 fn discard(&mut self, _cx: &mut Context<Self>) {
1570 self.pending_completions.clear();
1571 self.current_completion.take();
1572 }
1573
1574 fn suggest(
1575 &mut self,
1576 buffer: &Entity<Buffer>,
1577 cursor_position: language::Anchor,
1578 cx: &mut Context<Self>,
1579 ) -> Option<inline_completion::InlineCompletion> {
1580 let CurrentInlineCompletion {
1581 buffer_id,
1582 completion,
1583 ..
1584 } = self.current_completion.as_mut()?;
1585
1586 // Invalidate previous completion if it was generated for a different buffer.
1587 if *buffer_id != buffer.entity_id() {
1588 self.current_completion.take();
1589 return None;
1590 }
1591
1592 let buffer = buffer.read(cx);
1593 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1594 self.current_completion.take();
1595 return None;
1596 };
1597
1598 let cursor_row = cursor_position.to_point(buffer).row;
1599 let (closest_edit_ix, (closest_edit_range, _)) =
1600 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1601 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1602 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1603 cmp::min(distance_from_start, distance_from_end)
1604 })?;
1605
1606 let mut edit_start_ix = closest_edit_ix;
1607 for (range, _) in edits[..edit_start_ix].iter().rev() {
1608 let distance_from_closest_edit =
1609 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1610 if distance_from_closest_edit <= 1 {
1611 edit_start_ix -= 1;
1612 } else {
1613 break;
1614 }
1615 }
1616
1617 let mut edit_end_ix = closest_edit_ix + 1;
1618 for (range, _) in &edits[edit_end_ix..] {
1619 let distance_from_closest_edit =
1620 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1621 if distance_from_closest_edit <= 1 {
1622 edit_end_ix += 1;
1623 } else {
1624 break;
1625 }
1626 }
1627
1628 Some(inline_completion::InlineCompletion {
1629 id: Some(completion.id.to_string().into()),
1630 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1631 edit_preview: Some(completion.edit_preview.clone()),
1632 })
1633 }
1634}
1635
1636fn tokens_for_bytes(bytes: usize) -> usize {
1637 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1638 /// intentionally low to err on the side of underestimating limits.
1639 const BYTES_PER_TOKEN_GUESS: usize = 3;
1640 bytes / BYTES_PER_TOKEN_GUESS
1641}
1642
1643#[cfg(test)]
1644mod tests {
1645 use client::test::FakeServer;
1646 use clock::FakeSystemClock;
1647 use gpui::TestAppContext;
1648 use http_client::FakeHttpClient;
1649 use indoc::indoc;
1650 use language::Point;
1651 use rpc::proto;
1652 use settings::SettingsStore;
1653
1654 use super::*;
1655
1656 #[gpui::test]
1657 async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1658 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1659 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1660 to_completion_edits(
1661 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1662 &buffer,
1663 cx,
1664 )
1665 .into()
1666 });
1667
1668 let edit_preview = cx
1669 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1670 .await;
1671
1672 let completion = InlineCompletion {
1673 edits,
1674 edit_preview,
1675 path: Path::new("").into(),
1676 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1677 id: InlineCompletionId(Uuid::new_v4()),
1678 excerpt_range: 0..0,
1679 cursor_offset: 0,
1680 input_outline: "".into(),
1681 input_events: "".into(),
1682 input_excerpt: "".into(),
1683 output_excerpt: "".into(),
1684 request_sent_at: Instant::now(),
1685 response_received_at: Instant::now(),
1686 };
1687
1688 cx.update(|cx| {
1689 assert_eq!(
1690 from_completion_edits(
1691 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1692 &buffer,
1693 cx
1694 ),
1695 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1696 );
1697
1698 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1699 assert_eq!(
1700 from_completion_edits(
1701 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1702 &buffer,
1703 cx
1704 ),
1705 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1706 );
1707
1708 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1709 assert_eq!(
1710 from_completion_edits(
1711 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1712 &buffer,
1713 cx
1714 ),
1715 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1716 );
1717
1718 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1719 assert_eq!(
1720 from_completion_edits(
1721 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1722 &buffer,
1723 cx
1724 ),
1725 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1726 );
1727
1728 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1729 assert_eq!(
1730 from_completion_edits(
1731 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1732 &buffer,
1733 cx
1734 ),
1735 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1736 );
1737
1738 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1739 assert_eq!(
1740 from_completion_edits(
1741 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1742 &buffer,
1743 cx
1744 ),
1745 vec![(9..11, "".to_string())]
1746 );
1747
1748 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1749 assert_eq!(
1750 from_completion_edits(
1751 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1752 &buffer,
1753 cx
1754 ),
1755 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1756 );
1757
1758 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1759 assert_eq!(
1760 from_completion_edits(
1761 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1762 &buffer,
1763 cx
1764 ),
1765 vec![(4..4, "M".to_string())]
1766 );
1767
1768 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1769 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1770 })
1771 }
1772
1773 #[gpui::test]
1774 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1775 cx.update(|cx| {
1776 let settings_store = SettingsStore::test(cx);
1777 cx.set_global(settings_store);
1778 client::init_settings(cx);
1779 });
1780
1781 let edits = edits_for_prediction(
1782 indoc! {"
1783 fn main() {
1784 let word_1 = \"lorem\";
1785 let range = word.len()..word.len();
1786 }
1787 "},
1788 indoc! {"
1789 <|editable_region_start|>
1790 fn main() {
1791 let word_1 = \"lorem\";
1792 let range = word_1.len()..word_1.len();
1793 }
1794
1795 <|editable_region_end|>
1796 "},
1797 cx,
1798 )
1799 .await;
1800 assert_eq!(
1801 edits,
1802 [
1803 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1804 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1805 ]
1806 );
1807
1808 let edits = edits_for_prediction(
1809 indoc! {"
1810 fn main() {
1811 let story = \"the quick\"
1812 }
1813 "},
1814 indoc! {"
1815 <|editable_region_start|>
1816 fn main() {
1817 let story = \"the quick brown fox jumps over the lazy dog\";
1818 }
1819
1820 <|editable_region_end|>
1821 "},
1822 cx,
1823 )
1824 .await;
1825 assert_eq!(
1826 edits,
1827 [
1828 (
1829 Point::new(1, 26)..Point::new(1, 26),
1830 " brown fox jumps over the lazy dog".to_string()
1831 ),
1832 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1833 ]
1834 );
1835 }
1836
1837 #[gpui::test]
1838 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1839 cx.update(|cx| {
1840 let settings_store = SettingsStore::test(cx);
1841 cx.set_global(settings_store);
1842 client::init_settings(cx);
1843 });
1844
1845 let buffer_content = "lorem\n";
1846 let completion_response = indoc! {"
1847 ```animals.js
1848 <|start_of_file|>
1849 <|editable_region_start|>
1850 lorem
1851 ipsum
1852 <|editable_region_end|>
1853 ```"};
1854
1855 let http_client = FakeHttpClient::create(move |_| async move {
1856 Ok(http_client::Response::builder()
1857 .status(200)
1858 .body(
1859 serde_json::to_string(&PredictEditsResponse {
1860 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1861 .unwrap(),
1862 output_excerpt: completion_response.to_string(),
1863 })
1864 .unwrap()
1865 .into(),
1866 )
1867 .unwrap())
1868 });
1869
1870 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1871 cx.update(|cx| {
1872 RefreshLlmTokenListener::register(client.clone(), cx);
1873 });
1874 let server = FakeServer::for_client(42, &client, cx).await;
1875 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1876 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1877
1878 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1879 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1880 let completion_task = zeta.update(cx, |zeta, cx| {
1881 zeta.request_completion(None, &buffer, cursor, false, cx)
1882 });
1883
1884 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1885 server.respond(
1886 token_request.receipt(),
1887 proto::GetLlmTokenResponse { token: "".into() },
1888 );
1889
1890 let completion = completion_task.await.unwrap().unwrap();
1891 buffer.update(cx, |buffer, cx| {
1892 buffer.edit(completion.edits.iter().cloned(), None, cx)
1893 });
1894 assert_eq!(
1895 buffer.read_with(cx, |buffer, _| buffer.text()),
1896 "lorem\nipsum"
1897 );
1898 }
1899
1900 async fn edits_for_prediction(
1901 buffer_content: &str,
1902 completion_response: &str,
1903 cx: &mut TestAppContext,
1904 ) -> Vec<(Range<Point>, String)> {
1905 let completion_response = completion_response.to_string();
1906 let http_client = FakeHttpClient::create(move |_| {
1907 let completion = completion_response.clone();
1908 async move {
1909 Ok(http_client::Response::builder()
1910 .status(200)
1911 .body(
1912 serde_json::to_string(&PredictEditsResponse {
1913 request_id: Uuid::new_v4(),
1914 output_excerpt: completion,
1915 })
1916 .unwrap()
1917 .into(),
1918 )
1919 .unwrap())
1920 }
1921 });
1922
1923 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1924 cx.update(|cx| {
1925 RefreshLlmTokenListener::register(client.clone(), cx);
1926 });
1927 let server = FakeServer::for_client(42, &client, cx).await;
1928 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1929 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1930
1931 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1932 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1933 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1934 let completion_task = zeta.update(cx, |zeta, cx| {
1935 zeta.request_completion(None, &buffer, cursor, false, cx)
1936 });
1937
1938 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1939 server.respond(
1940 token_request.receipt(),
1941 proto::GetLlmTokenResponse { token: "".into() },
1942 );
1943
1944 let completion = completion_task.await.unwrap().unwrap();
1945 completion
1946 .edits
1947 .into_iter()
1948 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1949 .collect::<Vec<_>>()
1950 }
1951
1952 fn to_completion_edits(
1953 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1954 buffer: &Entity<Buffer>,
1955 cx: &App,
1956 ) -> Vec<(Range<Anchor>, String)> {
1957 let buffer = buffer.read(cx);
1958 iterator
1959 .into_iter()
1960 .map(|(range, text)| {
1961 (
1962 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1963 text,
1964 )
1965 })
1966 .collect()
1967 }
1968
1969 fn from_completion_edits(
1970 editor_edits: &[(Range<Anchor>, String)],
1971 buffer: &Entity<Buffer>,
1972 cx: &App,
1973 ) -> Vec<(Range<usize>, String)> {
1974 let buffer = buffer.read(cx);
1975 editor_edits
1976 .iter()
1977 .map(|(range, text)| {
1978 (
1979 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1980 text.clone(),
1981 )
1982 })
1983 .collect()
1984 }
1985
1986 #[ctor::ctor]
1987 fn init_logger() {
1988 if std::env::var("RUST_LOG").is_ok() {
1989 env_logger::init();
1990 }
1991 }
1992}