1mod completion_diff_element;
2mod init;
3mod input_excerpt;
4mod license_detection;
5mod onboarding_modal;
6mod onboarding_telemetry;
7mod rate_completion_modal;
8
9pub(crate) use completion_diff_element::*;
10use db::kvp::{Dismissable, KEY_VALUE_STORE};
11pub use init::*;
12use inline_completion::DataCollectionState;
13use license_detection::LICENSE_FILES_TO_CHECK;
14pub use license_detection::is_license_eligible_for_data_collection;
15pub use rate_completion_modal::*;
16
17use anyhow::{Context as _, Result, anyhow};
18use arrayvec::ArrayVec;
19use client::{Client, EditPredictionUsage, UserStore};
20use collections::{HashMap, HashSet, VecDeque};
21use futures::AsyncReadExt;
22use gpui::{
23 App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
24 Subscription, Task, WeakEntity, actions,
25};
26use http_client::{AsyncBody, HttpClient, Method, Request, Response};
27use input_excerpt::excerpt_for_cursor_position;
28use language::{
29 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
30};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use postage::watch;
33use project::Project;
34use release_channel::AppVersion;
35use settings::WorktreeId;
36use std::str::FromStr;
37use std::{
38 borrow::Cow,
39 cmp,
40 fmt::Write,
41 future::Future,
42 mem,
43 ops::Range,
44 path::Path,
45 rc::Rc,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use telemetry_events::InlineCompletionRating;
50use thiserror::Error;
51use util::ResultExt;
52use uuid::Uuid;
53use workspace::Workspace;
54use workspace::notifications::{ErrorMessagePrompt, NotificationId};
55use worktree::Worktree;
56use zed_llm_client::{
57 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
58 PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
59};
60
61const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
62const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
63const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
64const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
67
68const MAX_CONTEXT_TOKENS: usize = 150;
69const MAX_REWRITE_TOKENS: usize = 350;
70const MAX_EVENT_TOKENS: usize = 500;
71
72/// Maximum number of events to track.
73const MAX_EVENT_COUNT: usize = 16;
74
75actions!(
76 edit_prediction,
77 [
78 /// Clears the edit prediction history.
79 ClearHistory
80 ]
81);
82
83#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
84pub struct InlineCompletionId(Uuid);
85
86impl From<InlineCompletionId> for gpui::ElementId {
87 fn from(value: InlineCompletionId) -> Self {
88 gpui::ElementId::Uuid(value.0)
89 }
90}
91
92impl std::fmt::Display for InlineCompletionId {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98struct ZedPredictUpsell;
99
100impl Dismissable for ZedPredictUpsell {
101 const KEY: &'static str = "dismissed-edit-predict-upsell";
102
103 fn dismissed() -> bool {
104 // To make this backwards compatible with older versions of Zed, we
105 // check if the user has seen the previous Edit Prediction Onboarding
106 // before, by checking the data collection choice which was written to
107 // the database once the user clicked on "Accept and Enable"
108 if KEY_VALUE_STORE
109 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
110 .log_err()
111 .map_or(false, |s| s.is_some())
112 {
113 return true;
114 }
115
116 KEY_VALUE_STORE
117 .read_kvp(Self::KEY)
118 .log_err()
119 .map_or(false, |s| s.is_some())
120 }
121}
122
123pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool {
124 match user_store.read(cx).current_user_has_accepted_terms() {
125 Some(true) => !ZedPredictUpsell::dismissed(),
126 Some(false) | None => true,
127 }
128}
129
130#[derive(Clone)]
131struct ZetaGlobal(Entity<Zeta>);
132
133impl Global for ZetaGlobal {}
134
135#[derive(Clone)]
136pub struct InlineCompletion {
137 id: InlineCompletionId,
138 path: Arc<Path>,
139 excerpt_range: Range<usize>,
140 cursor_offset: usize,
141 edits: Arc<[(Range<Anchor>, String)]>,
142 snapshot: BufferSnapshot,
143 edit_preview: EditPreview,
144 input_outline: Arc<str>,
145 input_events: Arc<str>,
146 input_excerpt: Arc<str>,
147 output_excerpt: Arc<str>,
148 request_sent_at: Instant,
149 response_received_at: Instant,
150}
151
152impl InlineCompletion {
153 fn latency(&self) -> Duration {
154 self.response_received_at
155 .duration_since(self.request_sent_at)
156 }
157
158 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
159 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
160 }
161}
162
163fn interpolate(
164 old_snapshot: &BufferSnapshot,
165 new_snapshot: &BufferSnapshot,
166 current_edits: Arc<[(Range<Anchor>, String)]>,
167) -> Option<Vec<(Range<Anchor>, String)>> {
168 let mut edits = Vec::new();
169
170 let mut model_edits = current_edits.into_iter().peekable();
171 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
172 while let Some((model_old_range, _)) = model_edits.peek() {
173 let model_old_range = model_old_range.to_offset(old_snapshot);
174 if model_old_range.end < user_edit.old.start {
175 let (model_old_range, model_new_text) = model_edits.next().unwrap();
176 edits.push((model_old_range.clone(), model_new_text.clone()));
177 } else {
178 break;
179 }
180 }
181
182 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
183 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
184 if user_edit.old == model_old_offset_range {
185 let user_new_text = new_snapshot
186 .text_for_range(user_edit.new.clone())
187 .collect::<String>();
188
189 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
190 if !model_suffix.is_empty() {
191 let anchor = old_snapshot.anchor_after(user_edit.old.end);
192 edits.push((anchor..anchor, model_suffix.to_string()));
193 }
194
195 model_edits.next();
196 continue;
197 }
198 }
199 }
200
201 return None;
202 }
203
204 edits.extend(model_edits.cloned());
205
206 if edits.is_empty() { None } else { Some(edits) }
207}
208
209impl std::fmt::Debug for InlineCompletion {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("InlineCompletion")
212 .field("id", &self.id)
213 .field("path", &self.path)
214 .field("edits", &self.edits)
215 .finish_non_exhaustive()
216 }
217}
218
219pub struct Zeta {
220 workspace: Option<WeakEntity<Workspace>>,
221 client: Arc<Client>,
222 events: VecDeque<Event>,
223 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
224 shown_completions: VecDeque<InlineCompletion>,
225 rated_completions: HashSet<InlineCompletionId>,
226 data_collection_choice: Entity<DataCollectionChoice>,
227 llm_token: LlmApiToken,
228 _llm_token_subscription: Subscription,
229 /// Whether the terms of service have been accepted.
230 tos_accepted: bool,
231 /// Whether an update to a newer version of Zed is required to continue using Zeta.
232 update_required: bool,
233 user_store: Entity<UserStore>,
234 _user_store_subscription: Subscription,
235 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
236}
237
238impl Zeta {
239 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
240 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
241 }
242
243 pub fn register(
244 workspace: Option<WeakEntity<Workspace>>,
245 worktree: Option<Entity<Worktree>>,
246 client: Arc<Client>,
247 user_store: Entity<UserStore>,
248 cx: &mut App,
249 ) -> Entity<Self> {
250 let this = Self::global(cx).unwrap_or_else(|| {
251 let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
252 cx.set_global(ZetaGlobal(entity.clone()));
253 entity
254 });
255
256 this.update(cx, move |this, cx| {
257 if let Some(worktree) = worktree {
258 worktree.update(cx, |worktree, cx| {
259 this.license_detection_watchers
260 .entry(worktree.id())
261 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
262 });
263 }
264 });
265
266 this
267 }
268
269 pub fn clear_history(&mut self) {
270 self.events.clear();
271 }
272
273 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
274 self.user_store.read(cx).edit_prediction_usage()
275 }
276
277 fn new(
278 workspace: Option<WeakEntity<Workspace>>,
279 client: Arc<Client>,
280 user_store: Entity<UserStore>,
281 cx: &mut Context<Self>,
282 ) -> Self {
283 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
284
285 let data_collection_choice = Self::load_data_collection_choices();
286 let data_collection_choice = cx.new(|_| data_collection_choice);
287
288 Self {
289 workspace,
290 client,
291 events: VecDeque::new(),
292 shown_completions: VecDeque::new(),
293 rated_completions: HashSet::default(),
294 registered_buffers: HashMap::default(),
295 data_collection_choice,
296 llm_token: LlmApiToken::default(),
297 _llm_token_subscription: cx.subscribe(
298 &refresh_llm_token_listener,
299 |this, _listener, _event, cx| {
300 let client = this.client.clone();
301 let llm_token = this.llm_token.clone();
302 cx.spawn(async move |_this, _cx| {
303 llm_token.refresh(&client).await?;
304 anyhow::Ok(())
305 })
306 .detach_and_log_err(cx);
307 },
308 ),
309 tos_accepted: user_store
310 .read(cx)
311 .current_user_has_accepted_terms()
312 .unwrap_or(false),
313 update_required: false,
314 _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
315 match event {
316 client::user::Event::PrivateUserInfoUpdated => {
317 this.tos_accepted = user_store
318 .read(cx)
319 .current_user_has_accepted_terms()
320 .unwrap_or(false);
321 }
322 _ => {}
323 }
324 }),
325 license_detection_watchers: HashMap::default(),
326 user_store,
327 }
328 }
329
330 fn push_event(&mut self, event: Event) {
331 if let Some(Event::BufferChange {
332 new_snapshot: last_new_snapshot,
333 timestamp: last_timestamp,
334 ..
335 }) = self.events.back_mut()
336 {
337 // Coalesce edits for the same buffer when they happen one after the other.
338 let Event::BufferChange {
339 old_snapshot,
340 new_snapshot,
341 timestamp,
342 } = &event;
343
344 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
345 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
346 && old_snapshot.version == last_new_snapshot.version
347 {
348 *last_new_snapshot = new_snapshot.clone();
349 *last_timestamp = *timestamp;
350 return;
351 }
352 }
353
354 self.events.push_back(event);
355 if self.events.len() >= MAX_EVENT_COUNT {
356 // These are halved instead of popping to improve prompt caching.
357 self.events.drain(..MAX_EVENT_COUNT / 2);
358 }
359 }
360
361 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
362 let buffer_id = buffer.entity_id();
363 let weak_buffer = buffer.downgrade();
364
365 if let std::collections::hash_map::Entry::Vacant(entry) =
366 self.registered_buffers.entry(buffer_id)
367 {
368 let snapshot = buffer.read(cx).snapshot();
369
370 entry.insert(RegisteredBuffer {
371 snapshot,
372 _subscriptions: [
373 cx.subscribe(buffer, move |this, buffer, event, cx| {
374 this.handle_buffer_event(buffer, event, cx);
375 }),
376 cx.observe_release(buffer, move |this, _buffer, _cx| {
377 this.registered_buffers.remove(&weak_buffer.entity_id());
378 }),
379 ],
380 });
381 };
382 }
383
384 fn handle_buffer_event(
385 &mut self,
386 buffer: Entity<Buffer>,
387 event: &language::BufferEvent,
388 cx: &mut Context<Self>,
389 ) {
390 if let language::BufferEvent::Edited = event {
391 self.report_changes_for_buffer(&buffer, cx);
392 }
393 }
394
395 fn request_completion_impl<F, R>(
396 &mut self,
397 workspace: Option<Entity<Workspace>>,
398 project: Option<&Entity<Project>>,
399 buffer: &Entity<Buffer>,
400 cursor: language::Anchor,
401 can_collect_data: bool,
402 cx: &mut Context<Self>,
403 perform_predict_edits: F,
404 ) -> Task<Result<Option<InlineCompletion>>>
405 where
406 F: FnOnce(PerformPredictEditsParams) -> R + 'static,
407 R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
408 + Send
409 + 'static,
410 {
411 let snapshot = self.report_changes_for_buffer(&buffer, cx);
412 let diagnostic_groups = snapshot.diagnostic_groups(None);
413 let cursor_point = cursor.to_point(&snapshot);
414 let cursor_offset = cursor_point.to_offset(&snapshot);
415 let events = self.events.clone();
416 let path: Arc<Path> = snapshot
417 .file()
418 .map(|f| Arc::from(f.full_path(cx).as_path()))
419 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
420
421 let zeta = cx.entity();
422 let client = self.client.clone();
423 let llm_token = self.llm_token.clone();
424 let app_version = AppVersion::global(cx);
425
426 let buffer = buffer.clone();
427
428 let local_lsp_store =
429 project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
430 let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
431 Some(
432 diagnostic_groups
433 .into_iter()
434 .filter_map(|(language_server_id, diagnostic_group)| {
435 let language_server =
436 local_lsp_store.running_language_server_for_id(language_server_id)?;
437
438 Some((
439 language_server.name(),
440 diagnostic_group.resolve::<usize>(&snapshot),
441 ))
442 })
443 .collect::<Vec<_>>(),
444 )
445 } else {
446 None
447 };
448
449 cx.spawn(async move |this, cx| {
450 let request_sent_at = Instant::now();
451
452 struct BackgroundValues {
453 input_events: String,
454 input_excerpt: String,
455 speculated_output: String,
456 editable_range: Range<usize>,
457 input_outline: String,
458 }
459
460 let values = cx
461 .background_spawn({
462 let snapshot = snapshot.clone();
463 let path = path.clone();
464 async move {
465 let path = path.to_string_lossy();
466 let input_excerpt = excerpt_for_cursor_position(
467 cursor_point,
468 &path,
469 &snapshot,
470 MAX_REWRITE_TOKENS,
471 MAX_CONTEXT_TOKENS,
472 );
473 let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
474 let input_outline = prompt_for_outline(&snapshot);
475
476 anyhow::Ok(BackgroundValues {
477 input_events,
478 input_excerpt: input_excerpt.prompt,
479 speculated_output: input_excerpt.speculated_output,
480 editable_range: input_excerpt.editable_range.to_offset(&snapshot),
481 input_outline,
482 })
483 }
484 })
485 .await?;
486
487 log::debug!(
488 "Events:\n{}\nExcerpt:\n{:?}",
489 values.input_events,
490 values.input_excerpt
491 );
492
493 let body = PredictEditsBody {
494 input_events: values.input_events.clone(),
495 input_excerpt: values.input_excerpt.clone(),
496 speculated_output: Some(values.speculated_output),
497 outline: Some(values.input_outline.clone()),
498 can_collect_data,
499 diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
500 diagnostic_groups
501 .into_iter()
502 .map(|(name, diagnostic_group)| {
503 Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
504 })
505 .collect::<Result<Vec<_>>>()
506 .log_err()
507 }),
508 };
509
510 let response = perform_predict_edits(PerformPredictEditsParams {
511 client,
512 llm_token,
513 app_version,
514 body,
515 })
516 .await;
517 let (response, usage) = match response {
518 Ok(response) => response,
519 Err(err) => {
520 if err.is::<ZedUpdateRequiredError>() {
521 cx.update(|cx| {
522 zeta.update(cx, |zeta, _cx| {
523 zeta.update_required = true;
524 });
525
526 if let Some(workspace) = workspace {
527 workspace.update(cx, |workspace, cx| {
528 workspace.show_notification(
529 NotificationId::unique::<ZedUpdateRequiredError>(),
530 cx,
531 |cx| {
532 cx.new(|cx| {
533 ErrorMessagePrompt::new(err.to_string(), cx)
534 .with_link_button(
535 "Update Zed",
536 "https://zed.dev/releases",
537 )
538 })
539 },
540 );
541 });
542 }
543 })
544 .ok();
545 }
546
547 return Err(err);
548 }
549 };
550
551 log::debug!("completion response: {}", &response.output_excerpt);
552
553 if let Some(usage) = usage {
554 this.update(cx, |this, cx| {
555 this.user_store.update(cx, |user_store, cx| {
556 user_store.update_edit_prediction_usage(usage, cx);
557 });
558 })
559 .ok();
560 }
561
562 Self::process_completion_response(
563 response,
564 buffer,
565 &snapshot,
566 values.editable_range,
567 cursor_offset,
568 path,
569 values.input_outline,
570 values.input_events,
571 values.input_excerpt,
572 request_sent_at,
573 &cx,
574 )
575 .await
576 })
577 }
578
579 // Generates several example completions of various states to fill the Zeta completion modal
580 #[cfg(any(test, feature = "test-support"))]
581 pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
582 use language::Point;
583
584 let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
585 And maybe a short line
586
587 Then a few lines
588
589 and then another
590 "#};
591
592 let project = None;
593 let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
594 let position = buffer.read(cx).anchor_before(Point::new(1, 0));
595
596 let completion_tasks = vec![
597 self.fake_completion(
598 project,
599 &buffer,
600 position,
601 PredictEditsResponse {
602 request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
603 output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
604a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
605[here's an edit]
606And maybe a short line
607Then a few lines
608and then another
609{EDITABLE_REGION_END_MARKER}
610 ", ),
611 },
612 cx,
613 ),
614 self.fake_completion(
615 project,
616 &buffer,
617 position,
618 PredictEditsResponse {
619 request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
620 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
621a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
622And maybe a short line
623[and another edit]
624Then a few lines
625and then another
626{EDITABLE_REGION_END_MARKER}
627 "#),
628 },
629 cx,
630 ),
631 self.fake_completion(
632 project,
633 &buffer,
634 position,
635 PredictEditsResponse {
636 request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
637 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
638a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
639And maybe a short line
640
641Then a few lines
642
643and then another
644{EDITABLE_REGION_END_MARKER}
645 "#),
646 },
647 cx,
648 ),
649 self.fake_completion(
650 project,
651 &buffer,
652 position,
653 PredictEditsResponse {
654 request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
655 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
656a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
657And maybe a short line
658
659Then a few lines
660
661and then another
662{EDITABLE_REGION_END_MARKER}
663 "#),
664 },
665 cx,
666 ),
667 self.fake_completion(
668 project,
669 &buffer,
670 position,
671 PredictEditsResponse {
672 request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
673 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
674a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
675And maybe a short line
676Then a few lines
677[a third completion]
678and then another
679{EDITABLE_REGION_END_MARKER}
680 "#),
681 },
682 cx,
683 ),
684 self.fake_completion(
685 project,
686 &buffer,
687 position,
688 PredictEditsResponse {
689 request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
690 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
691a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
692And maybe a short line
693and then another
694[fourth completion example]
695{EDITABLE_REGION_END_MARKER}
696 "#),
697 },
698 cx,
699 ),
700 self.fake_completion(
701 project,
702 &buffer,
703 position,
704 PredictEditsResponse {
705 request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
706 output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
707a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
708And maybe a short line
709Then a few lines
710and then another
711[fifth and final completion]
712{EDITABLE_REGION_END_MARKER}
713 "#),
714 },
715 cx,
716 ),
717 ];
718
719 cx.spawn(async move |zeta, cx| {
720 for task in completion_tasks {
721 task.await.unwrap();
722 }
723
724 zeta.update(cx, |zeta, _cx| {
725 zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
726 zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
727 })
728 .ok();
729 })
730 }
731
732 #[cfg(any(test, feature = "test-support"))]
733 pub fn fake_completion(
734 &mut self,
735 project: Option<&Entity<Project>>,
736 buffer: &Entity<Buffer>,
737 position: language::Anchor,
738 response: PredictEditsResponse,
739 cx: &mut Context<Self>,
740 ) -> Task<Result<Option<InlineCompletion>>> {
741 use std::future::ready;
742
743 self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
744 ready(Ok((response, None)))
745 })
746 }
747
748 pub fn request_completion(
749 &mut self,
750 project: Option<&Entity<Project>>,
751 buffer: &Entity<Buffer>,
752 position: language::Anchor,
753 can_collect_data: bool,
754 cx: &mut Context<Self>,
755 ) -> Task<Result<Option<InlineCompletion>>> {
756 let workspace = self
757 .workspace
758 .as_ref()
759 .and_then(|workspace| workspace.upgrade());
760 self.request_completion_impl(
761 workspace,
762 project,
763 buffer,
764 position,
765 can_collect_data,
766 cx,
767 Self::perform_predict_edits,
768 )
769 }
770
771 fn perform_predict_edits(
772 params: PerformPredictEditsParams,
773 ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
774 async move {
775 let PerformPredictEditsParams {
776 client,
777 llm_token,
778 app_version,
779 body,
780 ..
781 } = params;
782
783 let http_client = client.http_client();
784 let mut token = llm_token.acquire(&client).await?;
785 let mut did_retry = false;
786
787 loop {
788 let request_builder = http_client::Request::builder().method(Method::POST);
789 let request_builder =
790 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
791 request_builder.uri(predict_edits_url)
792 } else {
793 request_builder.uri(
794 http_client
795 .build_zed_llm_url("/predict_edits/v2", &[])?
796 .as_ref(),
797 )
798 };
799 let request = request_builder
800 .header("Content-Type", "application/json")
801 .header("Authorization", format!("Bearer {}", token))
802 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
803 .body(serde_json::to_string(&body)?.into())?;
804
805 let mut response = http_client.send(request).await?;
806
807 if let Some(minimum_required_version) = response
808 .headers()
809 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
810 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
811 {
812 anyhow::ensure!(
813 app_version >= minimum_required_version,
814 ZedUpdateRequiredError {
815 minimum_version: minimum_required_version
816 }
817 );
818 }
819
820 if response.status().is_success() {
821 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
822
823 let mut body = String::new();
824 response.body_mut().read_to_string(&mut body).await?;
825 return Ok((serde_json::from_str(&body)?, usage));
826 } else if !did_retry
827 && response
828 .headers()
829 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
830 .is_some()
831 {
832 did_retry = true;
833 token = llm_token.refresh(&client).await?;
834 } else {
835 let mut body = String::new();
836 response.body_mut().read_to_string(&mut body).await?;
837 anyhow::bail!(
838 "error predicting edits.\nStatus: {:?}\nBody: {}",
839 response.status(),
840 body
841 );
842 }
843 }
844 }
845 }
846
847 fn accept_edit_prediction(
848 &mut self,
849 request_id: InlineCompletionId,
850 cx: &mut Context<Self>,
851 ) -> Task<Result<()>> {
852 let client = self.client.clone();
853 let llm_token = self.llm_token.clone();
854 let app_version = AppVersion::global(cx);
855 cx.spawn(async move |this, cx| {
856 let http_client = client.http_client();
857 let mut response = llm_token_retry(&llm_token, &client, |token| {
858 let request_builder = http_client::Request::builder().method(Method::POST);
859 let request_builder =
860 if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
861 request_builder.uri(accept_prediction_url)
862 } else {
863 request_builder.uri(
864 http_client
865 .build_zed_llm_url("/predict_edits/accept", &[])?
866 .as_ref(),
867 )
868 };
869 Ok(request_builder
870 .header("Content-Type", "application/json")
871 .header("Authorization", format!("Bearer {}", token))
872 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
873 .body(
874 serde_json::to_string(&AcceptEditPredictionBody {
875 request_id: request_id.0,
876 })?
877 .into(),
878 )?)
879 })
880 .await?;
881
882 if let Some(minimum_required_version) = response
883 .headers()
884 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
885 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
886 {
887 if app_version < minimum_required_version {
888 return Err(anyhow!(ZedUpdateRequiredError {
889 minimum_version: minimum_required_version
890 }));
891 }
892 }
893
894 if response.status().is_success() {
895 if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
896 this.update(cx, |this, cx| {
897 this.user_store.update(cx, |user_store, cx| {
898 user_store.update_edit_prediction_usage(usage, cx);
899 });
900 })?;
901 }
902
903 Ok(())
904 } else {
905 let mut body = String::new();
906 response.body_mut().read_to_string(&mut body).await?;
907 Err(anyhow!(
908 "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
909 response.status(),
910 body
911 ))
912 }
913 })
914 }
915
916 fn process_completion_response(
917 prediction_response: PredictEditsResponse,
918 buffer: Entity<Buffer>,
919 snapshot: &BufferSnapshot,
920 editable_range: Range<usize>,
921 cursor_offset: usize,
922 path: Arc<Path>,
923 input_outline: String,
924 input_events: String,
925 input_excerpt: String,
926 request_sent_at: Instant,
927 cx: &AsyncApp,
928 ) -> Task<Result<Option<InlineCompletion>>> {
929 let snapshot = snapshot.clone();
930 let request_id = prediction_response.request_id;
931 let output_excerpt = prediction_response.output_excerpt;
932 cx.spawn(async move |cx| {
933 let output_excerpt: Arc<str> = output_excerpt.into();
934
935 let edits: Arc<[(Range<Anchor>, String)]> = cx
936 .background_spawn({
937 let output_excerpt = output_excerpt.clone();
938 let editable_range = editable_range.clone();
939 let snapshot = snapshot.clone();
940 async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
941 })
942 .await?
943 .into();
944
945 let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
946 let edits = edits.clone();
947 |buffer, cx| {
948 let new_snapshot = buffer.snapshot();
949 let edits: Arc<[(Range<Anchor>, String)]> =
950 interpolate(&snapshot, &new_snapshot, edits)?.into();
951 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
952 }
953 })?
954 else {
955 return anyhow::Ok(None);
956 };
957
958 let edit_preview = edit_preview.await;
959
960 Ok(Some(InlineCompletion {
961 id: InlineCompletionId(request_id),
962 path,
963 excerpt_range: editable_range,
964 cursor_offset,
965 edits,
966 edit_preview,
967 snapshot,
968 input_outline: input_outline.into(),
969 input_events: input_events.into(),
970 input_excerpt: input_excerpt.into(),
971 output_excerpt,
972 request_sent_at,
973 response_received_at: Instant::now(),
974 }))
975 })
976 }
977
978 fn parse_edits(
979 output_excerpt: Arc<str>,
980 editable_range: Range<usize>,
981 snapshot: &BufferSnapshot,
982 ) -> Result<Vec<(Range<Anchor>, String)>> {
983 let content = output_excerpt.replace(CURSOR_MARKER, "");
984
985 let start_markers = content
986 .match_indices(EDITABLE_REGION_START_MARKER)
987 .collect::<Vec<_>>();
988 anyhow::ensure!(
989 start_markers.len() == 1,
990 "expected exactly one start marker, found {}",
991 start_markers.len()
992 );
993
994 let end_markers = content
995 .match_indices(EDITABLE_REGION_END_MARKER)
996 .collect::<Vec<_>>();
997 anyhow::ensure!(
998 end_markers.len() == 1,
999 "expected exactly one end marker, found {}",
1000 end_markers.len()
1001 );
1002
1003 let sof_markers = content
1004 .match_indices(START_OF_FILE_MARKER)
1005 .collect::<Vec<_>>();
1006 anyhow::ensure!(
1007 sof_markers.len() <= 1,
1008 "expected at most one start-of-file marker, found {}",
1009 sof_markers.len()
1010 );
1011
1012 let codefence_start = start_markers[0].0;
1013 let content = &content[codefence_start..];
1014
1015 let newline_ix = content.find('\n').context("could not find newline")?;
1016 let content = &content[newline_ix + 1..];
1017
1018 let codefence_end = content
1019 .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1020 .context("could not find end marker")?;
1021 let new_text = &content[..codefence_end];
1022
1023 let old_text = snapshot
1024 .text_for_range(editable_range.clone())
1025 .collect::<String>();
1026
1027 Ok(Self::compute_edits(
1028 old_text,
1029 new_text,
1030 editable_range.start,
1031 &snapshot,
1032 ))
1033 }
1034
1035 pub fn compute_edits(
1036 old_text: String,
1037 new_text: &str,
1038 offset: usize,
1039 snapshot: &BufferSnapshot,
1040 ) -> Vec<(Range<Anchor>, String)> {
1041 text_diff(&old_text, &new_text)
1042 .into_iter()
1043 .map(|(mut old_range, new_text)| {
1044 old_range.start += offset;
1045 old_range.end += offset;
1046
1047 let prefix_len = common_prefix(
1048 snapshot.chars_for_range(old_range.clone()),
1049 new_text.chars(),
1050 );
1051 old_range.start += prefix_len;
1052
1053 let suffix_len = common_prefix(
1054 snapshot.reversed_chars_for_range(old_range.clone()),
1055 new_text[prefix_len..].chars().rev(),
1056 );
1057 old_range.end = old_range.end.saturating_sub(suffix_len);
1058
1059 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1060 let range = if old_range.is_empty() {
1061 let anchor = snapshot.anchor_after(old_range.start);
1062 anchor..anchor
1063 } else {
1064 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1065 };
1066 (range, new_text)
1067 })
1068 .collect()
1069 }
1070
1071 pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
1072 self.rated_completions.contains(&completion_id)
1073 }
1074
1075 pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
1076 self.shown_completions.push_front(completion.clone());
1077 if self.shown_completions.len() > 50 {
1078 let completion = self.shown_completions.pop_back().unwrap();
1079 self.rated_completions.remove(&completion.id);
1080 }
1081 cx.notify();
1082 }
1083
1084 pub fn rate_completion(
1085 &mut self,
1086 completion: &InlineCompletion,
1087 rating: InlineCompletionRating,
1088 feedback: String,
1089 cx: &mut Context<Self>,
1090 ) {
1091 self.rated_completions.insert(completion.id);
1092 telemetry::event!(
1093 "Edit Prediction Rated",
1094 rating,
1095 input_events = completion.input_events,
1096 input_excerpt = completion.input_excerpt,
1097 input_outline = completion.input_outline,
1098 output_excerpt = completion.output_excerpt,
1099 feedback
1100 );
1101 self.client.telemetry().flush_events().detach();
1102 cx.notify();
1103 }
1104
1105 pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
1106 self.shown_completions.iter()
1107 }
1108
1109 pub fn shown_completions_len(&self) -> usize {
1110 self.shown_completions.len()
1111 }
1112
1113 fn report_changes_for_buffer(
1114 &mut self,
1115 buffer: &Entity<Buffer>,
1116 cx: &mut Context<Self>,
1117 ) -> BufferSnapshot {
1118 self.register_buffer(buffer, cx);
1119
1120 let registered_buffer = self
1121 .registered_buffers
1122 .get_mut(&buffer.entity_id())
1123 .unwrap();
1124 let new_snapshot = buffer.read(cx).snapshot();
1125
1126 if new_snapshot.version != registered_buffer.snapshot.version {
1127 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1128 self.push_event(Event::BufferChange {
1129 old_snapshot,
1130 new_snapshot: new_snapshot.clone(),
1131 timestamp: Instant::now(),
1132 });
1133 }
1134
1135 new_snapshot
1136 }
1137
1138 fn load_data_collection_choices() -> DataCollectionChoice {
1139 let choice = KEY_VALUE_STORE
1140 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1141 .log_err()
1142 .flatten();
1143
1144 match choice.as_deref() {
1145 Some("true") => DataCollectionChoice::Enabled,
1146 Some("false") => DataCollectionChoice::Disabled,
1147 Some(_) => {
1148 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1149 DataCollectionChoice::NotAnswered
1150 }
1151 None => DataCollectionChoice::NotAnswered,
1152 }
1153 }
1154}
1155
1156struct PerformPredictEditsParams {
1157 pub client: Arc<Client>,
1158 pub llm_token: LlmApiToken,
1159 pub app_version: SemanticVersion,
1160 pub body: PredictEditsBody,
1161}
1162
1163#[derive(Error, Debug)]
1164#[error(
1165 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1166)]
1167pub struct ZedUpdateRequiredError {
1168 minimum_version: SemanticVersion,
1169}
1170
1171struct LicenseDetectionWatcher {
1172 is_open_source_rx: watch::Receiver<bool>,
1173 _is_open_source_task: Task<()>,
1174}
1175
1176impl LicenseDetectionWatcher {
1177 pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1178 let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1179
1180 // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1181 let task = if worktree.abs_path().is_file() {
1182 Task::ready(())
1183 } else {
1184 let loaded_files = LICENSE_FILES_TO_CHECK
1185 .iter()
1186 .map(Path::new)
1187 .map(|file| worktree.load_file(file, cx))
1188 .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1189
1190 cx.background_spawn(async move {
1191 for loaded_file in loaded_files.into_iter() {
1192 let Ok(loaded_file) = loaded_file.await else {
1193 continue;
1194 };
1195
1196 let path = &loaded_file.file.path;
1197 if is_license_eligible_for_data_collection(&loaded_file.text) {
1198 log::info!("detected '{path:?}' as open source license");
1199 *is_open_source_tx.borrow_mut() = true;
1200 } else {
1201 log::info!("didn't detect '{path:?}' as open source license");
1202 }
1203
1204 // stop on the first license that successfully read
1205 return;
1206 }
1207
1208 log::debug!("didn't find a license file to check, assuming closed source");
1209 })
1210 };
1211
1212 Self {
1213 is_open_source_rx,
1214 _is_open_source_task: task,
1215 }
1216 }
1217
1218 /// Answers false until we find out it's open source
1219 pub fn is_project_open_source(&self) -> bool {
1220 *self.is_open_source_rx.borrow()
1221 }
1222}
1223
1224fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1225 a.zip(b)
1226 .take_while(|(a, b)| a == b)
1227 .map(|(a, _)| a.len_utf8())
1228 .sum()
1229}
1230
1231fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1232 let mut input_outline = String::new();
1233
1234 writeln!(
1235 input_outline,
1236 "```{}",
1237 snapshot
1238 .file()
1239 .map_or(Cow::Borrowed("untitled"), |file| file
1240 .path()
1241 .to_string_lossy())
1242 )
1243 .unwrap();
1244
1245 if let Some(outline) = snapshot.outline(None) {
1246 for item in &outline.items {
1247 let spacing = " ".repeat(item.depth);
1248 writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1249 }
1250 }
1251
1252 writeln!(input_outline, "```").unwrap();
1253
1254 input_outline
1255}
1256
1257fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1258 let mut result = String::new();
1259 for event in events.iter().rev() {
1260 let event_string = event.to_prompt();
1261 let event_tokens = tokens_for_bytes(event_string.len());
1262 if event_tokens > remaining_tokens {
1263 break;
1264 }
1265
1266 if !result.is_empty() {
1267 result.insert_str(0, "\n\n");
1268 }
1269 result.insert_str(0, &event_string);
1270 remaining_tokens -= event_tokens;
1271 }
1272 result
1273}
1274
1275struct RegisteredBuffer {
1276 snapshot: BufferSnapshot,
1277 _subscriptions: [gpui::Subscription; 2],
1278}
1279
1280#[derive(Clone)]
1281enum Event {
1282 BufferChange {
1283 old_snapshot: BufferSnapshot,
1284 new_snapshot: BufferSnapshot,
1285 timestamp: Instant,
1286 },
1287}
1288
1289impl Event {
1290 fn to_prompt(&self) -> String {
1291 match self {
1292 Event::BufferChange {
1293 old_snapshot,
1294 new_snapshot,
1295 ..
1296 } => {
1297 let mut prompt = String::new();
1298
1299 let old_path = old_snapshot
1300 .file()
1301 .map(|f| f.path().as_ref())
1302 .unwrap_or(Path::new("untitled"));
1303 let new_path = new_snapshot
1304 .file()
1305 .map(|f| f.path().as_ref())
1306 .unwrap_or(Path::new("untitled"));
1307 if old_path != new_path {
1308 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1309 }
1310
1311 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1312 if !diff.is_empty() {
1313 write!(
1314 prompt,
1315 "User edited {:?}:\n```diff\n{}\n```",
1316 new_path, diff
1317 )
1318 .unwrap();
1319 }
1320
1321 prompt
1322 }
1323 }
1324 }
1325}
1326
1327#[derive(Debug, Clone)]
1328struct CurrentInlineCompletion {
1329 buffer_id: EntityId,
1330 completion: InlineCompletion,
1331}
1332
1333impl CurrentInlineCompletion {
1334 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1335 if self.buffer_id != old_completion.buffer_id {
1336 return true;
1337 }
1338
1339 let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1340 return true;
1341 };
1342 let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1343 return false;
1344 };
1345
1346 if old_edits.len() == 1 && new_edits.len() == 1 {
1347 let (old_range, old_text) = &old_edits[0];
1348 let (new_range, new_text) = &new_edits[0];
1349 new_range == old_range && new_text.starts_with(old_text)
1350 } else {
1351 true
1352 }
1353 }
1354}
1355
1356struct PendingCompletion {
1357 id: usize,
1358 _task: Task<()>,
1359}
1360
1361#[derive(Debug, Clone, Copy)]
1362pub enum DataCollectionChoice {
1363 NotAnswered,
1364 Enabled,
1365 Disabled,
1366}
1367
1368impl DataCollectionChoice {
1369 pub fn is_enabled(self) -> bool {
1370 match self {
1371 Self::Enabled => true,
1372 Self::NotAnswered | Self::Disabled => false,
1373 }
1374 }
1375
1376 pub fn is_answered(self) -> bool {
1377 match self {
1378 Self::Enabled | Self::Disabled => true,
1379 Self::NotAnswered => false,
1380 }
1381 }
1382
1383 pub fn toggle(&self) -> DataCollectionChoice {
1384 match self {
1385 Self::Enabled => Self::Disabled,
1386 Self::Disabled => Self::Enabled,
1387 Self::NotAnswered => Self::Enabled,
1388 }
1389 }
1390}
1391
1392impl From<bool> for DataCollectionChoice {
1393 fn from(value: bool) -> Self {
1394 match value {
1395 true => DataCollectionChoice::Enabled,
1396 false => DataCollectionChoice::Disabled,
1397 }
1398 }
1399}
1400
1401pub struct ProviderDataCollection {
1402 /// When set to None, data collection is not possible in the provider buffer
1403 choice: Option<Entity<DataCollectionChoice>>,
1404 license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1405}
1406
1407impl ProviderDataCollection {
1408 pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1409 let choice_and_watcher = buffer.and_then(|buffer| {
1410 let file = buffer.read(cx).file()?;
1411
1412 if !file.is_local() || file.is_private() {
1413 return None;
1414 }
1415
1416 let zeta = zeta.read(cx);
1417 let choice = zeta.data_collection_choice.clone();
1418
1419 let license_detection_watcher = zeta
1420 .license_detection_watchers
1421 .get(&file.worktree_id(cx))
1422 .cloned()?;
1423
1424 Some((choice, license_detection_watcher))
1425 });
1426
1427 if let Some((choice, watcher)) = choice_and_watcher {
1428 ProviderDataCollection {
1429 choice: Some(choice),
1430 license_detection_watcher: Some(watcher),
1431 }
1432 } else {
1433 ProviderDataCollection {
1434 choice: None,
1435 license_detection_watcher: None,
1436 }
1437 }
1438 }
1439
1440 pub fn can_collect_data(&self, cx: &App) -> bool {
1441 self.is_data_collection_enabled(cx) && self.is_project_open_source()
1442 }
1443
1444 pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1445 self.choice
1446 .as_ref()
1447 .is_some_and(|choice| choice.read(cx).is_enabled())
1448 }
1449
1450 fn is_project_open_source(&self) -> bool {
1451 self.license_detection_watcher
1452 .as_ref()
1453 .is_some_and(|watcher| watcher.is_project_open_source())
1454 }
1455
1456 pub fn toggle(&mut self, cx: &mut App) {
1457 if let Some(choice) = self.choice.as_mut() {
1458 let new_choice = choice.update(cx, |choice, _cx| {
1459 let new_choice = choice.toggle();
1460 *choice = new_choice;
1461 new_choice
1462 });
1463
1464 db::write_and_log(cx, move || {
1465 KEY_VALUE_STORE.write_kvp(
1466 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1467 new_choice.is_enabled().to_string(),
1468 )
1469 });
1470 }
1471 }
1472}
1473
1474async fn llm_token_retry(
1475 llm_token: &LlmApiToken,
1476 client: &Arc<Client>,
1477 build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1478) -> Result<Response<AsyncBody>> {
1479 let mut did_retry = false;
1480 let http_client = client.http_client();
1481 let mut token = llm_token.acquire(client).await?;
1482 loop {
1483 let request = build_request(token.clone())?;
1484 let response = http_client.send(request).await?;
1485
1486 if !did_retry
1487 && !response.status().is_success()
1488 && response
1489 .headers()
1490 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1491 .is_some()
1492 {
1493 did_retry = true;
1494 token = llm_token.refresh(client).await?;
1495 continue;
1496 }
1497
1498 return Ok(response);
1499 }
1500}
1501
1502pub struct ZetaInlineCompletionProvider {
1503 zeta: Entity<Zeta>,
1504 pending_completions: ArrayVec<PendingCompletion, 2>,
1505 next_pending_completion_id: usize,
1506 current_completion: Option<CurrentInlineCompletion>,
1507 /// None if this is entirely disabled for this provider
1508 provider_data_collection: ProviderDataCollection,
1509 last_request_timestamp: Instant,
1510}
1511
1512impl ZetaInlineCompletionProvider {
1513 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1514
1515 pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1516 Self {
1517 zeta,
1518 pending_completions: ArrayVec::new(),
1519 next_pending_completion_id: 0,
1520 current_completion: None,
1521 provider_data_collection,
1522 last_request_timestamp: Instant::now(),
1523 }
1524 }
1525}
1526
1527impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1528 fn name() -> &'static str {
1529 "zed-predict"
1530 }
1531
1532 fn display_name() -> &'static str {
1533 "Zed's Edit Predictions"
1534 }
1535
1536 fn show_completions_in_menu() -> bool {
1537 true
1538 }
1539
1540 fn show_tab_accept_marker() -> bool {
1541 true
1542 }
1543
1544 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1545 let is_project_open_source = self.provider_data_collection.is_project_open_source();
1546
1547 if self.provider_data_collection.is_data_collection_enabled(cx) {
1548 DataCollectionState::Enabled {
1549 is_project_open_source,
1550 }
1551 } else {
1552 DataCollectionState::Disabled {
1553 is_project_open_source,
1554 }
1555 }
1556 }
1557
1558 fn toggle_data_collection(&mut self, cx: &mut App) {
1559 self.provider_data_collection.toggle(cx);
1560 }
1561
1562 fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1563 self.zeta.read(cx).usage(cx)
1564 }
1565
1566 fn is_enabled(
1567 &self,
1568 _buffer: &Entity<Buffer>,
1569 _cursor_position: language::Anchor,
1570 _cx: &App,
1571 ) -> bool {
1572 true
1573 }
1574
1575 fn needs_terms_acceptance(&self, cx: &App) -> bool {
1576 !self.zeta.read(cx).tos_accepted
1577 }
1578
1579 fn is_refreshing(&self) -> bool {
1580 !self.pending_completions.is_empty()
1581 }
1582
1583 fn refresh(
1584 &mut self,
1585 project: Option<Entity<Project>>,
1586 buffer: Entity<Buffer>,
1587 position: language::Anchor,
1588 _debounce: bool,
1589 cx: &mut Context<Self>,
1590 ) {
1591 if !self.zeta.read(cx).tos_accepted {
1592 return;
1593 }
1594
1595 if self.zeta.read(cx).update_required {
1596 return;
1597 }
1598
1599 if self
1600 .zeta
1601 .read(cx)
1602 .user_store
1603 .read_with(cx, |user_store, _| {
1604 user_store.account_too_young() || user_store.has_overdue_invoices()
1605 })
1606 {
1607 return;
1608 }
1609
1610 if let Some(current_completion) = self.current_completion.as_ref() {
1611 let snapshot = buffer.read(cx).snapshot();
1612 if current_completion
1613 .completion
1614 .interpolate(&snapshot)
1615 .is_some()
1616 {
1617 return;
1618 }
1619 }
1620
1621 let pending_completion_id = self.next_pending_completion_id;
1622 self.next_pending_completion_id += 1;
1623 let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1624 let last_request_timestamp = self.last_request_timestamp;
1625
1626 let task = cx.spawn(async move |this, cx| {
1627 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1628 .checked_duration_since(Instant::now())
1629 {
1630 cx.background_executor().timer(timeout).await;
1631 }
1632
1633 let completion_request = this.update(cx, |this, cx| {
1634 this.last_request_timestamp = Instant::now();
1635 this.zeta.update(cx, |zeta, cx| {
1636 zeta.request_completion(
1637 project.as_ref(),
1638 &buffer,
1639 position,
1640 can_collect_data,
1641 cx,
1642 )
1643 })
1644 });
1645
1646 let completion = match completion_request {
1647 Ok(completion_request) => {
1648 let completion_request = completion_request.await;
1649 completion_request.map(|c| {
1650 c.map(|completion| CurrentInlineCompletion {
1651 buffer_id: buffer.entity_id(),
1652 completion,
1653 })
1654 })
1655 }
1656 Err(error) => Err(error),
1657 };
1658 let Some(new_completion) = completion
1659 .context("edit prediction failed")
1660 .log_err()
1661 .flatten()
1662 else {
1663 this.update(cx, |this, cx| {
1664 if this.pending_completions[0].id == pending_completion_id {
1665 this.pending_completions.remove(0);
1666 } else {
1667 this.pending_completions.clear();
1668 }
1669
1670 cx.notify();
1671 })
1672 .ok();
1673 return;
1674 };
1675
1676 this.update(cx, |this, cx| {
1677 if this.pending_completions[0].id == pending_completion_id {
1678 this.pending_completions.remove(0);
1679 } else {
1680 this.pending_completions.clear();
1681 }
1682
1683 if let Some(old_completion) = this.current_completion.as_ref() {
1684 let snapshot = buffer.read(cx).snapshot();
1685 if new_completion.should_replace_completion(&old_completion, &snapshot) {
1686 this.zeta.update(cx, |zeta, cx| {
1687 zeta.completion_shown(&new_completion.completion, cx);
1688 });
1689 this.current_completion = Some(new_completion);
1690 }
1691 } else {
1692 this.zeta.update(cx, |zeta, cx| {
1693 zeta.completion_shown(&new_completion.completion, cx);
1694 });
1695 this.current_completion = Some(new_completion);
1696 }
1697
1698 cx.notify();
1699 })
1700 .ok();
1701 });
1702
1703 // We always maintain at most two pending completions. When we already
1704 // have two, we replace the newest one.
1705 if self.pending_completions.len() <= 1 {
1706 self.pending_completions.push(PendingCompletion {
1707 id: pending_completion_id,
1708 _task: task,
1709 });
1710 } else if self.pending_completions.len() == 2 {
1711 self.pending_completions.pop();
1712 self.pending_completions.push(PendingCompletion {
1713 id: pending_completion_id,
1714 _task: task,
1715 });
1716 }
1717 }
1718
1719 fn cycle(
1720 &mut self,
1721 _buffer: Entity<Buffer>,
1722 _cursor_position: language::Anchor,
1723 _direction: inline_completion::Direction,
1724 _cx: &mut Context<Self>,
1725 ) {
1726 // Right now we don't support cycling.
1727 }
1728
1729 fn accept(&mut self, cx: &mut Context<Self>) {
1730 let completion_id = self
1731 .current_completion
1732 .as_ref()
1733 .map(|completion| completion.completion.id);
1734 if let Some(completion_id) = completion_id {
1735 self.zeta
1736 .update(cx, |zeta, cx| {
1737 zeta.accept_edit_prediction(completion_id, cx)
1738 })
1739 .detach();
1740 }
1741 self.pending_completions.clear();
1742 }
1743
1744 fn discard(&mut self, _cx: &mut Context<Self>) {
1745 self.pending_completions.clear();
1746 self.current_completion.take();
1747 }
1748
1749 fn suggest(
1750 &mut self,
1751 buffer: &Entity<Buffer>,
1752 cursor_position: language::Anchor,
1753 cx: &mut Context<Self>,
1754 ) -> Option<inline_completion::InlineCompletion> {
1755 let CurrentInlineCompletion {
1756 buffer_id,
1757 completion,
1758 ..
1759 } = self.current_completion.as_mut()?;
1760
1761 // Invalidate previous completion if it was generated for a different buffer.
1762 if *buffer_id != buffer.entity_id() {
1763 self.current_completion.take();
1764 return None;
1765 }
1766
1767 let buffer = buffer.read(cx);
1768 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1769 self.current_completion.take();
1770 return None;
1771 };
1772
1773 let cursor_row = cursor_position.to_point(buffer).row;
1774 let (closest_edit_ix, (closest_edit_range, _)) =
1775 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1776 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1777 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1778 cmp::min(distance_from_start, distance_from_end)
1779 })?;
1780
1781 let mut edit_start_ix = closest_edit_ix;
1782 for (range, _) in edits[..edit_start_ix].iter().rev() {
1783 let distance_from_closest_edit =
1784 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1785 if distance_from_closest_edit <= 1 {
1786 edit_start_ix -= 1;
1787 } else {
1788 break;
1789 }
1790 }
1791
1792 let mut edit_end_ix = closest_edit_ix + 1;
1793 for (range, _) in &edits[edit_end_ix..] {
1794 let distance_from_closest_edit =
1795 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1796 if distance_from_closest_edit <= 1 {
1797 edit_end_ix += 1;
1798 } else {
1799 break;
1800 }
1801 }
1802
1803 Some(inline_completion::InlineCompletion {
1804 id: Some(completion.id.to_string().into()),
1805 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1806 edit_preview: Some(completion.edit_preview.clone()),
1807 })
1808 }
1809}
1810
1811fn tokens_for_bytes(bytes: usize) -> usize {
1812 /// Typical number of string bytes per token for the purposes of limiting model input. This is
1813 /// intentionally low to err on the side of underestimating limits.
1814 const BYTES_PER_TOKEN_GUESS: usize = 3;
1815 bytes / BYTES_PER_TOKEN_GUESS
1816}
1817
1818#[cfg(test)]
1819mod tests {
1820 use client::test::FakeServer;
1821 use clock::FakeSystemClock;
1822 use gpui::TestAppContext;
1823 use http_client::FakeHttpClient;
1824 use indoc::indoc;
1825 use language::Point;
1826 use rpc::proto;
1827 use settings::SettingsStore;
1828
1829 use super::*;
1830
1831 #[gpui::test]
1832 async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1833 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1834 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1835 to_completion_edits(
1836 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1837 &buffer,
1838 cx,
1839 )
1840 .into()
1841 });
1842
1843 let edit_preview = cx
1844 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1845 .await;
1846
1847 let completion = InlineCompletion {
1848 edits,
1849 edit_preview,
1850 path: Path::new("").into(),
1851 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1852 id: InlineCompletionId(Uuid::new_v4()),
1853 excerpt_range: 0..0,
1854 cursor_offset: 0,
1855 input_outline: "".into(),
1856 input_events: "".into(),
1857 input_excerpt: "".into(),
1858 output_excerpt: "".into(),
1859 request_sent_at: Instant::now(),
1860 response_received_at: Instant::now(),
1861 };
1862
1863 cx.update(|cx| {
1864 assert_eq!(
1865 from_completion_edits(
1866 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1867 &buffer,
1868 cx
1869 ),
1870 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1871 );
1872
1873 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1874 assert_eq!(
1875 from_completion_edits(
1876 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1877 &buffer,
1878 cx
1879 ),
1880 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1881 );
1882
1883 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1884 assert_eq!(
1885 from_completion_edits(
1886 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1887 &buffer,
1888 cx
1889 ),
1890 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1891 );
1892
1893 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1894 assert_eq!(
1895 from_completion_edits(
1896 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1897 &buffer,
1898 cx
1899 ),
1900 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1901 );
1902
1903 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1904 assert_eq!(
1905 from_completion_edits(
1906 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1907 &buffer,
1908 cx
1909 ),
1910 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1911 );
1912
1913 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1914 assert_eq!(
1915 from_completion_edits(
1916 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1917 &buffer,
1918 cx
1919 ),
1920 vec![(9..11, "".to_string())]
1921 );
1922
1923 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1924 assert_eq!(
1925 from_completion_edits(
1926 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1927 &buffer,
1928 cx
1929 ),
1930 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1931 );
1932
1933 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1934 assert_eq!(
1935 from_completion_edits(
1936 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1937 &buffer,
1938 cx
1939 ),
1940 vec![(4..4, "M".to_string())]
1941 );
1942
1943 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1944 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1945 })
1946 }
1947
1948 #[gpui::test]
1949 async fn test_clean_up_diff(cx: &mut TestAppContext) {
1950 cx.update(|cx| {
1951 let settings_store = SettingsStore::test(cx);
1952 cx.set_global(settings_store);
1953 client::init_settings(cx);
1954 });
1955
1956 let edits = edits_for_prediction(
1957 indoc! {"
1958 fn main() {
1959 let word_1 = \"lorem\";
1960 let range = word.len()..word.len();
1961 }
1962 "},
1963 indoc! {"
1964 <|editable_region_start|>
1965 fn main() {
1966 let word_1 = \"lorem\";
1967 let range = word_1.len()..word_1.len();
1968 }
1969
1970 <|editable_region_end|>
1971 "},
1972 cx,
1973 )
1974 .await;
1975 assert_eq!(
1976 edits,
1977 [
1978 (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1979 (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1980 ]
1981 );
1982
1983 let edits = edits_for_prediction(
1984 indoc! {"
1985 fn main() {
1986 let story = \"the quick\"
1987 }
1988 "},
1989 indoc! {"
1990 <|editable_region_start|>
1991 fn main() {
1992 let story = \"the quick brown fox jumps over the lazy dog\";
1993 }
1994
1995 <|editable_region_end|>
1996 "},
1997 cx,
1998 )
1999 .await;
2000 assert_eq!(
2001 edits,
2002 [
2003 (
2004 Point::new(1, 26)..Point::new(1, 26),
2005 " brown fox jumps over the lazy dog".to_string()
2006 ),
2007 (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2008 ]
2009 );
2010 }
2011
2012 #[gpui::test]
2013 async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
2014 cx.update(|cx| {
2015 let settings_store = SettingsStore::test(cx);
2016 cx.set_global(settings_store);
2017 client::init_settings(cx);
2018 });
2019
2020 let buffer_content = "lorem\n";
2021 let completion_response = indoc! {"
2022 ```animals.js
2023 <|start_of_file|>
2024 <|editable_region_start|>
2025 lorem
2026 ipsum
2027 <|editable_region_end|>
2028 ```"};
2029
2030 let http_client = FakeHttpClient::create(move |_| async move {
2031 Ok(http_client::Response::builder()
2032 .status(200)
2033 .body(
2034 serde_json::to_string(&PredictEditsResponse {
2035 request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2036 .unwrap(),
2037 output_excerpt: completion_response.to_string(),
2038 })
2039 .unwrap()
2040 .into(),
2041 )
2042 .unwrap())
2043 });
2044
2045 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2046 cx.update(|cx| {
2047 RefreshLlmTokenListener::register(client.clone(), cx);
2048 });
2049 let server = FakeServer::for_client(42, &client, cx).await;
2050 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2051 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
2052
2053 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2054 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2055 let completion_task = zeta.update(cx, |zeta, cx| {
2056 zeta.request_completion(None, &buffer, cursor, false, cx)
2057 });
2058
2059 server.receive::<proto::GetUsers>().await.unwrap();
2060 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
2061 server.respond(
2062 token_request.receipt(),
2063 proto::GetLlmTokenResponse { token: "".into() },
2064 );
2065
2066 let completion = completion_task.await.unwrap().unwrap();
2067 buffer.update(cx, |buffer, cx| {
2068 buffer.edit(completion.edits.iter().cloned(), None, cx)
2069 });
2070 assert_eq!(
2071 buffer.read_with(cx, |buffer, _| buffer.text()),
2072 "lorem\nipsum"
2073 );
2074 }
2075
2076 async fn edits_for_prediction(
2077 buffer_content: &str,
2078 completion_response: &str,
2079 cx: &mut TestAppContext,
2080 ) -> Vec<(Range<Point>, String)> {
2081 let completion_response = completion_response.to_string();
2082 let http_client = FakeHttpClient::create(move |_| {
2083 let completion = completion_response.clone();
2084 async move {
2085 Ok(http_client::Response::builder()
2086 .status(200)
2087 .body(
2088 serde_json::to_string(&PredictEditsResponse {
2089 request_id: Uuid::new_v4(),
2090 output_excerpt: completion,
2091 })
2092 .unwrap()
2093 .into(),
2094 )
2095 .unwrap())
2096 }
2097 });
2098
2099 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2100 cx.update(|cx| {
2101 RefreshLlmTokenListener::register(client.clone(), cx);
2102 });
2103 let server = FakeServer::for_client(42, &client, cx).await;
2104 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2105 let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
2106
2107 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2108 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2109 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2110 let completion_task = zeta.update(cx, |zeta, cx| {
2111 zeta.request_completion(None, &buffer, cursor, false, cx)
2112 });
2113
2114 server.receive::<proto::GetUsers>().await.unwrap();
2115 let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
2116 server.respond(
2117 token_request.receipt(),
2118 proto::GetLlmTokenResponse { token: "".into() },
2119 );
2120
2121 let completion = completion_task.await.unwrap().unwrap();
2122 completion
2123 .edits
2124 .into_iter()
2125 .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2126 .collect::<Vec<_>>()
2127 }
2128
2129 fn to_completion_edits(
2130 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2131 buffer: &Entity<Buffer>,
2132 cx: &App,
2133 ) -> Vec<(Range<Anchor>, String)> {
2134 let buffer = buffer.read(cx);
2135 iterator
2136 .into_iter()
2137 .map(|(range, text)| {
2138 (
2139 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2140 text,
2141 )
2142 })
2143 .collect()
2144 }
2145
2146 fn from_completion_edits(
2147 editor_edits: &[(Range<Anchor>, String)],
2148 buffer: &Entity<Buffer>,
2149 cx: &App,
2150 ) -> Vec<(Range<usize>, String)> {
2151 let buffer = buffer.read(cx);
2152 editor_edits
2153 .iter()
2154 .map(|(range, text)| {
2155 (
2156 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2157 text.clone(),
2158 )
2159 })
2160 .collect()
2161 }
2162
2163 #[ctor::ctor]
2164 fn init_logger() {
2165 zlog::init_test();
2166 }
2167}