1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::KEY_VALUE_STORE;
11pub use init::*;
12use inline_completion::DataCollectionState;
13use license_detection::LICENSE_FILES_TO_CHECK;
14pub use license_detection::is_license_eligible_for_data_collection;
15pub use rate_completion_modal::*;
16
17use anyhow::{Context as _, Result, anyhow};
18use arrayvec::ArrayVec;
19use client::{Client, UserStore};
20use collections::{HashMap, HashSet, VecDeque};
21use futures::AsyncReadExt;
22use gpui::{
23 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
24 Subscription, Task, WeakEntity, actions,
25};
26use http_client::{HttpClient, Method};
27use input_excerpt::excerpt_for_cursor_position;
28use language::{
29 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
30};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use postage::watch;
33use project::Project;
34use release_channel::AppVersion;
35use settings::WorktreeId;
36use std::str::FromStr;
37use std::{
38 borrow::Cow,
39 cmp,
40 fmt::Write,
41 future::Future,
42 mem,
43 ops::Range,
44 path::Path,
45 rc::Rc,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use telemetry_events::InlineCompletionRating;
50use thiserror::Error;
51use util::ResultExt;
52use uuid::Uuid;
53use workspace::Workspace;
54use workspace::notifications::{ErrorMessagePrompt, NotificationId};
55use worktree::Worktree;
56use zed_llm_client::{
57 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
58 PredictEditsResponse,
59};
60
61const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75actions!(edit_prediction, [ClearHistory]);
76
77#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
78pub struct InlineCompletionId(Uuid);
79
80impl From<InlineCompletionId> for gpui::ElementId {
81 fn from(value: InlineCompletionId) -> Self {
82 gpui::ElementId::Uuid(value.0)
83 }
84}
85
86impl std::fmt::Display for InlineCompletionId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Clone)]
93struct ZetaGlobal(Entity<Zeta>);
94
95impl Global for ZetaGlobal {}
96
97#[derive(Clone)]
98pub struct InlineCompletion {
99 id: InlineCompletionId,
100 path: Arc<Path>,
101 excerpt_range: Range<usize>,
102 cursor_offset: usize,
103 edits: Arc<[(Range<Anchor>, String)]>,
104 snapshot: BufferSnapshot,
105 edit_preview: EditPreview,
106 input_outline: Arc<str>,
107 input_events: Arc<str>,
108 input_excerpt: Arc<str>,
109 output_excerpt: Arc<str>,
110 request_sent_at: Instant,
111 response_received_at: Instant,
112}
113
114impl InlineCompletion {
115 fn latency(&self) -> Duration {
116 self.response_received_at
117 .duration_since(self.request_sent_at)
118 }
119
120 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
121 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
122 }
123}
124
125fn interpolate(
126 old_snapshot: &BufferSnapshot,
127 new_snapshot: &BufferSnapshot,
128 current_edits: Arc<[(Range<Anchor>, String)]>,
129) -> Option<Vec<(Range<Anchor>, String)>> {
130 let mut edits = Vec::new();
131
132 let mut model_edits = current_edits.into_iter().peekable();
133 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
134 while let Some((model_old_range, _)) = model_edits.peek() {
135 let model_old_range = model_old_range.to_offset(old_snapshot);
136 if model_old_range.end < user_edit.old.start {
137 let (model_old_range, model_new_text) = model_edits.next().unwrap();
138 edits.push((model_old_range.clone(), model_new_text.clone()));
139 } else {
140 break;
141 }
142 }
143
144 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
145 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
146 if user_edit.old == model_old_offset_range {
147 let user_new_text = new_snapshot
148 .text_for_range(user_edit.new.clone())
149 .collect::<String>();
150
151 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
152 if !model_suffix.is_empty() {
153 let anchor = old_snapshot.anchor_after(user_edit.old.end);
154 edits.push((anchor..anchor, model_suffix.to_string()));
155 }
156
157 model_edits.next();
158 continue;
159 }
160 }
161 }
162
163 return None;
164 }
165
166 edits.extend(model_edits.cloned());
167
168 if edits.is_empty() { None } else { Some(edits) }
169}
170
171impl std::fmt::Debug for InlineCompletion {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("InlineCompletion")
174 .field("id", &self.id)
175 .field("path", &self.path)
176 .field("edits", &self.edits)
177 .finish_non_exhaustive()
178 }
179}
180
181pub struct Zeta {
182 workspace: Option<WeakEntity<Workspace>>,
183 client: Arc<Client>,
184 events: VecDeque<Event>,
185 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
186 shown_completions: VecDeque<InlineCompletion>,
187 rated_completions: HashSet<InlineCompletionId>,
188 data_collection_choice: Entity<DataCollectionChoice>,
189 llm_token: LlmApiToken,
190 _llm_token_subscription: Subscription,
191 /// Whether the terms of service have been accepted.
192 tos_accepted: bool,
193 /// Whether an update to a newer version of Zed is required to continue using Zeta.
194 update_required: bool,
195 _user_store_subscription: Subscription,
196 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
197}
198
199impl Zeta {
200 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
201 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
202 }
203
204 pub fn register(
205 workspace: Option<WeakEntity<Workspace>>,
206 worktree: Option<Entity<Worktree>>,
207 client: Arc<Client>,
208 user_store: Entity<UserStore>,
209 cx: &mut App,
210 ) -> Entity<Self> {
211 let this = Self::global(cx).unwrap_or_else(|| {
212 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
213 cx.set_global(ZetaGlobal(entity.clone()));
214 entity
215 });
216
217 this.update(cx, move |this, cx| {
218 if let Some(worktree) = worktree {
219 worktree.update(cx, |worktree, cx| {
220 this.license_detection_watchers
221 .entry(worktree.id())
222 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
223 });
224 }
225 });
226
227 this
228 }
229
230 pub fn clear_history(&mut self) {
231 self.events.clear();
232 }
233
234 fn new(
235 workspace: Option<WeakEntity<Workspace>>,
236 client: Arc<Client>,
237 user_store: Entity<UserStore>,
238 cx: &mut Context<Self>,
239 ) -> Self {
240 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
241
242 let data_collection_choice = Self::load_data_collection_choices();
243 let data_collection_choice = cx.new(|_| data_collection_choice);
244
245 Self {
246 workspace,
247 client,
248 events: VecDeque::new(),
249 shown_completions: VecDeque::new(),
250 rated_completions: HashSet::default(),
251 registered_buffers: HashMap::default(),
252 data_collection_choice,
253 llm_token: LlmApiToken::default(),
254 _llm_token_subscription: cx.subscribe(
255 &refresh_llm_token_listener,
256 |this, _listener, _event, cx| {
257 let client = this.client.clone();
258 let llm_token = this.llm_token.clone();
259 cx.spawn(async move |_this, _cx| {
260 llm_token.refresh(&client).await?;
261 anyhow::Ok(())
262 })
263 .detach_and_log_err(cx);
264 },
265 ),
266 tos_accepted: user_store
267 .read(cx)
268 .current_user_has_accepted_terms()
269 .unwrap_or(false),
270 update_required: false,
271 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
272 match event {
273 client::user::Event::PrivateUserInfoUpdated => {
274 this.tos_accepted = user_store
275 .read(cx)
276 .current_user_has_accepted_terms()
277 .unwrap_or(false);
278 }
279 _ => {}
280 }
281 }),
282 license_detection_watchers: HashMap::default(),
283 }
284 }
285
286 fn push_event(&mut self, event: Event) {
287 if let Some(Event::BufferChange {
288 new_snapshot: last_new_snapshot,
289 timestamp: last_timestamp,
290 ..
291 }) = self.events.back_mut()
292 {
293 // Coalesce edits for the same buffer when they happen one after the other.
294 let Event::BufferChange {
295 old_snapshot,
296 new_snapshot,
297 timestamp,
298 } = &event;
299
300 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
301 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
302 && old_snapshot.version == last_new_snapshot.version
303 {
304 *last_new_snapshot = new_snapshot.clone();
305 *last_timestamp = *timestamp;
306 return;
307 }
308 }
309
310 self.events.push_back(event);
311 if self.events.len() >= MAX_EVENT_COUNT {
312 self.events.drain(..MAX_EVENT_COUNT / 2);
313 }
314 }
315
316 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
317 let buffer_id = buffer.entity_id();
318 let weak_buffer = buffer.downgrade();
319
320 if let std::collections::hash_map::Entry::Vacant(entry) =
321 self.registered_buffers.entry(buffer_id)
322 {
323 let snapshot = buffer.read(cx).snapshot();
324
325 entry.insert(RegisteredBuffer {
326 snapshot,
327 _subscriptions: [
328 cx.subscribe(buffer, move |this, buffer, event, cx| {
329 this.handle_buffer_event(buffer, event, cx);
330 }),
331 cx.observe_release(buffer, move |this, _buffer, _cx| {
332 this.registered_buffers.remove(&weak_buffer.entity_id());
333 }),
334 ],
335 });
336 };
337 }
338
339 fn handle_buffer_event(
340 &mut self,
341 buffer: Entity<Buffer>,
342 event: &language::BufferEvent,
343 cx: &mut Context<Self>,
344 ) {
345 if let language::BufferEvent::Edited = event {
346 self.report_changes_for_buffer(&buffer, cx);
347 }
348 }
349
350 fn request_completion_impl<F, R>(
351 &mut self,
352 workspace: Option<Entity<Workspace>>,
353 project: Option<&Entity<Project>>,
354 buffer: &Entity<Buffer>,
355 cursor: language::Anchor,
356 can_collect_data: bool,
357 cx: &mut Context<Self>,
358 perform_predict_edits: F,
359 ) -> Task<Result<Option<InlineCompletion>>>
360 where
361 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
362 R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
363 {
364 let snapshot = self.report_changes_for_buffer(&buffer, cx);
365 let diagnostic_groups = snapshot.diagnostic_groups(None);
366 let cursor_point = cursor.to_point(&snapshot);
367 let cursor_offset = cursor_point.to_offset(&snapshot);
368 let events = self.events.clone();
369 let path: Arc<Path> = snapshot
370 .file()
371 .map(|f| Arc::from(f.full_path(cx).as_path()))
372 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
373
374 let zeta = cx.entity();
375 let client = self.client.clone();
376 let llm_token = self.llm_token.clone();
377 let app_version = AppVersion::global(cx);
378
379 let buffer = buffer.clone();
380
381 let local_lsp_store =
382 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
383 let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
384 Some(
385 diagnostic_groups
386 .into_iter()
387 .filter_map(|(language_server_id, diagnostic_group)| {
388 let language_server =
389 local_lsp_store.running_language_server_for_id(language_server_id)?;
390
391 Some((
392 language_server.name(),
393 diagnostic_group.resolve::<usize>(&snapshot),
394 ))
395 })
396 .collect::<Vec<_>>(),
397 )
398 } else {
399 None
400 };
401
402 cx.spawn(async move |_, cx| {
403 let request_sent_at = Instant::now();
404
405 struct BackgroundValues {
406 input_events: String,
407 input_excerpt: String,
408 speculated_output: String,
409 editable_range: Range<usize>,
410 input_outline: String,
411 }
412
413 let values = cx
414 .background_spawn({
415 let snapshot = snapshot.clone();
416 let path = path.clone();
417 async move {
418 let path = path.to_string_lossy();
419 let input_excerpt = excerpt_for_cursor_position(
420 cursor_point,
421 &path,
422 &snapshot,
423 MAX_REWRITE_TOKENS,
424 MAX_CONTEXT_TOKENS,
425 );
426 let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
427 let input_outline = prompt_for_outline(&snapshot);
428
429 anyhow::Ok(BackgroundValues {
430 input_events,
431 input_excerpt: input_excerpt.prompt,
432 speculated_output: input_excerpt.speculated_output,
433 editable_range: input_excerpt.editable_range.to_offset(&snapshot),
434 input_outline,
435 })
436 }
437 })
438 .await?;
439
440 log::debug!(
441 "Events:\n{}\nExcerpt:\n{:?}",
442 values.input_events,
443 values.input_excerpt
444 );
445
446 let body = PredictEditsBody {
447 input_events: values.input_events.clone(),
448 input_excerpt: values.input_excerpt.clone(),
449 speculated_output: Some(values.speculated_output),
450 outline: Some(values.input_outline.clone()),
451 can_collect_data,
452 diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
453 diagnostic_groups
454 .into_iter()
455 .map(|(name, diagnostic_group)| {
456 Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
457 })
458 .collect::<Result<Vec<_>>>()
459 .log_err()
460 }),
461 };
462
463 let response = perform_predict_edits(PerformPredictEditsParams {
464 client,
465 llm_token,
466 app_version,
467 body,
468 })
469 .await;
470 let response = match response {
471 Ok(response) => response,
472 Err(err) => {
473 if err.is::<ZedUpdateRequiredError>() {
474 cx.update(|cx| {
475 zeta.update(cx, |zeta, _cx| {
476 zeta.update_required = true;
477 });
478
479 if let Some(workspace) = workspace {
480 workspace.update(cx, |workspace, cx| {
481 workspace.show_notification(
482 NotificationId::unique::<ZedUpdateRequiredError>(),
483 cx,
484 |cx| {
485 cx.new(|cx| {
486 ErrorMessagePrompt::new(err.to_string(), cx)
487 .with_link_button(
488 "Update Zed",
489 "https://zed.dev/releases",
490 )
491 })
492 },
493 );
494 });
495 }
496 })
497 .ok();
498 }
499
500 return Err(err);
501 }
502 };
503
504 log::debug!("completion response: {}", &response.output_excerpt);
505
506 Self::process_completion_response(
507 response,
508 buffer,
509 &snapshot,
510 values.editable_range,
511 cursor_offset,
512 path,
513 values.input_outline,
514 values.input_events,
515 values.input_excerpt,
516 request_sent_at,
517 &cx,
518 )
519 .await
520 })
521 }
522
523 // Generates several example completions of various states to fill the Zeta completion modal
524 #[cfg(any(test, feature = "test-support"))]
525 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
526 use language::Point;
527
528 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
529 And maybe a short line
530
531 Then a few lines
532
533 and then another
534 "#};
535
536 let project = None;
537 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
538 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
539
540 let completion_tasks = vec![
541 self.fake_completion(
542 project,
543 &buffer,
544 position,
545 PredictEditsResponse {
546 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
547 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
548a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
549[here's an edit]
550And maybe a short line
551Then a few lines
552and then another
553{EDITABLE_REGION_END_MARKER}
554 ", ),
555 },
556 cx,
557 ),
558 self.fake_completion(
559 project,
560 &buffer,
561 position,
562 PredictEditsResponse {
563 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
564 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
565a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
566And maybe a short line
567[and another edit]
568Then a few lines
569and then another
570{EDITABLE_REGION_END_MARKER}
571 "#),
572 },
573 cx,
574 ),
575 self.fake_completion(
576 project,
577 &buffer,
578 position,
579 PredictEditsResponse {
580 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
581 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
582a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
583And maybe a short line
584
585Then a few lines
586
587and then another
588{EDITABLE_REGION_END_MARKER}
589 "#),
590 },
591 cx,
592 ),
593 self.fake_completion(
594 project,
595 &buffer,
596 position,
597 PredictEditsResponse {
598 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
599 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
600a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
601And maybe a short line
602
603Then a few lines
604
605and then another
606{EDITABLE_REGION_END_MARKER}
607 "#),
608 },
609 cx,
610 ),
611 self.fake_completion(
612 project,
613 &buffer,
614 position,
615 PredictEditsResponse {
616 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
617 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
618a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
619And maybe a short line
620Then a few lines
621[a third completion]
622and then another
623{EDITABLE_REGION_END_MARKER}
624 "#),
625 },
626 cx,
627 ),
628 self.fake_completion(
629 project,
630 &buffer,
631 position,
632 PredictEditsResponse {
633 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
634 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
635a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
636And maybe a short line
637and then another
638[fourth completion example]
639{EDITABLE_REGION_END_MARKER}
640 "#),
641 },
642 cx,
643 ),
644 self.fake_completion(
645 project,
646 &buffer,
647 position,
648 PredictEditsResponse {
649 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
650 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
652And maybe a short line
653Then a few lines
654and then another
655[fifth and final completion]
656{EDITABLE_REGION_END_MARKER}
657 "#),
658 },
659 cx,
660 ),
661 ];
662
663 cx.spawn(async move |zeta, cx| {
664 for task in completion_tasks {
665 task.await.unwrap();
666 }
667
668 zeta.update(cx, |zeta, _cx| {
669 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
670 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
671 })
672 .ok();
673 })
674 }
675
676 #[cfg(any(test, feature = "test-support"))]
677 pub fn fake_completion(
678 &mut self,
679 project: Option<&Entity<Project>>,
680 buffer: &Entity<Buffer>,
681 position: language::Anchor,
682 response: PredictEditsResponse,
683 cx: &mut Context<Self>,
684 ) -> Task<Result<Option<InlineCompletion>>> {
685 use std::future::ready;
686
687 self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
688 ready(Ok(response))
689 })
690 }
691
692 pub fn request_completion(
693 &mut self,
694 project: Option<&Entity<Project>>,
695 buffer: &Entity<Buffer>,
696 position: language::Anchor,
697 can_collect_data: bool,
698 cx: &mut Context<Self>,
699 ) -> Task<Result<Option<InlineCompletion>>> {
700 let workspace = self
701 .workspace
702 .as_ref()
703 .and_then(|workspace| workspace.upgrade());
704 self.request_completion_impl(
705 workspace,
706 project,
707 buffer,
708 position,
709 can_collect_data,
710 cx,
711 Self::perform_predict_edits,
712 )
713 }
714
715 fn perform_predict_edits(
716 params: PerformPredictEditsParams,
717 ) -> impl Future<Output = Result<PredictEditsResponse>> {
718 async move {
719 let PerformPredictEditsParams {
720 client,
721 llm_token,
722 app_version,
723 body,
724 ..
725 } = params;
726
727 let http_client = client.http_client();
728 let mut token = llm_token.acquire(&client).await?;
729 let mut did_retry = false;
730
731 loop {
732 let request_builder = http_client::Request::builder().method(Method::POST);
733 let request_builder =
734 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
735 request_builder.uri(predict_edits_url)
736 } else {
737 request_builder.uri(
738 http_client
739 .build_zed_llm_url("/predict_edits/v2", &[])?
740 .as_ref(),
741 )
742 };
743 let request = request_builder
744 .header("Content-Type", "application/json")
745 .header("Authorization", format!("Bearer {}", token))
746 .body(serde_json::to_string(&body)?.into())?;
747
748 let mut response = http_client.send(request).await?;
749
750 if let Some(minimum_required_version) = response
751 .headers()
752 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
753 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
754 {
755 if app_version < minimum_required_version {
756 return Err(anyhow!(ZedUpdateRequiredError {
757 minimum_version: minimum_required_version
758 }));
759 }
760 }
761
762 if response.status().is_success() {
763 let mut body = String::new();
764 response.body_mut().read_to_string(&mut body).await?;
765 return Ok(serde_json::from_str(&body)?);
766 } else if !did_retry
767 && response
768 .headers()
769 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
770 .is_some()
771 {
772 did_retry = true;
773 token = llm_token.refresh(&client).await?;
774 } else {
775 let mut body = String::new();
776 response.body_mut().read_to_string(&mut body).await?;
777 return Err(anyhow!(
778 "error predicting edits.\nStatus: {:?}\nBody: {}",
779 response.status(),
780 body
781 ));
782 }
783 }
784 }
785 }
786
787 fn process_completion_response(
788 prediction_response: PredictEditsResponse,
789 buffer: Entity<Buffer>,
790 snapshot: &BufferSnapshot,
791 editable_range: Range<usize>,
792 cursor_offset: usize,
793 path: Arc<Path>,
794 input_outline: String,
795 input_events: String,
796 input_excerpt: String,
797 request_sent_at: Instant,
798 cx: &AsyncApp,
799 ) -> Task<Result<Option<InlineCompletion>>> {
800 let snapshot = snapshot.clone();
801 let request_id = prediction_response.request_id;
802 let output_excerpt = prediction_response.output_excerpt;
803 cx.spawn(async move |cx| {
804 let output_excerpt: Arc<str> = output_excerpt.into();
805
806 let edits: Arc<[(Range<Anchor>, String)]> = cx
807 .background_spawn({
808 let output_excerpt = output_excerpt.clone();
809 let editable_range = editable_range.clone();
810 let snapshot = snapshot.clone();
811 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
812 })
813 .await?
814 .into();
815
816 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
817 let edits = edits.clone();
818 |buffer, cx| {
819 let new_snapshot = buffer.snapshot();
820 let edits: Arc<[(Range<Anchor>, String)]> =
821 interpolate(&snapshot, &new_snapshot, edits)?.into();
822 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
823 }
824 })?
825 else {
826 return anyhow::Ok(None);
827 };
828
829 let edit_preview = edit_preview.await;
830
831 Ok(Some(InlineCompletion {
832 id: InlineCompletionId(request_id),
833 path,
834 excerpt_range: editable_range,
835 cursor_offset,
836 edits,
837 edit_preview,
838 snapshot,
839 input_outline: input_outline.into(),
840 input_events: input_events.into(),
841 input_excerpt: input_excerpt.into(),
842 output_excerpt,
843 request_sent_at,
844 response_received_at: Instant::now(),
845 }))
846 })
847 }
848
849 fn parse_edits(
850 output_excerpt: Arc<str>,
851 editable_range: Range<usize>,
852 snapshot: &BufferSnapshot,
853 ) -> Result<Vec<(Range<Anchor>, String)>> {
854 let content = output_excerpt.replace(CURSOR_MARKER, "");
855
856 let start_markers = content
857 .match_indices(EDITABLE_REGION_START_MARKER)
858 .collect::<Vec<_>>();
859 anyhow::ensure!(
860 start_markers.len() == 1,
861 "expected exactly one start marker, found {}",
862 start_markers.len()
863 );
864
865 let end_markers = content
866 .match_indices(EDITABLE_REGION_END_MARKER)
867 .collect::<Vec<_>>();
868 anyhow::ensure!(
869 end_markers.len() == 1,
870 "expected exactly one end marker, found {}",
871 end_markers.len()
872 );
873
874 let sof_markers = content
875 .match_indices(START_OF_FILE_MARKER)
876 .collect::<Vec<_>>();
877 anyhow::ensure!(
878 sof_markers.len() <= 1,
879 "expected at most one start-of-file marker, found {}",
880 sof_markers.len()
881 );
882
883 let codefence_start = start_markers[0].0;
884 let content = &content[codefence_start..];
885
886 let newline_ix = content.find('\n').context("could not find newline")?;
887 let content = &content[newline_ix + 1..];
888
889 let codefence_end = content
890 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
891 .context("could not find end marker")?;
892 let new_text = &content[..codefence_end];
893
894 let old_text = snapshot
895 .text_for_range(editable_range.clone())
896 .collect::<String>();
897
898 Ok(Self::compute_edits(
899 old_text,
900 new_text,
901 editable_range.start,
902 &snapshot,
903 ))
904 }
905
906 pub fn compute_edits(
907 old_text: String,
908 new_text: &str,
909 offset: usize,
910 snapshot: &BufferSnapshot,
911 ) -> Vec<(Range<Anchor>, String)> {
912 text_diff(&old_text, &new_text)
913 .into_iter()
914 .map(|(mut old_range, new_text)| {
915 old_range.start += offset;
916 old_range.end += offset;
917
918 let prefix_len = common_prefix(
919 snapshot.chars_for_range(old_range.clone()),
920 new_text.chars(),
921 );
922 old_range.start += prefix_len;
923
924 let suffix_len = common_prefix(
925 snapshot.reversed_chars_for_range(old_range.clone()),
926 new_text[prefix_len..].chars().rev(),
927 );
928 old_range.end = old_range.end.saturating_sub(suffix_len);
929
930 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
931 let range = if old_range.is_empty() {
932 let anchor = snapshot.anchor_after(old_range.start);
933 anchor..anchor
934 } else {
935 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
936 };
937 (range, new_text)
938 })
939 .collect()
940 }
941
942 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
943 self.rated_completions.contains(&completion_id)
944 }
945
946 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
947 self.shown_completions.push_front(completion.clone());
948 if self.shown_completions.len() > 50 {
949 let completion = self.shown_completions.pop_back().unwrap();
950 self.rated_completions.remove(&completion.id);
951 }
952 cx.notify();
953 }
954
955 pub fn rate_completion(
956 &mut self,
957 completion: &InlineCompletion,
958 rating: InlineCompletionRating,
959 feedback: String,
960 cx: &mut Context<Self>,
961 ) {
962 self.rated_completions.insert(completion.id);
963 telemetry::event!(
964 "Edit Prediction Rated",
965 rating,
966 input_events = completion.input_events,
967 input_excerpt = completion.input_excerpt,
968 input_outline = completion.input_outline,
969 output_excerpt = completion.output_excerpt,
970 feedback
971 );
972 self.client.telemetry().flush_events();
973 cx.notify();
974 }
975
976 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
977 self.shown_completions.iter()
978 }
979
980 pub fn shown_completions_len(&self) -> usize {
981 self.shown_completions.len()
982 }
983
984 fn report_changes_for_buffer(
985 &mut self,
986 buffer: &Entity<Buffer>,
987 cx: &mut Context<Self>,
988 ) -> BufferSnapshot {
989 self.register_buffer(buffer, cx);
990
991 let registered_buffer = self
992 .registered_buffers
993 .get_mut(&buffer.entity_id())
994 .unwrap();
995 let new_snapshot = buffer.read(cx).snapshot();
996
997 if new_snapshot.version != registered_buffer.snapshot.version {
998 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
999 self.push_event(Event::BufferChange {
1000 old_snapshot,
1001 new_snapshot: new_snapshot.clone(),
1002 timestamp: Instant::now(),
1003 });
1004 }
1005
1006 new_snapshot
1007 }
1008
1009 fn load_data_collection_choices() -> DataCollectionChoice {
1010 let choice = KEY_VALUE_STORE
1011 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1012 .log_err()
1013 .flatten();
1014
1015 match choice.as_deref() {
1016 Some("true") => DataCollectionChoice::Enabled,
1017 Some("false") => DataCollectionChoice::Disabled,
1018 Some(_) => {
1019 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1020 DataCollectionChoice::NotAnswered
1021 }
1022 None => DataCollectionChoice::NotAnswered,
1023 }
1024 }
1025}
1026
1027struct PerformPredictEditsParams {
1028 pub client: Arc<Client>,
1029 pub llm_token: LlmApiToken,
1030 pub app_version: SemanticVersion,
1031 pub body: PredictEditsBody,
1032}
1033
1034#[derive(Error, Debug)]
1035#[error(
1036 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1037)]
1038pub struct ZedUpdateRequiredError {
1039 minimum_version: SemanticVersion,
1040}
1041
1042struct LicenseDetectionWatcher {
1043 is_open_source_rx: watch::Receiver<bool>,
1044 _is_open_source_task: Task<()>,
1045}
1046
1047impl LicenseDetectionWatcher {
1048 pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1049 let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1050
1051 // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1052 let task = if worktree.abs_path().is_file() {
1053 Task::ready(())
1054 } else {
1055 let loaded_files = LICENSE_FILES_TO_CHECK
1056 .iter()
1057 .map(Path::new)
1058 .map(|file| worktree.load_file(file, cx))
1059 .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1060
1061 cx.background_spawn(async move {
1062 for loaded_file in loaded_files.into_iter() {
1063 let Ok(loaded_file) = loaded_file.await else {
1064 continue;
1065 };
1066
1067 let path = &loaded_file.file.path;
1068 if is_license_eligible_for_data_collection(&loaded_file.text) {
1069 log::info!("detected '{path:?}' as open source license");
1070 *is_open_source_tx.borrow_mut() = true;
1071 } else {
1072 log::info!("didn't detect '{path:?}' as open source license");
1073 }
1074
1075 // stop on the first license that successfully read
1076 return;
1077 }
1078
1079 log::debug!("didn't find a license file to check, assuming closed source");
1080 })
1081 };
1082
1083 Self {
1084 is_open_source_rx,
1085 _is_open_source_task: task,
1086 }
1087 }
1088
1089 /// Answers false until we find out it's open source
1090 pub fn is_project_open_source(&self) -> bool {
1091 *self.is_open_source_rx.borrow()
1092 }
1093}
1094
1095fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1096 a.zip(b)
1097 .take_while(|(a, b)| a == b)
1098 .map(|(a, _)| a.len_utf8())
1099 .sum()
1100}
1101
1102fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1103 let mut input_outline = String::new();
1104
1105 writeln!(
1106 input_outline,
1107 "```{}",
1108 snapshot
1109 .file()
1110 .map_or(Cow::Borrowed("untitled"), |file| file
1111 .path()
1112 .to_string_lossy())
1113 )
1114 .unwrap();
1115
1116 if let Some(outline) = snapshot.outline(None) {
1117 for item in &outline.items {
1118 let spacing = " ".repeat(item.depth);
1119 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1120 }
1121 }
1122
1123 writeln!(input_outline, "```").unwrap();
1124
1125 input_outline
1126}
1127
1128fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1129 let mut result = String::new();
1130 for event in events.iter().rev() {
1131 let event_string = event.to_prompt();
1132 let event_tokens = tokens_for_bytes(event_string.len());
1133 if event_tokens > remaining_tokens {
1134 break;
1135 }
1136
1137 if !result.is_empty() {
1138 result.insert_str(0, "\n\n");
1139 }
1140 result.insert_str(0, &event_string);
1141 remaining_tokens -= event_tokens;
1142 }
1143 result
1144}
1145
1146struct RegisteredBuffer {
1147 snapshot: BufferSnapshot,
1148 _subscriptions: [gpui::Subscription; 2],
1149}
1150
1151#[derive(Clone)]
1152enum Event {
1153 BufferChange {
1154 old_snapshot: BufferSnapshot,
1155 new_snapshot: BufferSnapshot,
1156 timestamp: Instant,
1157 },
1158}
1159
1160impl Event {
1161 fn to_prompt(&self) -> String {
1162 match self {
1163 Event::BufferChange {
1164 old_snapshot,
1165 new_snapshot,
1166 ..
1167 } => {
1168 let mut prompt = String::new();
1169
1170 let old_path = old_snapshot
1171 .file()
1172 .map(|f| f.path().as_ref())
1173 .unwrap_or(Path::new("untitled"));
1174 let new_path = new_snapshot
1175 .file()
1176 .map(|f| f.path().as_ref())
1177 .unwrap_or(Path::new("untitled"));
1178 if old_path != new_path {
1179 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1180 }
1181
1182 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1183 if !diff.is_empty() {
1184 write!(
1185 prompt,
1186 "User edited {:?}:\n```diff\n{}\n```",
1187 new_path, diff
1188 )
1189 .unwrap();
1190 }
1191
1192 prompt
1193 }
1194 }
1195 }
1196}
1197
1198#[derive(Debug, Clone)]
1199struct CurrentInlineCompletion {
1200 buffer_id: EntityId,
1201 completion: InlineCompletion,
1202}
1203
1204impl CurrentInlineCompletion {
1205 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1206 if self.buffer_id != old_completion.buffer_id {
1207 return true;
1208 }
1209
1210 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1211 return true;
1212 };
1213 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1214 return false;
1215 };
1216
1217 if old_edits.len() == 1 && new_edits.len() == 1 {
1218 let (old_range, old_text) = &old_edits[0];
1219 let (new_range, new_text) = &new_edits[0];
1220 new_range == old_range && new_text.starts_with(old_text)
1221 } else {
1222 true
1223 }
1224 }
1225}
1226
1227struct PendingCompletion {
1228 id: usize,
1229 _task: Task<()>,
1230}
1231
1232#[derive(Debug, Clone, Copy)]
1233pub enum DataCollectionChoice {
1234 NotAnswered,
1235 Enabled,
1236 Disabled,
1237}
1238
1239impl DataCollectionChoice {
1240 pub fn is_enabled(self) -> bool {
1241 match self {
1242 Self::Enabled => true,
1243 Self::NotAnswered | Self::Disabled => false,
1244 }
1245 }
1246
1247 pub fn is_answered(self) -> bool {
1248 match self {
1249 Self::Enabled | Self::Disabled => true,
1250 Self::NotAnswered => false,
1251 }
1252 }
1253
1254 pub fn toggle(&self) -> DataCollectionChoice {
1255 match self {
1256 Self::Enabled => Self::Disabled,
1257 Self::Disabled => Self::Enabled,
1258 Self::NotAnswered => Self::Enabled,
1259 }
1260 }
1261}
1262
1263impl From<bool> for DataCollectionChoice {
1264 fn from(value: bool) -> Self {
1265 match value {
1266 true => DataCollectionChoice::Enabled,
1267 false => DataCollectionChoice::Disabled,
1268 }
1269 }
1270}
1271
1272pub struct ProviderDataCollection {
1273 /// When set to None, data collection is not possible in the provider buffer
1274 choice: Option<Entity<DataCollectionChoice>>,
1275 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1276}
1277
1278impl ProviderDataCollection {
1279 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1280 let choice_and_watcher = buffer.and_then(|buffer| {
1281 let file = buffer.read(cx).file()?;
1282
1283 if !file.is_local() || file.is_private() {
1284 return None;
1285 }
1286
1287 let zeta = zeta.read(cx);
1288 let choice = zeta.data_collection_choice.clone();
1289
1290 let license_detection_watcher = zeta
1291 .license_detection_watchers
1292 .get(&file.worktree_id(cx))
1293 .cloned()?;
1294
1295 Some((choice, license_detection_watcher))
1296 });
1297
1298 if let Some((choice, watcher)) = choice_and_watcher {
1299 ProviderDataCollection {
1300 choice: Some(choice),
1301 license_detection_watcher: Some(watcher),
1302 }
1303 } else {
1304 ProviderDataCollection {
1305 choice: None,
1306 license_detection_watcher: None,
1307 }
1308 }
1309 }
1310
1311 pub fn can_collect_data(&self, cx: &App) -> bool {
1312 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1313 }
1314
1315 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1316 self.choice
1317 .as_ref()
1318 .is_some_and(|choice| choice.read(cx).is_enabled())
1319 }
1320
1321 fn is_project_open_source(&self) -> bool {
1322 self.license_detection_watcher
1323 .as_ref()
1324 .is_some_and(|watcher| watcher.is_project_open_source())
1325 }
1326
1327 pub fn toggle(&mut self, cx: &mut App) {
1328 if let Some(choice) = self.choice.as_mut() {
1329 let new_choice = choice.update(cx, |choice, _cx| {
1330 let new_choice = choice.toggle();
1331 *choice = new_choice;
1332 new_choice
1333 });
1334
1335 db::write_and_log(cx, move || {
1336 KEY_VALUE_STORE.write_kvp(
1337 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1338 new_choice.is_enabled().to_string(),
1339 )
1340 });
1341 }
1342 }
1343}
1344
1345pub struct ZetaInlineCompletionProvider {
1346 zeta: Entity<Zeta>,
1347 pending_completions: ArrayVec<PendingCompletion, 2>,
1348 next_pending_completion_id: usize,
1349 current_completion: Option<CurrentInlineCompletion>,
1350 /// None if this is entirely disabled for this provider
1351 provider_data_collection: ProviderDataCollection,
1352 last_request_timestamp: Instant,
1353}
1354
1355impl ZetaInlineCompletionProvider {
1356 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1357
1358 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1359 Self {
1360 zeta,
1361 pending_completions: ArrayVec::new(),
1362 next_pending_completion_id: 0,
1363 current_completion: None,
1364 provider_data_collection,
1365 last_request_timestamp: Instant::now(),
1366 }
1367 }
1368}
1369
1370impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1371 fn name() -> &'static str {
1372 "zed-predict"
1373 }
1374
1375 fn display_name() -> &'static str {
1376 "Zed's Edit Predictions"
1377 }
1378
1379 fn show_completions_in_menu() -> bool {
1380 true
1381 }
1382
1383 fn show_tab_accept_marker() -> bool {
1384 true
1385 }
1386
1387 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1388 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1389
1390 if self.provider_data_collection.is_data_collection_enabled(cx) {
1391 DataCollectionState::Enabled {
1392 is_project_open_source,
1393 }
1394 } else {
1395 DataCollectionState::Disabled {
1396 is_project_open_source,
1397 }
1398 }
1399 }
1400
1401 fn toggle_data_collection(&mut self, cx: &mut App) {
1402 self.provider_data_collection.toggle(cx);
1403 }
1404
1405 fn is_enabled(
1406 &self,
1407 _buffer: &Entity<Buffer>,
1408 _cursor_position: language::Anchor,
1409 _cx: &App,
1410 ) -> bool {
1411 true
1412 }
1413
1414 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1415 !self.zeta.read(cx).tos_accepted
1416 }
1417
1418 fn is_refreshing(&self) -> bool {
1419 !self.pending_completions.is_empty()
1420 }
1421
1422 fn refresh(
1423 &mut self,
1424 project: Option<Entity<Project>>,
1425 buffer: Entity<Buffer>,
1426 position: language::Anchor,
1427 _debounce: bool,
1428 cx: &mut Context<Self>,
1429 ) {
1430 if !self.zeta.read(cx).tos_accepted {
1431 return;
1432 }
1433
1434 if self.zeta.read(cx).update_required {
1435 return;
1436 }
1437
1438 if let Some(current_completion) = self.current_completion.as_ref() {
1439 let snapshot = buffer.read(cx).snapshot();
1440 if current_completion
1441 .completion
1442 .interpolate(&snapshot)
1443 .is_some()
1444 {
1445 return;
1446 }
1447 }
1448
1449 let pending_completion_id = self.next_pending_completion_id;
1450 self.next_pending_completion_id += 1;
1451 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1452 let last_request_timestamp = self.last_request_timestamp;
1453
1454 let task = cx.spawn(async move |this, cx| {
1455 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1456 .checked_duration_since(Instant::now())
1457 {
1458 cx.background_executor().timer(timeout).await;
1459 }
1460
1461 let completion_request = this.update(cx, |this, cx| {
1462 this.last_request_timestamp = Instant::now();
1463 this.zeta.update(cx, |zeta, cx| {
1464 zeta.request_completion(
1465 project.as_ref(),
1466 &buffer,
1467 position,
1468 can_collect_data,
1469 cx,
1470 )
1471 })
1472 });
1473
1474 let completion = match completion_request {
1475 Ok(completion_request) => {
1476 let completion_request = completion_request.await;
1477 completion_request.map(|c| {
1478 c.map(|completion| CurrentInlineCompletion {
1479 buffer_id: buffer.entity_id(),
1480 completion,
1481 })
1482 })
1483 }
1484 Err(error) => Err(error),
1485 };
1486 let Some(new_completion) = completion
1487 .context("edit prediction failed")
1488 .log_err()
1489 .flatten()
1490 else {
1491 this.update(cx, |this, cx| {
1492 if this.pending_completions[0].id == pending_completion_id {
1493 this.pending_completions.remove(0);
1494 } else {
1495 this.pending_completions.clear();
1496 }
1497
1498 cx.notify();
1499 })
1500 .ok();
1501 return;
1502 };
1503
1504 this.update(cx, |this, cx| {
1505 if this.pending_completions[0].id == pending_completion_id {
1506 this.pending_completions.remove(0);
1507 } else {
1508 this.pending_completions.clear();
1509 }
1510
1511 if let Some(old_completion) = this.current_completion.as_ref() {
1512 let snapshot = buffer.read(cx).snapshot();
1513 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1514 this.zeta.update(cx, |zeta, cx| {
1515 zeta.completion_shown(&new_completion.completion, cx);
1516 });
1517 this.current_completion = Some(new_completion);
1518 }
1519 } else {
1520 this.zeta.update(cx, |zeta, cx| {
1521 zeta.completion_shown(&new_completion.completion, cx);
1522 });
1523 this.current_completion = Some(new_completion);
1524 }
1525
1526 cx.notify();
1527 })
1528 .ok();
1529 });
1530
1531 // We always maintain at most two pending completions. When we already
1532 // have two, we replace the newest one.
1533 if self.pending_completions.len() <= 1 {
1534 self.pending_completions.push(PendingCompletion {
1535 id: pending_completion_id,
1536 _task: task,
1537 });
1538 } else if self.pending_completions.len() == 2 {
1539 self.pending_completions.pop();
1540 self.pending_completions.push(PendingCompletion {
1541 id: pending_completion_id,
1542 _task: task,
1543 });
1544 }
1545 }
1546
1547 fn cycle(
1548 &mut self,
1549 _buffer: Entity<Buffer>,
1550 _cursor_position: language::Anchor,
1551 _direction: inline_completion::Direction,
1552 _cx: &mut Context<Self>,
1553 ) {
1554 // Right now we don't support cycling.
1555 }
1556
1557 fn accept(&mut self, _cx: &mut Context<Self>) {
1558 self.pending_completions.clear();
1559 }
1560
1561 fn discard(&mut self, _cx: &mut Context<Self>) {
1562 self.pending_completions.clear();
1563 self.current_completion.take();
1564 }
1565
1566 fn suggest(
1567 &mut self,
1568 buffer: &Entity<Buffer>,
1569 cursor_position: language::Anchor,
1570 cx: &mut Context<Self>,
1571 ) -> Option<inline_completion::InlineCompletion> {
1572 let CurrentInlineCompletion {
1573 buffer_id,
1574 completion,
1575 ..
1576 } = self.current_completion.as_mut()?;
1577
1578 // Invalidate previous completion if it was generated for a different buffer.
1579 if *buffer_id != buffer.entity_id() {
1580 self.current_completion.take();
1581 return None;
1582 }
1583
1584 let buffer = buffer.read(cx);
1585 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1586 self.current_completion.take();
1587 return None;
1588 };
1589
1590 let cursor_row = cursor_position.to_point(buffer).row;
1591 let (closest_edit_ix, (closest_edit_range, _)) =
1592 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1593 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1594 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1595 cmp::min(distance_from_start, distance_from_end)
1596 })?;
1597
1598 let mut edit_start_ix = closest_edit_ix;
1599 for (range, _) in edits[..edit_start_ix].iter().rev() {
1600 let distance_from_closest_edit =
1601 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1602 if distance_from_closest_edit <= 1 {
1603 edit_start_ix -= 1;
1604 } else {
1605 break;
1606 }
1607 }
1608
1609 let mut edit_end_ix = closest_edit_ix + 1;
1610 for (range, _) in &edits[edit_end_ix..] {
1611 let distance_from_closest_edit =
1612 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1613 if distance_from_closest_edit <= 1 {
1614 edit_end_ix += 1;
1615 } else {
1616 break;
1617 }
1618 }
1619
1620 Some(inline_completion::InlineCompletion {
1621 id: Some(completion.id.to_string().into()),
1622 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1623 edit_preview: Some(completion.edit_preview.clone()),
1624 })
1625 }
1626}
1627
1628fn tokens_for_bytes(bytes: usize) -> usize {
1629 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1630 /// intentionally low to err on the side of underestimating limits.
1631 const BYTES_PER_TOKEN_GUESS: usize = 3;
1632 bytes / BYTES_PER_TOKEN_GUESS
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637 use client::test::FakeServer;
1638 use clock::FakeSystemClock;
1639 use gpui::TestAppContext;
1640 use http_client::FakeHttpClient;
1641 use indoc::indoc;
1642 use language::Point;
1643 use rpc::proto;
1644 use settings::SettingsStore;
1645
1646 use super::*;
1647
1648 #[gpui::test]
1649 async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1650 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1651 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1652 to_completion_edits(
1653 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1654 &buffer,
1655 cx,
1656 )
1657 .into()
1658 });
1659
1660 let edit_preview = cx
1661 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1662 .await;
1663
1664 let completion = InlineCompletion {
1665 edits,
1666 edit_preview,
1667 path: Path::new("").into(),
1668 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1669 id: InlineCompletionId(Uuid::new_v4()),
1670 excerpt_range: 0..0,
1671 cursor_offset: 0,
1672 input_outline: "".into(),
1673 input_events: "".into(),
1674 input_excerpt: "".into(),
1675 output_excerpt: "".into(),
1676 request_sent_at: Instant::now(),
1677 response_received_at: Instant::now(),
1678 };
1679
1680 cx.update(|cx| {
1681 assert_eq!(
1682 from_completion_edits(
1683 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1684 &buffer,
1685 cx
1686 ),
1687 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1688 );
1689
1690 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1691 assert_eq!(
1692 from_completion_edits(
1693 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1694 &buffer,
1695 cx
1696 ),
1697 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1698 );
1699
1700 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1701 assert_eq!(
1702 from_completion_edits(
1703 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1704 &buffer,
1705 cx
1706 ),
1707 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1708 );
1709
1710 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1711 assert_eq!(
1712 from_completion_edits(
1713 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1714 &buffer,
1715 cx
1716 ),
1717 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1718 );
1719
1720 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1721 assert_eq!(
1722 from_completion_edits(
1723 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1724 &buffer,
1725 cx
1726 ),
1727 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1728 );
1729
1730 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1731 assert_eq!(
1732 from_completion_edits(
1733 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1734 &buffer,
1735 cx
1736 ),
1737 vec![(9..11, "".to_string())]
1738 );
1739
1740 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1741 assert_eq!(
1742 from_completion_edits(
1743 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1744 &buffer,
1745 cx
1746 ),
1747 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1748 );
1749
1750 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1751 assert_eq!(
1752 from_completion_edits(
1753 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1754 &buffer,
1755 cx
1756 ),
1757 vec![(4..4, "M".to_string())]
1758 );
1759
1760 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1761 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1762 })
1763 }
1764
1765 #[gpui::test]
1766 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1767 cx.update(|cx| {
1768 let settings_store = SettingsStore::test(cx);
1769 cx.set_global(settings_store);
1770 client::init_settings(cx);
1771 });
1772
1773 let edits = edits_for_prediction(
1774 indoc! {"
1775 fn main() {
1776 let word_1 = \"lorem\";
1777 let range = word.len()..word.len();
1778 }
1779 "},
1780 indoc! {"
1781 <|editable_region_start|>
1782 fn main() {
1783 let word_1 = \"lorem\";
1784 let range = word_1.len()..word_1.len();
1785 }
1786
1787 <|editable_region_end|>
1788 "},
1789 cx,
1790 )
1791 .await;
1792 assert_eq!(
1793 edits,
1794 [
1795 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1796 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1797 ]
1798 );
1799
1800 let edits = edits_for_prediction(
1801 indoc! {"
1802 fn main() {
1803 let story = \"the quick\"
1804 }
1805 "},
1806 indoc! {"
1807 <|editable_region_start|>
1808 fn main() {
1809 let story = \"the quick brown fox jumps over the lazy dog\";
1810 }
1811
1812 <|editable_region_end|>
1813 "},
1814 cx,
1815 )
1816 .await;
1817 assert_eq!(
1818 edits,
1819 [
1820 (
1821 Point::new(1, 26)..Point::new(1, 26),
1822 " brown fox jumps over the lazy dog".to_string()
1823 ),
1824 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1825 ]
1826 );
1827 }
1828
1829 #[gpui::test]
1830 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1831 cx.update(|cx| {
1832 let settings_store = SettingsStore::test(cx);
1833 cx.set_global(settings_store);
1834 client::init_settings(cx);
1835 });
1836
1837 let buffer_content = "lorem\n";
1838 let completion_response = indoc! {"
1839 ```animals.js
1840 <|start_of_file|>
1841 <|editable_region_start|>
1842 lorem
1843 ipsum
1844 <|editable_region_end|>
1845 ```"};
1846
1847 let http_client = FakeHttpClient::create(move |_| async move {
1848 Ok(http_client::Response::builder()
1849 .status(200)
1850 .body(
1851 serde_json::to_string(&PredictEditsResponse {
1852 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1853 .unwrap(),
1854 output_excerpt: completion_response.to_string(),
1855 })
1856 .unwrap()
1857 .into(),
1858 )
1859 .unwrap())
1860 });
1861
1862 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1863 cx.update(|cx| {
1864 RefreshLlmTokenListener::register(client.clone(), cx);
1865 });
1866 let server = FakeServer::for_client(42, &client, cx).await;
1867 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1869
1870 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1871 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1872 let completion_task = zeta.update(cx, |zeta, cx| {
1873 zeta.request_completion(None, &buffer, cursor, false, cx)
1874 });
1875
1876 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1877 server.respond(
1878 token_request.receipt(),
1879 proto::GetLlmTokenResponse { token: "".into() },
1880 );
1881
1882 let completion = completion_task.await.unwrap().unwrap();
1883 buffer.update(cx, |buffer, cx| {
1884 buffer.edit(completion.edits.iter().cloned(), None, cx)
1885 });
1886 assert_eq!(
1887 buffer.read_with(cx, |buffer, _| buffer.text()),
1888 "lorem\nipsum"
1889 );
1890 }
1891
1892 async fn edits_for_prediction(
1893 buffer_content: &str,
1894 completion_response: &str,
1895 cx: &mut TestAppContext,
1896 ) -> Vec<(Range<Point>, String)> {
1897 let completion_response = completion_response.to_string();
1898 let http_client = FakeHttpClient::create(move |_| {
1899 let completion = completion_response.clone();
1900 async move {
1901 Ok(http_client::Response::builder()
1902 .status(200)
1903 .body(
1904 serde_json::to_string(&PredictEditsResponse {
1905 request_id: Uuid::new_v4(),
1906 output_excerpt: completion,
1907 })
1908 .unwrap()
1909 .into(),
1910 )
1911 .unwrap())
1912 }
1913 });
1914
1915 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1916 cx.update(|cx| {
1917 RefreshLlmTokenListener::register(client.clone(), cx);
1918 });
1919 let server = FakeServer::for_client(42, &client, cx).await;
1920 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1921 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1922
1923 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1924 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1925 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1926 let completion_task = zeta.update(cx, |zeta, cx| {
1927 zeta.request_completion(None, &buffer, cursor, false, cx)
1928 });
1929
1930 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1931 server.respond(
1932 token_request.receipt(),
1933 proto::GetLlmTokenResponse { token: "".into() },
1934 );
1935
1936 let completion = completion_task.await.unwrap().unwrap();
1937 completion
1938 .edits
1939 .into_iter()
1940 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1941 .collect::<Vec<_>>()
1942 }
1943
1944 fn to_completion_edits(
1945 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1946 buffer: &Entity<Buffer>,
1947 cx: &App,
1948 ) -> Vec<(Range<Anchor>, String)> {
1949 let buffer = buffer.read(cx);
1950 iterator
1951 .into_iter()
1952 .map(|(range, text)| {
1953 (
1954 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1955 text,
1956 )
1957 })
1958 .collect()
1959 }
1960
1961 fn from_completion_edits(
1962 editor_edits: &[(Range<Anchor>, String)],
1963 buffer: &Entity<Buffer>,
1964 cx: &App,
1965 ) -> Vec<(Range<usize>, String)> {
1966 let buffer = buffer.read(cx);
1967 editor_edits
1968 .iter()
1969 .map(|(range, text)| {
1970 (
1971 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1972 text.clone(),
1973 )
1974 })
1975 .collect()
1976 }
1977
1978 #[ctor::ctor]
1979 fn init_logger() {
1980 if std::env::var("RUST_LOG").is_ok() {
1981 env_logger::init();
1982 }
1983 }
1984}