1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::KEY_VALUE_STORE;
11pub use init::*;
12use inline_completion::DataCollectionState;
13use license_detection::LICENSE_FILES_TO_CHECK;
14pub use license_detection::is_license_eligible_for_data_collection;
15pub use rate_completion_modal::*;
16
17use anyhow::{Context as _, Result, anyhow};
18use arrayvec::ArrayVec;
19use client::{Client, EditPredictionUsage, UserStore};
20use collections::{HashMap, HashSet, VecDeque};
21use futures::AsyncReadExt;
22use gpui::{
23 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
24 Subscription, Task, WeakEntity, actions,
25};
26use http_client::{AsyncBody, HttpClient, Method, Request, Response};
27use input_excerpt::excerpt_for_cursor_position;
28use language::{
29 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
30};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use postage::watch;
33use project::Project;
34use release_channel::AppVersion;
35use settings::WorktreeId;
36use std::str::FromStr;
37use std::{
38 borrow::Cow,
39 cmp,
40 fmt::Write,
41 future::Future,
42 mem,
43 ops::Range,
44 path::Path,
45 rc::Rc,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use telemetry_events::InlineCompletionRating;
50use thiserror::Error;
51use util::ResultExt;
52use uuid::Uuid;
53use workspace::Workspace;
54use workspace::notifications::{ErrorMessagePrompt, NotificationId};
55use worktree::Worktree;
56use zed_llm_client::{
57 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
58 PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
59};
60
61const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75actions!(
76 edit_prediction,
77 [
78 /// Clears the edit prediction history.
79 ClearHistory
80 ]
81);
82
83#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
84pub struct InlineCompletionId(Uuid);
85
86impl From<InlineCompletionId> for gpui::ElementId {
87 fn from(value: InlineCompletionId) -> Self {
88 gpui::ElementId::Uuid(value.0)
89 }
90}
91
92impl std::fmt::Display for InlineCompletionId {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98#[derive(Clone)]
99struct ZetaGlobal(Entity<Zeta>);
100
101impl Global for ZetaGlobal {}
102
103#[derive(Clone)]
104pub struct InlineCompletion {
105 id: InlineCompletionId,
106 path: Arc<Path>,
107 excerpt_range: Range<usize>,
108 cursor_offset: usize,
109 edits: Arc<[(Range<Anchor>, String)]>,
110 snapshot: BufferSnapshot,
111 edit_preview: EditPreview,
112 input_outline: Arc<str>,
113 input_events: Arc<str>,
114 input_excerpt: Arc<str>,
115 output_excerpt: Arc<str>,
116 request_sent_at: Instant,
117 response_received_at: Instant,
118}
119
120impl InlineCompletion {
121 fn latency(&self) -> Duration {
122 self.response_received_at
123 .duration_since(self.request_sent_at)
124 }
125
126 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
127 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
128 }
129}
130
131fn interpolate(
132 old_snapshot: &BufferSnapshot,
133 new_snapshot: &BufferSnapshot,
134 current_edits: Arc<[(Range<Anchor>, String)]>,
135) -> Option<Vec<(Range<Anchor>, String)>> {
136 let mut edits = Vec::new();
137
138 let mut model_edits = current_edits.into_iter().peekable();
139 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
140 while let Some((model_old_range, _)) = model_edits.peek() {
141 let model_old_range = model_old_range.to_offset(old_snapshot);
142 if model_old_range.end < user_edit.old.start {
143 let (model_old_range, model_new_text) = model_edits.next().unwrap();
144 edits.push((model_old_range.clone(), model_new_text.clone()));
145 } else {
146 break;
147 }
148 }
149
150 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
151 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
152 if user_edit.old == model_old_offset_range {
153 let user_new_text = new_snapshot
154 .text_for_range(user_edit.new.clone())
155 .collect::<String>();
156
157 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
158 if !model_suffix.is_empty() {
159 let anchor = old_snapshot.anchor_after(user_edit.old.end);
160 edits.push((anchor..anchor, model_suffix.to_string()));
161 }
162
163 model_edits.next();
164 continue;
165 }
166 }
167 }
168
169 return None;
170 }
171
172 edits.extend(model_edits.cloned());
173
174 if edits.is_empty() { None } else { Some(edits) }
175}
176
177impl std::fmt::Debug for InlineCompletion {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("InlineCompletion")
180 .field("id", &self.id)
181 .field("path", &self.path)
182 .field("edits", &self.edits)
183 .finish_non_exhaustive()
184 }
185}
186
187pub struct Zeta {
188 workspace: Option<WeakEntity<Workspace>>,
189 client: Arc<Client>,
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192 shown_completions: VecDeque<InlineCompletion>,
193 rated_completions: HashSet<InlineCompletionId>,
194 data_collection_choice: Entity<DataCollectionChoice>,
195 llm_token: LlmApiToken,
196 _llm_token_subscription: Subscription,
197 /// Whether the terms of service have been accepted.
198 tos_accepted: bool,
199 /// Whether an update to a newer version of Zed is required to continue using Zeta.
200 update_required: bool,
201 user_store: Entity<UserStore>,
202 _user_store_subscription: Subscription,
203 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
204}
205
206impl Zeta {
207 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
208 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
209 }
210
211 pub fn register(
212 workspace: Option<WeakEntity<Workspace>>,
213 worktree: Option<Entity<Worktree>>,
214 client: Arc<Client>,
215 user_store: Entity<UserStore>,
216 cx: &mut App,
217 ) -> Entity<Self> {
218 let this = Self::global(cx).unwrap_or_else(|| {
219 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
220 cx.set_global(ZetaGlobal(entity.clone()));
221 entity
222 });
223
224 this.update(cx, move |this, cx| {
225 if let Some(worktree) = worktree {
226 worktree.update(cx, |worktree, cx| {
227 this.license_detection_watchers
228 .entry(worktree.id())
229 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
230 });
231 }
232 });
233
234 this
235 }
236
237 pub fn clear_history(&mut self) {
238 self.events.clear();
239 }
240
241 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
242 self.user_store.read(cx).edit_prediction_usage()
243 }
244
245 fn new(
246 workspace: Option<WeakEntity<Workspace>>,
247 client: Arc<Client>,
248 user_store: Entity<UserStore>,
249 cx: &mut Context<Self>,
250 ) -> Self {
251 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
252
253 let data_collection_choice = Self::load_data_collection_choices();
254 let data_collection_choice = cx.new(|_| data_collection_choice);
255
256 Self {
257 workspace,
258 client,
259 events: VecDeque::new(),
260 shown_completions: VecDeque::new(),
261 rated_completions: HashSet::default(),
262 registered_buffers: HashMap::default(),
263 data_collection_choice,
264 llm_token: LlmApiToken::default(),
265 _llm_token_subscription: cx.subscribe(
266 &refresh_llm_token_listener,
267 |this, _listener, _event, cx| {
268 let client = this.client.clone();
269 let llm_token = this.llm_token.clone();
270 cx.spawn(async move |_this, _cx| {
271 llm_token.refresh(&client).await?;
272 anyhow::Ok(())
273 })
274 .detach_and_log_err(cx);
275 },
276 ),
277 tos_accepted: user_store
278 .read(cx)
279 .current_user_has_accepted_terms()
280 .unwrap_or(false),
281 update_required: false,
282 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
283 match event {
284 client::user::Event::PrivateUserInfoUpdated => {
285 this.tos_accepted = user_store
286 .read(cx)
287 .current_user_has_accepted_terms()
288 .unwrap_or(false);
289 }
290 _ => {}
291 }
292 }),
293 license_detection_watchers: HashMap::default(),
294 user_store,
295 }
296 }
297
298 fn push_event(&mut self, event: Event) {
299 if let Some(Event::BufferChange {
300 new_snapshot: last_new_snapshot,
301 timestamp: last_timestamp,
302 ..
303 }) = self.events.back_mut()
304 {
305 // Coalesce edits for the same buffer when they happen one after the other.
306 let Event::BufferChange {
307 old_snapshot,
308 new_snapshot,
309 timestamp,
310 } = &event;
311
312 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
313 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
314 && old_snapshot.version == last_new_snapshot.version
315 {
316 *last_new_snapshot = new_snapshot.clone();
317 *last_timestamp = *timestamp;
318 return;
319 }
320 }
321
322 self.events.push_back(event);
323 if self.events.len() >= MAX_EVENT_COUNT {
324 // These are halved instead of popping to improve prompt caching.
325 self.events.drain(..MAX_EVENT_COUNT / 2);
326 }
327 }
328
329 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
330 let buffer_id = buffer.entity_id();
331 let weak_buffer = buffer.downgrade();
332
333 if let std::collections::hash_map::Entry::Vacant(entry) =
334 self.registered_buffers.entry(buffer_id)
335 {
336 let snapshot = buffer.read(cx).snapshot();
337
338 entry.insert(RegisteredBuffer {
339 snapshot,
340 _subscriptions: [
341 cx.subscribe(buffer, move |this, buffer, event, cx| {
342 this.handle_buffer_event(buffer, event, cx);
343 }),
344 cx.observe_release(buffer, move |this, _buffer, _cx| {
345 this.registered_buffers.remove(&weak_buffer.entity_id());
346 }),
347 ],
348 });
349 };
350 }
351
352 fn handle_buffer_event(
353 &mut self,
354 buffer: Entity<Buffer>,
355 event: &language::BufferEvent,
356 cx: &mut Context<Self>,
357 ) {
358 if let language::BufferEvent::Edited = event {
359 self.report_changes_for_buffer(&buffer, cx);
360 }
361 }
362
363 fn request_completion_impl<F, R>(
364 &mut self,
365 workspace: Option<Entity<Workspace>>,
366 project: Option<&Entity<Project>>,
367 buffer: &Entity<Buffer>,
368 cursor: language::Anchor,
369 can_collect_data: bool,
370 cx: &mut Context<Self>,
371 perform_predict_edits: F,
372 ) -> Task<Result<Option<InlineCompletion>>>
373 where
374 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
375 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
376 + Send
377 + 'static,
378 {
379 let snapshot = self.report_changes_for_buffer(&buffer, cx);
380 let diagnostic_groups = snapshot.diagnostic_groups(None);
381 let cursor_point = cursor.to_point(&snapshot);
382 let cursor_offset = cursor_point.to_offset(&snapshot);
383 let events = self.events.clone();
384 let path: Arc<Path> = snapshot
385 .file()
386 .map(|f| Arc::from(f.full_path(cx).as_path()))
387 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
388
389 let zeta = cx.entity();
390 let client = self.client.clone();
391 let llm_token = self.llm_token.clone();
392 let app_version = AppVersion::global(cx);
393
394 let buffer = buffer.clone();
395
396 let local_lsp_store =
397 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
398 let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
399 Some(
400 diagnostic_groups
401 .into_iter()
402 .filter_map(|(language_server_id, diagnostic_group)| {
403 let language_server =
404 local_lsp_store.running_language_server_for_id(language_server_id)?;
405
406 Some((
407 language_server.name(),
408 diagnostic_group.resolve::<usize>(&snapshot),
409 ))
410 })
411 .collect::<Vec<_>>(),
412 )
413 } else {
414 None
415 };
416
417 cx.spawn(async move |this, cx| {
418 let request_sent_at = Instant::now();
419
420 struct BackgroundValues {
421 input_events: String,
422 input_excerpt: String,
423 speculated_output: String,
424 editable_range: Range<usize>,
425 input_outline: String,
426 }
427
428 let values = cx
429 .background_spawn({
430 let snapshot = snapshot.clone();
431 let path = path.clone();
432 async move {
433 let path = path.to_string_lossy();
434 let input_excerpt = excerpt_for_cursor_position(
435 cursor_point,
436 &path,
437 &snapshot,
438 MAX_REWRITE_TOKENS,
439 MAX_CONTEXT_TOKENS,
440 );
441 let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
442 let input_outline = prompt_for_outline(&snapshot);
443
444 anyhow::Ok(BackgroundValues {
445 input_events,
446 input_excerpt: input_excerpt.prompt,
447 speculated_output: input_excerpt.speculated_output,
448 editable_range: input_excerpt.editable_range.to_offset(&snapshot),
449 input_outline,
450 })
451 }
452 })
453 .await?;
454
455 log::debug!(
456 "Events:\n{}\nExcerpt:\n{:?}",
457 values.input_events,
458 values.input_excerpt
459 );
460
461 let body = PredictEditsBody {
462 input_events: values.input_events.clone(),
463 input_excerpt: values.input_excerpt.clone(),
464 speculated_output: Some(values.speculated_output),
465 outline: Some(values.input_outline.clone()),
466 can_collect_data,
467 diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
468 diagnostic_groups
469 .into_iter()
470 .map(|(name, diagnostic_group)| {
471 Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
472 })
473 .collect::<Result<Vec<_>>>()
474 .log_err()
475 }),
476 };
477
478 let response = perform_predict_edits(PerformPredictEditsParams {
479 client,
480 llm_token,
481 app_version,
482 body,
483 })
484 .await;
485 let (response, usage) = match response {
486 Ok(response) => response,
487 Err(err) => {
488 if err.is::<ZedUpdateRequiredError>() {
489 cx.update(|cx| {
490 zeta.update(cx, |zeta, _cx| {
491 zeta.update_required = true;
492 });
493
494 if let Some(workspace) = workspace {
495 workspace.update(cx, |workspace, cx| {
496 workspace.show_notification(
497 NotificationId::unique::<ZedUpdateRequiredError>(),
498 cx,
499 |cx| {
500 cx.new(|cx| {
501 ErrorMessagePrompt::new(err.to_string(), cx)
502 .with_link_button(
503 "Update Zed",
504 "https://zed.dev/releases",
505 )
506 })
507 },
508 );
509 });
510 }
511 })
512 .ok();
513 }
514
515 return Err(err);
516 }
517 };
518
519 log::debug!("completion response: {}", &response.output_excerpt);
520
521 if let Some(usage) = usage {
522 this.update(cx, |this, cx| {
523 this.user_store.update(cx, |user_store, cx| {
524 user_store.update_edit_prediction_usage(usage, cx);
525 });
526 })
527 .ok();
528 }
529
530 Self::process_completion_response(
531 response,
532 buffer,
533 &snapshot,
534 values.editable_range,
535 cursor_offset,
536 path,
537 values.input_outline,
538 values.input_events,
539 values.input_excerpt,
540 request_sent_at,
541 &cx,
542 )
543 .await
544 })
545 }
546
547 // Generates several example completions of various states to fill the Zeta completion modal
548 #[cfg(any(test, feature = "test-support"))]
549 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
550 use language::Point;
551
552 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
553 And maybe a short line
554
555 Then a few lines
556
557 and then another
558 "#};
559
560 let project = None;
561 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
562 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
563
564 let completion_tasks = vec![
565 self.fake_completion(
566 project,
567 &buffer,
568 position,
569 PredictEditsResponse {
570 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
571 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
572a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
573[here's an edit]
574And maybe a short line
575Then a few lines
576and then another
577{EDITABLE_REGION_END_MARKER}
578 ", ),
579 },
580 cx,
581 ),
582 self.fake_completion(
583 project,
584 &buffer,
585 position,
586 PredictEditsResponse {
587 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
588 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
589a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
590And maybe a short line
591[and another edit]
592Then a few lines
593and then another
594{EDITABLE_REGION_END_MARKER}
595 "#),
596 },
597 cx,
598 ),
599 self.fake_completion(
600 project,
601 &buffer,
602 position,
603 PredictEditsResponse {
604 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
605 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
606a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
607And maybe a short line
608
609Then a few lines
610
611and then another
612{EDITABLE_REGION_END_MARKER}
613 "#),
614 },
615 cx,
616 ),
617 self.fake_completion(
618 project,
619 &buffer,
620 position,
621 PredictEditsResponse {
622 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
623 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
624a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
625And maybe a short line
626
627Then a few lines
628
629and then another
630{EDITABLE_REGION_END_MARKER}
631 "#),
632 },
633 cx,
634 ),
635 self.fake_completion(
636 project,
637 &buffer,
638 position,
639 PredictEditsResponse {
640 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
641 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
642a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
643And maybe a short line
644Then a few lines
645[a third completion]
646and then another
647{EDITABLE_REGION_END_MARKER}
648 "#),
649 },
650 cx,
651 ),
652 self.fake_completion(
653 project,
654 &buffer,
655 position,
656 PredictEditsResponse {
657 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
658 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
659a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
660And maybe a short line
661and then another
662[fourth completion example]
663{EDITABLE_REGION_END_MARKER}
664 "#),
665 },
666 cx,
667 ),
668 self.fake_completion(
669 project,
670 &buffer,
671 position,
672 PredictEditsResponse {
673 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
674 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
675a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
676And maybe a short line
677Then a few lines
678and then another
679[fifth and final completion]
680{EDITABLE_REGION_END_MARKER}
681 "#),
682 },
683 cx,
684 ),
685 ];
686
687 cx.spawn(async move |zeta, cx| {
688 for task in completion_tasks {
689 task.await.unwrap();
690 }
691
692 zeta.update(cx, |zeta, _cx| {
693 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
694 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
695 })
696 .ok();
697 })
698 }
699
700 #[cfg(any(test, feature = "test-support"))]
701 pub fn fake_completion(
702 &mut self,
703 project: Option<&Entity<Project>>,
704 buffer: &Entity<Buffer>,
705 position: language::Anchor,
706 response: PredictEditsResponse,
707 cx: &mut Context<Self>,
708 ) -> Task<Result<Option<InlineCompletion>>> {
709 use std::future::ready;
710
711 self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
712 ready(Ok((response, None)))
713 })
714 }
715
716 pub fn request_completion(
717 &mut self,
718 project: Option<&Entity<Project>>,
719 buffer: &Entity<Buffer>,
720 position: language::Anchor,
721 can_collect_data: bool,
722 cx: &mut Context<Self>,
723 ) -> Task<Result<Option<InlineCompletion>>> {
724 let workspace = self
725 .workspace
726 .as_ref()
727 .and_then(|workspace| workspace.upgrade());
728 self.request_completion_impl(
729 workspace,
730 project,
731 buffer,
732 position,
733 can_collect_data,
734 cx,
735 Self::perform_predict_edits,
736 )
737 }
738
739 fn perform_predict_edits(
740 params: PerformPredictEditsParams,
741 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
742 async move {
743 let PerformPredictEditsParams {
744 client,
745 llm_token,
746 app_version,
747 body,
748 ..
749 } = params;
750
751 let http_client = client.http_client();
752 let mut token = llm_token.acquire(&client).await?;
753 let mut did_retry = false;
754
755 loop {
756 let request_builder = http_client::Request::builder().method(Method::POST);
757 let request_builder =
758 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
759 request_builder.uri(predict_edits_url)
760 } else {
761 request_builder.uri(
762 http_client
763 .build_zed_llm_url("/predict_edits/v2", &[])?
764 .as_ref(),
765 )
766 };
767 let request = request_builder
768 .header("Content-Type", "application/json")
769 .header("Authorization", format!("Bearer {}", token))
770 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
771 .body(serde_json::to_string(&body)?.into())?;
772
773 let mut response = http_client.send(request).await?;
774
775 if let Some(minimum_required_version) = response
776 .headers()
777 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
778 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
779 {
780 anyhow::ensure!(
781 app_version >= minimum_required_version,
782 ZedUpdateRequiredError {
783 minimum_version: minimum_required_version
784 }
785 );
786 }
787
788 if response.status().is_success() {
789 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
790
791 let mut body = String::new();
792 response.body_mut().read_to_string(&mut body).await?;
793 return Ok((serde_json::from_str(&body)?, usage));
794 } else if !did_retry
795 && response
796 .headers()
797 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
798 .is_some()
799 {
800 did_retry = true;
801 token = llm_token.refresh(&client).await?;
802 } else {
803 let mut body = String::new();
804 response.body_mut().read_to_string(&mut body).await?;
805 anyhow::bail!(
806 "error predicting edits.\nStatus: {:?}\nBody: {}",
807 response.status(),
808 body
809 );
810 }
811 }
812 }
813 }
814
815 fn accept_edit_prediction(
816 &mut self,
817 request_id: InlineCompletionId,
818 cx: &mut Context<Self>,
819 ) -> Task<Result<()>> {
820 let client = self.client.clone();
821 let llm_token = self.llm_token.clone();
822 let app_version = AppVersion::global(cx);
823 cx.spawn(async move |this, cx| {
824 let http_client = client.http_client();
825 let mut response = llm_token_retry(&llm_token, &client, |token| {
826 let request_builder = http_client::Request::builder().method(Method::POST);
827 let request_builder =
828 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
829 request_builder.uri(accept_prediction_url)
830 } else {
831 request_builder.uri(
832 http_client
833 .build_zed_llm_url("/predict_edits/accept", &[])?
834 .as_ref(),
835 )
836 };
837 Ok(request_builder
838 .header("Content-Type", "application/json")
839 .header("Authorization", format!("Bearer {}", token))
840 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
841 .body(
842 serde_json::to_string(&AcceptEditPredictionBody {
843 request_id: request_id.0,
844 })?
845 .into(),
846 )?)
847 })
848 .await?;
849
850 if let Some(minimum_required_version) = response
851 .headers()
852 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
853 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
854 {
855 if app_version < minimum_required_version {
856 return Err(anyhow!(ZedUpdateRequiredError {
857 minimum_version: minimum_required_version
858 }));
859 }
860 }
861
862 if response.status().is_success() {
863 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
864 this.update(cx, |this, cx| {
865 this.user_store.update(cx, |user_store, cx| {
866 user_store.update_edit_prediction_usage(usage, cx);
867 });
868 })?;
869 }
870
871 Ok(())
872 } else {
873 let mut body = String::new();
874 response.body_mut().read_to_string(&mut body).await?;
875 Err(anyhow!(
876 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
877 response.status(),
878 body
879 ))
880 }
881 })
882 }
883
884 fn process_completion_response(
885 prediction_response: PredictEditsResponse,
886 buffer: Entity<Buffer>,
887 snapshot: &BufferSnapshot,
888 editable_range: Range<usize>,
889 cursor_offset: usize,
890 path: Arc<Path>,
891 input_outline: String,
892 input_events: String,
893 input_excerpt: String,
894 request_sent_at: Instant,
895 cx: &AsyncApp,
896 ) -> Task<Result<Option<InlineCompletion>>> {
897 let snapshot = snapshot.clone();
898 let request_id = prediction_response.request_id;
899 let output_excerpt = prediction_response.output_excerpt;
900 cx.spawn(async move |cx| {
901 let output_excerpt: Arc<str> = output_excerpt.into();
902
903 let edits: Arc<[(Range<Anchor>, String)]> = cx
904 .background_spawn({
905 let output_excerpt = output_excerpt.clone();
906 let editable_range = editable_range.clone();
907 let snapshot = snapshot.clone();
908 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
909 })
910 .await?
911 .into();
912
913 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
914 let edits = edits.clone();
915 |buffer, cx| {
916 let new_snapshot = buffer.snapshot();
917 let edits: Arc<[(Range<Anchor>, String)]> =
918 interpolate(&snapshot, &new_snapshot, edits)?.into();
919 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
920 }
921 })?
922 else {
923 return anyhow::Ok(None);
924 };
925
926 let edit_preview = edit_preview.await;
927
928 Ok(Some(InlineCompletion {
929 id: InlineCompletionId(request_id),
930 path,
931 excerpt_range: editable_range,
932 cursor_offset,
933 edits,
934 edit_preview,
935 snapshot,
936 input_outline: input_outline.into(),
937 input_events: input_events.into(),
938 input_excerpt: input_excerpt.into(),
939 output_excerpt,
940 request_sent_at,
941 response_received_at: Instant::now(),
942 }))
943 })
944 }
945
946 fn parse_edits(
947 output_excerpt: Arc<str>,
948 editable_range: Range<usize>,
949 snapshot: &BufferSnapshot,
950 ) -> Result<Vec<(Range<Anchor>, String)>> {
951 let content = output_excerpt.replace(CURSOR_MARKER, "");
952
953 let start_markers = content
954 .match_indices(EDITABLE_REGION_START_MARKER)
955 .collect::<Vec<_>>();
956 anyhow::ensure!(
957 start_markers.len() == 1,
958 "expected exactly one start marker, found {}",
959 start_markers.len()
960 );
961
962 let end_markers = content
963 .match_indices(EDITABLE_REGION_END_MARKER)
964 .collect::<Vec<_>>();
965 anyhow::ensure!(
966 end_markers.len() == 1,
967 "expected exactly one end marker, found {}",
968 end_markers.len()
969 );
970
971 let sof_markers = content
972 .match_indices(START_OF_FILE_MARKER)
973 .collect::<Vec<_>>();
974 anyhow::ensure!(
975 sof_markers.len() <= 1,
976 "expected at most one start-of-file marker, found {}",
977 sof_markers.len()
978 );
979
980 let codefence_start = start_markers[0].0;
981 let content = &content[codefence_start..];
982
983 let newline_ix = content.find('\n').context("could not find newline")?;
984 let content = &content[newline_ix + 1..];
985
986 let codefence_end = content
987 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
988 .context("could not find end marker")?;
989 let new_text = &content[..codefence_end];
990
991 let old_text = snapshot
992 .text_for_range(editable_range.clone())
993 .collect::<String>();
994
995 Ok(Self::compute_edits(
996 old_text,
997 new_text,
998 editable_range.start,
999 &snapshot,
1000 ))
1001 }
1002
1003 pub fn compute_edits(
1004 old_text: String,
1005 new_text: &str,
1006 offset: usize,
1007 snapshot: &BufferSnapshot,
1008 ) -> Vec<(Range<Anchor>, String)> {
1009 text_diff(&old_text, &new_text)
1010 .into_iter()
1011 .map(|(mut old_range, new_text)| {
1012 old_range.start += offset;
1013 old_range.end += offset;
1014
1015 let prefix_len = common_prefix(
1016 snapshot.chars_for_range(old_range.clone()),
1017 new_text.chars(),
1018 );
1019 old_range.start += prefix_len;
1020
1021 let suffix_len = common_prefix(
1022 snapshot.reversed_chars_for_range(old_range.clone()),
1023 new_text[prefix_len..].chars().rev(),
1024 );
1025 old_range.end = old_range.end.saturating_sub(suffix_len);
1026
1027 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1028 let range = if old_range.is_empty() {
1029 let anchor = snapshot.anchor_after(old_range.start);
1030 anchor..anchor
1031 } else {
1032 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1033 };
1034 (range, new_text)
1035 })
1036 .collect()
1037 }
1038
1039 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
1040 self.rated_completions.contains(&completion_id)
1041 }
1042
1043 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
1044 self.shown_completions.push_front(completion.clone());
1045 if self.shown_completions.len() > 50 {
1046 let completion = self.shown_completions.pop_back().unwrap();
1047 self.rated_completions.remove(&completion.id);
1048 }
1049 cx.notify();
1050 }
1051
1052 pub fn rate_completion(
1053 &mut self,
1054 completion: &InlineCompletion,
1055 rating: InlineCompletionRating,
1056 feedback: String,
1057 cx: &mut Context<Self>,
1058 ) {
1059 self.rated_completions.insert(completion.id);
1060 telemetry::event!(
1061 "Edit Prediction Rated",
1062 rating,
1063 input_events = completion.input_events,
1064 input_excerpt = completion.input_excerpt,
1065 input_outline = completion.input_outline,
1066 output_excerpt = completion.output_excerpt,
1067 feedback
1068 );
1069 self.client.telemetry().flush_events().detach();
1070 cx.notify();
1071 }
1072
1073 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
1074 self.shown_completions.iter()
1075 }
1076
1077 pub fn shown_completions_len(&self) -> usize {
1078 self.shown_completions.len()
1079 }
1080
1081 fn report_changes_for_buffer(
1082 &mut self,
1083 buffer: &Entity<Buffer>,
1084 cx: &mut Context<Self>,
1085 ) -> BufferSnapshot {
1086 self.register_buffer(buffer, cx);
1087
1088 let registered_buffer = self
1089 .registered_buffers
1090 .get_mut(&buffer.entity_id())
1091 .unwrap();
1092 let new_snapshot = buffer.read(cx).snapshot();
1093
1094 if new_snapshot.version != registered_buffer.snapshot.version {
1095 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1096 self.push_event(Event::BufferChange {
1097 old_snapshot,
1098 new_snapshot: new_snapshot.clone(),
1099 timestamp: Instant::now(),
1100 });
1101 }
1102
1103 new_snapshot
1104 }
1105
1106 fn load_data_collection_choices() -> DataCollectionChoice {
1107 let choice = KEY_VALUE_STORE
1108 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1109 .log_err()
1110 .flatten();
1111
1112 match choice.as_deref() {
1113 Some("true") => DataCollectionChoice::Enabled,
1114 Some("false") => DataCollectionChoice::Disabled,
1115 Some(_) => {
1116 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1117 DataCollectionChoice::NotAnswered
1118 }
1119 None => DataCollectionChoice::NotAnswered,
1120 }
1121 }
1122}
1123
1124struct PerformPredictEditsParams {
1125 pub client: Arc<Client>,
1126 pub llm_token: LlmApiToken,
1127 pub app_version: SemanticVersion,
1128 pub body: PredictEditsBody,
1129}
1130
1131#[derive(Error, Debug)]
1132#[error(
1133 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1134)]
1135pub struct ZedUpdateRequiredError {
1136 minimum_version: SemanticVersion,
1137}
1138
1139struct LicenseDetectionWatcher {
1140 is_open_source_rx: watch::Receiver<bool>,
1141 _is_open_source_task: Task<()>,
1142}
1143
1144impl LicenseDetectionWatcher {
1145 pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1146 let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1147
1148 // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1149 let task = if worktree.abs_path().is_file() {
1150 Task::ready(())
1151 } else {
1152 let loaded_files = LICENSE_FILES_TO_CHECK
1153 .iter()
1154 .map(Path::new)
1155 .map(|file| worktree.load_file(file, cx))
1156 .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1157
1158 cx.background_spawn(async move {
1159 for loaded_file in loaded_files.into_iter() {
1160 let Ok(loaded_file) = loaded_file.await else {
1161 continue;
1162 };
1163
1164 let path = &loaded_file.file.path;
1165 if is_license_eligible_for_data_collection(&loaded_file.text) {
1166 log::info!("detected '{path:?}' as open source license");
1167 *is_open_source_tx.borrow_mut() = true;
1168 } else {
1169 log::info!("didn't detect '{path:?}' as open source license");
1170 }
1171
1172 // stop on the first license that successfully read
1173 return;
1174 }
1175
1176 log::debug!("didn't find a license file to check, assuming closed source");
1177 })
1178 };
1179
1180 Self {
1181 is_open_source_rx,
1182 _is_open_source_task: task,
1183 }
1184 }
1185
1186 /// Answers false until we find out it's open source
1187 pub fn is_project_open_source(&self) -> bool {
1188 *self.is_open_source_rx.borrow()
1189 }
1190}
1191
1192fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1193 a.zip(b)
1194 .take_while(|(a, b)| a == b)
1195 .map(|(a, _)| a.len_utf8())
1196 .sum()
1197}
1198
1199fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1200 let mut input_outline = String::new();
1201
1202 writeln!(
1203 input_outline,
1204 "```{}",
1205 snapshot
1206 .file()
1207 .map_or(Cow::Borrowed("untitled"), |file| file
1208 .path()
1209 .to_string_lossy())
1210 )
1211 .unwrap();
1212
1213 if let Some(outline) = snapshot.outline(None) {
1214 for item in &outline.items {
1215 let spacing = " ".repeat(item.depth);
1216 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1217 }
1218 }
1219
1220 writeln!(input_outline, "```").unwrap();
1221
1222 input_outline
1223}
1224
1225fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1226 let mut result = String::new();
1227 for event in events.iter().rev() {
1228 let event_string = event.to_prompt();
1229 let event_tokens = tokens_for_bytes(event_string.len());
1230 if event_tokens > remaining_tokens {
1231 break;
1232 }
1233
1234 if !result.is_empty() {
1235 result.insert_str(0, "\n\n");
1236 }
1237 result.insert_str(0, &event_string);
1238 remaining_tokens -= event_tokens;
1239 }
1240 result
1241}
1242
1243struct RegisteredBuffer {
1244 snapshot: BufferSnapshot,
1245 _subscriptions: [gpui::Subscription; 2],
1246}
1247
1248#[derive(Clone)]
1249enum Event {
1250 BufferChange {
1251 old_snapshot: BufferSnapshot,
1252 new_snapshot: BufferSnapshot,
1253 timestamp: Instant,
1254 },
1255}
1256
1257impl Event {
1258 fn to_prompt(&self) -> String {
1259 match self {
1260 Event::BufferChange {
1261 old_snapshot,
1262 new_snapshot,
1263 ..
1264 } => {
1265 let mut prompt = String::new();
1266
1267 let old_path = old_snapshot
1268 .file()
1269 .map(|f| f.path().as_ref())
1270 .unwrap_or(Path::new("untitled"));
1271 let new_path = new_snapshot
1272 .file()
1273 .map(|f| f.path().as_ref())
1274 .unwrap_or(Path::new("untitled"));
1275 if old_path != new_path {
1276 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1277 }
1278
1279 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1280 if !diff.is_empty() {
1281 write!(
1282 prompt,
1283 "User edited {:?}:\n```diff\n{}\n```",
1284 new_path, diff
1285 )
1286 .unwrap();
1287 }
1288
1289 prompt
1290 }
1291 }
1292 }
1293}
1294
1295#[derive(Debug, Clone)]
1296struct CurrentInlineCompletion {
1297 buffer_id: EntityId,
1298 completion: InlineCompletion,
1299}
1300
1301impl CurrentInlineCompletion {
1302 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1303 if self.buffer_id != old_completion.buffer_id {
1304 return true;
1305 }
1306
1307 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1308 return true;
1309 };
1310 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1311 return false;
1312 };
1313
1314 if old_edits.len() == 1 && new_edits.len() == 1 {
1315 let (old_range, old_text) = &old_edits[0];
1316 let (new_range, new_text) = &new_edits[0];
1317 new_range == old_range && new_text.starts_with(old_text)
1318 } else {
1319 true
1320 }
1321 }
1322}
1323
1324struct PendingCompletion {
1325 id: usize,
1326 _task: Task<()>,
1327}
1328
1329#[derive(Debug, Clone, Copy)]
1330pub enum DataCollectionChoice {
1331 NotAnswered,
1332 Enabled,
1333 Disabled,
1334}
1335
1336impl DataCollectionChoice {
1337 pub fn is_enabled(self) -> bool {
1338 match self {
1339 Self::Enabled => true,
1340 Self::NotAnswered | Self::Disabled => false,
1341 }
1342 }
1343
1344 pub fn is_answered(self) -> bool {
1345 match self {
1346 Self::Enabled | Self::Disabled => true,
1347 Self::NotAnswered => false,
1348 }
1349 }
1350
1351 pub fn toggle(&self) -> DataCollectionChoice {
1352 match self {
1353 Self::Enabled => Self::Disabled,
1354 Self::Disabled => Self::Enabled,
1355 Self::NotAnswered => Self::Enabled,
1356 }
1357 }
1358}
1359
1360impl From<bool> for DataCollectionChoice {
1361 fn from(value: bool) -> Self {
1362 match value {
1363 true => DataCollectionChoice::Enabled,
1364 false => DataCollectionChoice::Disabled,
1365 }
1366 }
1367}
1368
1369pub struct ProviderDataCollection {
1370 /// When set to None, data collection is not possible in the provider buffer
1371 choice: Option<Entity<DataCollectionChoice>>,
1372 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1373}
1374
1375impl ProviderDataCollection {
1376 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1377 let choice_and_watcher = buffer.and_then(|buffer| {
1378 let file = buffer.read(cx).file()?;
1379
1380 if !file.is_local() || file.is_private() {
1381 return None;
1382 }
1383
1384 let zeta = zeta.read(cx);
1385 let choice = zeta.data_collection_choice.clone();
1386
1387 let license_detection_watcher = zeta
1388 .license_detection_watchers
1389 .get(&file.worktree_id(cx))
1390 .cloned()?;
1391
1392 Some((choice, license_detection_watcher))
1393 });
1394
1395 if let Some((choice, watcher)) = choice_and_watcher {
1396 ProviderDataCollection {
1397 choice: Some(choice),
1398 license_detection_watcher: Some(watcher),
1399 }
1400 } else {
1401 ProviderDataCollection {
1402 choice: None,
1403 license_detection_watcher: None,
1404 }
1405 }
1406 }
1407
1408 pub fn can_collect_data(&self, cx: &App) -> bool {
1409 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1410 }
1411
1412 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1413 self.choice
1414 .as_ref()
1415 .is_some_and(|choice| choice.read(cx).is_enabled())
1416 }
1417
1418 fn is_project_open_source(&self) -> bool {
1419 self.license_detection_watcher
1420 .as_ref()
1421 .is_some_and(|watcher| watcher.is_project_open_source())
1422 }
1423
1424 pub fn toggle(&mut self, cx: &mut App) {
1425 if let Some(choice) = self.choice.as_mut() {
1426 let new_choice = choice.update(cx, |choice, _cx| {
1427 let new_choice = choice.toggle();
1428 *choice = new_choice;
1429 new_choice
1430 });
1431
1432 db::write_and_log(cx, move || {
1433 KEY_VALUE_STORE.write_kvp(
1434 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1435 new_choice.is_enabled().to_string(),
1436 )
1437 });
1438 }
1439 }
1440}
1441
1442async fn llm_token_retry(
1443 llm_token: &LlmApiToken,
1444 client: &Arc<Client>,
1445 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1446) -> Result<Response<AsyncBody>> {
1447 let mut did_retry = false;
1448 let http_client = client.http_client();
1449 let mut token = llm_token.acquire(client).await?;
1450 loop {
1451 let request = build_request(token.clone())?;
1452 let response = http_client.send(request).await?;
1453
1454 if !did_retry
1455 && !response.status().is_success()
1456 && response
1457 .headers()
1458 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1459 .is_some()
1460 {
1461 did_retry = true;
1462 token = llm_token.refresh(client).await?;
1463 continue;
1464 }
1465
1466 return Ok(response);
1467 }
1468}
1469
1470pub struct ZetaInlineCompletionProvider {
1471 zeta: Entity<Zeta>,
1472 pending_completions: ArrayVec<PendingCompletion, 2>,
1473 next_pending_completion_id: usize,
1474 current_completion: Option<CurrentInlineCompletion>,
1475 /// None if this is entirely disabled for this provider
1476 provider_data_collection: ProviderDataCollection,
1477 last_request_timestamp: Instant,
1478}
1479
1480impl ZetaInlineCompletionProvider {
1481 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1482
1483 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1484 Self {
1485 zeta,
1486 pending_completions: ArrayVec::new(),
1487 next_pending_completion_id: 0,
1488 current_completion: None,
1489 provider_data_collection,
1490 last_request_timestamp: Instant::now(),
1491 }
1492 }
1493}
1494
1495impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1496 fn name() -> &'static str {
1497 "zed-predict"
1498 }
1499
1500 fn display_name() -> &'static str {
1501 "Zed's Edit Predictions"
1502 }
1503
1504 fn show_completions_in_menu() -> bool {
1505 true
1506 }
1507
1508 fn show_tab_accept_marker() -> bool {
1509 true
1510 }
1511
1512 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1513 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1514
1515 if self.provider_data_collection.is_data_collection_enabled(cx) {
1516 DataCollectionState::Enabled {
1517 is_project_open_source,
1518 }
1519 } else {
1520 DataCollectionState::Disabled {
1521 is_project_open_source,
1522 }
1523 }
1524 }
1525
1526 fn toggle_data_collection(&mut self, cx: &mut App) {
1527 self.provider_data_collection.toggle(cx);
1528 }
1529
1530 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1531 self.zeta.read(cx).usage(cx)
1532 }
1533
1534 fn is_enabled(
1535 &self,
1536 _buffer: &Entity<Buffer>,
1537 _cursor_position: language::Anchor,
1538 _cx: &App,
1539 ) -> bool {
1540 true
1541 }
1542
1543 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1544 !self.zeta.read(cx).tos_accepted
1545 }
1546
1547 fn is_refreshing(&self) -> bool {
1548 !self.pending_completions.is_empty()
1549 }
1550
1551 fn refresh(
1552 &mut self,
1553 project: Option<Entity<Project>>,
1554 buffer: Entity<Buffer>,
1555 position: language::Anchor,
1556 _debounce: bool,
1557 cx: &mut Context<Self>,
1558 ) {
1559 if !self.zeta.read(cx).tos_accepted {
1560 return;
1561 }
1562
1563 if self.zeta.read(cx).update_required {
1564 return;
1565 }
1566
1567 if self
1568 .zeta
1569 .read(cx)
1570 .user_store
1571 .read_with(cx, |user_store, _| {
1572 user_store.account_too_young() || user_store.has_overdue_invoices()
1573 })
1574 {
1575 return;
1576 }
1577
1578 if let Some(current_completion) = self.current_completion.as_ref() {
1579 let snapshot = buffer.read(cx).snapshot();
1580 if current_completion
1581 .completion
1582 .interpolate(&snapshot)
1583 .is_some()
1584 {
1585 return;
1586 }
1587 }
1588
1589 let pending_completion_id = self.next_pending_completion_id;
1590 self.next_pending_completion_id += 1;
1591 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1592 let last_request_timestamp = self.last_request_timestamp;
1593
1594 let task = cx.spawn(async move |this, cx| {
1595 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1596 .checked_duration_since(Instant::now())
1597 {
1598 cx.background_executor().timer(timeout).await;
1599 }
1600
1601 let completion_request = this.update(cx, |this, cx| {
1602 this.last_request_timestamp = Instant::now();
1603 this.zeta.update(cx, |zeta, cx| {
1604 zeta.request_completion(
1605 project.as_ref(),
1606 &buffer,
1607 position,
1608 can_collect_data,
1609 cx,
1610 )
1611 })
1612 });
1613
1614 let completion = match completion_request {
1615 Ok(completion_request) => {
1616 let completion_request = completion_request.await;
1617 completion_request.map(|c| {
1618 c.map(|completion| CurrentInlineCompletion {
1619 buffer_id: buffer.entity_id(),
1620 completion,
1621 })
1622 })
1623 }
1624 Err(error) => Err(error),
1625 };
1626 let Some(new_completion) = completion
1627 .context("edit prediction failed")
1628 .log_err()
1629 .flatten()
1630 else {
1631 this.update(cx, |this, cx| {
1632 if this.pending_completions[0].id == pending_completion_id {
1633 this.pending_completions.remove(0);
1634 } else {
1635 this.pending_completions.clear();
1636 }
1637
1638 cx.notify();
1639 })
1640 .ok();
1641 return;
1642 };
1643
1644 this.update(cx, |this, cx| {
1645 if this.pending_completions[0].id == pending_completion_id {
1646 this.pending_completions.remove(0);
1647 } else {
1648 this.pending_completions.clear();
1649 }
1650
1651 if let Some(old_completion) = this.current_completion.as_ref() {
1652 let snapshot = buffer.read(cx).snapshot();
1653 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1654 this.zeta.update(cx, |zeta, cx| {
1655 zeta.completion_shown(&new_completion.completion, cx);
1656 });
1657 this.current_completion = Some(new_completion);
1658 }
1659 } else {
1660 this.zeta.update(cx, |zeta, cx| {
1661 zeta.completion_shown(&new_completion.completion, cx);
1662 });
1663 this.current_completion = Some(new_completion);
1664 }
1665
1666 cx.notify();
1667 })
1668 .ok();
1669 });
1670
1671 // We always maintain at most two pending completions. When we already
1672 // have two, we replace the newest one.
1673 if self.pending_completions.len() <= 1 {
1674 self.pending_completions.push(PendingCompletion {
1675 id: pending_completion_id,
1676 _task: task,
1677 });
1678 } else if self.pending_completions.len() == 2 {
1679 self.pending_completions.pop();
1680 self.pending_completions.push(PendingCompletion {
1681 id: pending_completion_id,
1682 _task: task,
1683 });
1684 }
1685 }
1686
1687 fn cycle(
1688 &mut self,
1689 _buffer: Entity<Buffer>,
1690 _cursor_position: language::Anchor,
1691 _direction: inline_completion::Direction,
1692 _cx: &mut Context<Self>,
1693 ) {
1694 // Right now we don't support cycling.
1695 }
1696
1697 fn accept(&mut self, cx: &mut Context<Self>) {
1698 let completion_id = self
1699 .current_completion
1700 .as_ref()
1701 .map(|completion| completion.completion.id);
1702 if let Some(completion_id) = completion_id {
1703 self.zeta
1704 .update(cx, |zeta, cx| {
1705 zeta.accept_edit_prediction(completion_id, cx)
1706 })
1707 .detach();
1708 }
1709 self.pending_completions.clear();
1710 }
1711
1712 fn discard(&mut self, _cx: &mut Context<Self>) {
1713 self.pending_completions.clear();
1714 self.current_completion.take();
1715 }
1716
1717 fn suggest(
1718 &mut self,
1719 buffer: &Entity<Buffer>,
1720 cursor_position: language::Anchor,
1721 cx: &mut Context<Self>,
1722 ) -> Option<inline_completion::InlineCompletion> {
1723 let CurrentInlineCompletion {
1724 buffer_id,
1725 completion,
1726 ..
1727 } = self.current_completion.as_mut()?;
1728
1729 // Invalidate previous completion if it was generated for a different buffer.
1730 if *buffer_id != buffer.entity_id() {
1731 self.current_completion.take();
1732 return None;
1733 }
1734
1735 let buffer = buffer.read(cx);
1736 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1737 self.current_completion.take();
1738 return None;
1739 };
1740
1741 let cursor_row = cursor_position.to_point(buffer).row;
1742 let (closest_edit_ix, (closest_edit_range, _)) =
1743 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1744 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1745 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1746 cmp::min(distance_from_start, distance_from_end)
1747 })?;
1748
1749 let mut edit_start_ix = closest_edit_ix;
1750 for (range, _) in edits[..edit_start_ix].iter().rev() {
1751 let distance_from_closest_edit =
1752 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1753 if distance_from_closest_edit <= 1 {
1754 edit_start_ix -= 1;
1755 } else {
1756 break;
1757 }
1758 }
1759
1760 let mut edit_end_ix = closest_edit_ix + 1;
1761 for (range, _) in &edits[edit_end_ix..] {
1762 let distance_from_closest_edit =
1763 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1764 if distance_from_closest_edit <= 1 {
1765 edit_end_ix += 1;
1766 } else {
1767 break;
1768 }
1769 }
1770
1771 Some(inline_completion::InlineCompletion {
1772 id: Some(completion.id.to_string().into()),
1773 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1774 edit_preview: Some(completion.edit_preview.clone()),
1775 })
1776 }
1777}
1778
1779fn tokens_for_bytes(bytes: usize) -> usize {
1780 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1781 /// intentionally low to err on the side of underestimating limits.
1782 const BYTES_PER_TOKEN_GUESS: usize = 3;
1783 bytes / BYTES_PER_TOKEN_GUESS
1784}
1785
1786#[cfg(test)]
1787mod tests {
1788 use client::test::FakeServer;
1789 use clock::FakeSystemClock;
1790 use gpui::TestAppContext;
1791 use http_client::FakeHttpClient;
1792 use indoc::indoc;
1793 use language::Point;
1794 use rpc::proto;
1795 use settings::SettingsStore;
1796
1797 use super::*;
1798
1799 #[gpui::test]
1800 async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1801 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1802 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1803 to_completion_edits(
1804 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1805 &buffer,
1806 cx,
1807 )
1808 .into()
1809 });
1810
1811 let edit_preview = cx
1812 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1813 .await;
1814
1815 let completion = InlineCompletion {
1816 edits,
1817 edit_preview,
1818 path: Path::new("").into(),
1819 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1820 id: InlineCompletionId(Uuid::new_v4()),
1821 excerpt_range: 0..0,
1822 cursor_offset: 0,
1823 input_outline: "".into(),
1824 input_events: "".into(),
1825 input_excerpt: "".into(),
1826 output_excerpt: "".into(),
1827 request_sent_at: Instant::now(),
1828 response_received_at: Instant::now(),
1829 };
1830
1831 cx.update(|cx| {
1832 assert_eq!(
1833 from_completion_edits(
1834 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1835 &buffer,
1836 cx
1837 ),
1838 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1839 );
1840
1841 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1842 assert_eq!(
1843 from_completion_edits(
1844 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1845 &buffer,
1846 cx
1847 ),
1848 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1849 );
1850
1851 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1852 assert_eq!(
1853 from_completion_edits(
1854 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1855 &buffer,
1856 cx
1857 ),
1858 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1859 );
1860
1861 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1862 assert_eq!(
1863 from_completion_edits(
1864 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1865 &buffer,
1866 cx
1867 ),
1868 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1869 );
1870
1871 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1872 assert_eq!(
1873 from_completion_edits(
1874 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1875 &buffer,
1876 cx
1877 ),
1878 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1879 );
1880
1881 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1882 assert_eq!(
1883 from_completion_edits(
1884 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1885 &buffer,
1886 cx
1887 ),
1888 vec![(9..11, "".to_string())]
1889 );
1890
1891 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1892 assert_eq!(
1893 from_completion_edits(
1894 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1895 &buffer,
1896 cx
1897 ),
1898 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1899 );
1900
1901 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1902 assert_eq!(
1903 from_completion_edits(
1904 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1905 &buffer,
1906 cx
1907 ),
1908 vec![(4..4, "M".to_string())]
1909 );
1910
1911 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1912 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1913 })
1914 }
1915
1916 #[gpui::test]
1917 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1918 cx.update(|cx| {
1919 let settings_store = SettingsStore::test(cx);
1920 cx.set_global(settings_store);
1921 client::init_settings(cx);
1922 });
1923
1924 let edits = edits_for_prediction(
1925 indoc! {"
1926 fn main() {
1927 let word_1 = \"lorem\";
1928 let range = word.len()..word.len();
1929 }
1930 "},
1931 indoc! {"
1932 <|editable_region_start|>
1933 fn main() {
1934 let word_1 = \"lorem\";
1935 let range = word_1.len()..word_1.len();
1936 }
1937
1938 <|editable_region_end|>
1939 "},
1940 cx,
1941 )
1942 .await;
1943 assert_eq!(
1944 edits,
1945 [
1946 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1947 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1948 ]
1949 );
1950
1951 let edits = edits_for_prediction(
1952 indoc! {"
1953 fn main() {
1954 let story = \"the quick\"
1955 }
1956 "},
1957 indoc! {"
1958 <|editable_region_start|>
1959 fn main() {
1960 let story = \"the quick brown fox jumps over the lazy dog\";
1961 }
1962
1963 <|editable_region_end|>
1964 "},
1965 cx,
1966 )
1967 .await;
1968 assert_eq!(
1969 edits,
1970 [
1971 (
1972 Point::new(1, 26)..Point::new(1, 26),
1973 " brown fox jumps over the lazy dog".to_string()
1974 ),
1975 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1976 ]
1977 );
1978 }
1979
1980 #[gpui::test]
1981 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1982 cx.update(|cx| {
1983 let settings_store = SettingsStore::test(cx);
1984 cx.set_global(settings_store);
1985 client::init_settings(cx);
1986 });
1987
1988 let buffer_content = "lorem\n";
1989 let completion_response = indoc! {"
1990 ```animals.js
1991 <|start_of_file|>
1992 <|editable_region_start|>
1993 lorem
1994 ipsum
1995 <|editable_region_end|>
1996 ```"};
1997
1998 let http_client = FakeHttpClient::create(move |_| async move {
1999 Ok(http_client::Response::builder()
2000 .status(200)
2001 .body(
2002 serde_json::to_string(&PredictEditsResponse {
2003 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2004 .unwrap(),
2005 output_excerpt: completion_response.to_string(),
2006 })
2007 .unwrap()
2008 .into(),
2009 )
2010 .unwrap())
2011 });
2012
2013 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2014 cx.update(|cx| {
2015 RefreshLlmTokenListener::register(client.clone(), cx);
2016 });
2017 let server = FakeServer::for_client(42, &client, cx).await;
2018 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2019 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
2020
2021 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2022 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2023 let completion_task = zeta.update(cx, |zeta, cx| {
2024 zeta.request_completion(None, &buffer, cursor, false, cx)
2025 });
2026
2027 server.receive::<proto::GetUsers>().await.unwrap();
2028 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
2029 server.respond(
2030 token_request.receipt(),
2031 proto::GetLlmTokenResponse { token: "".into() },
2032 );
2033
2034 let completion = completion_task.await.unwrap().unwrap();
2035 buffer.update(cx, |buffer, cx| {
2036 buffer.edit(completion.edits.iter().cloned(), None, cx)
2037 });
2038 assert_eq!(
2039 buffer.read_with(cx, |buffer, _| buffer.text()),
2040 "lorem\nipsum"
2041 );
2042 }
2043
2044 async fn edits_for_prediction(
2045 buffer_content: &str,
2046 completion_response: &str,
2047 cx: &mut TestAppContext,
2048 ) -> Vec<(Range<Point>, String)> {
2049 let completion_response = completion_response.to_string();
2050 let http_client = FakeHttpClient::create(move |_| {
2051 let completion = completion_response.clone();
2052 async move {
2053 Ok(http_client::Response::builder()
2054 .status(200)
2055 .body(
2056 serde_json::to_string(&PredictEditsResponse {
2057 request_id: Uuid::new_v4(),
2058 output_excerpt: completion,
2059 })
2060 .unwrap()
2061 .into(),
2062 )
2063 .unwrap())
2064 }
2065 });
2066
2067 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2068 cx.update(|cx| {
2069 RefreshLlmTokenListener::register(client.clone(), cx);
2070 });
2071 let server = FakeServer::for_client(42, &client, cx).await;
2072 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2073 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
2074
2075 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2076 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2077 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2078 let completion_task = zeta.update(cx, |zeta, cx| {
2079 zeta.request_completion(None, &buffer, cursor, false, cx)
2080 });
2081
2082 server.receive::<proto::GetUsers>().await.unwrap();
2083 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
2084 server.respond(
2085 token_request.receipt(),
2086 proto::GetLlmTokenResponse { token: "".into() },
2087 );
2088
2089 let completion = completion_task.await.unwrap().unwrap();
2090 completion
2091 .edits
2092 .into_iter()
2093 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2094 .collect::<Vec<_>>()
2095 }
2096
2097 fn to_completion_edits(
2098 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2099 buffer: &Entity<Buffer>,
2100 cx: &App,
2101 ) -> Vec<(Range<Anchor>, String)> {
2102 let buffer = buffer.read(cx);
2103 iterator
2104 .into_iter()
2105 .map(|(range, text)| {
2106 (
2107 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2108 text,
2109 )
2110 })
2111 .collect()
2112 }
2113
2114 fn from_completion_edits(
2115 editor_edits: &[(Range<Anchor>, String)],
2116 buffer: &Entity<Buffer>,
2117 cx: &App,
2118 ) -> Vec<(Range<usize>, String)> {
2119 let buffer = buffer.read(cx);
2120 editor_edits
2121 .iter()
2122 .map(|(range, text)| {
2123 (
2124 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2125 text.clone(),
2126 )
2127 })
2128 .collect()
2129 }
2130
2131 #[ctor::ctor]
2132 fn init_logger() {
2133 zlog::init_test();
2134 }
2135}