1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56 prompt_format: PromptFormat::MarkedExcerpt,
57};
58
59#[derive(Clone)]
60struct ZetaGlobal(Entity<Zeta>);
61
62impl Global for ZetaGlobal {}
63
64pub struct Zeta {
65 client: Arc<Client>,
66 user_store: Entity<UserStore>,
67 llm_token: LlmApiToken,
68 _llm_token_subscription: Subscription,
69 projects: HashMap<EntityId, ZetaProject>,
70 options: ZetaOptions,
71 update_required: bool,
72 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct ZetaOptions {
77 pub excerpt: EditPredictionExcerptOptions,
78 pub max_prompt_bytes: usize,
79 pub max_diagnostic_bytes: usize,
80 pub prompt_format: predict_edits_v3::PromptFormat,
81}
82
83pub struct PredictionDebugInfo {
84 pub context: EditPredictionContext,
85 pub retrieval_time: TimeDelta,
86 pub request: RequestDebugInfo,
87 pub buffer: WeakEntity<Buffer>,
88 pub position: language::Anchor,
89}
90
91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
92
93struct ZetaProject {
94 syntax_index: Entity<SyntaxIndex>,
95 events: VecDeque<Event>,
96 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
97}
98
99struct RegisteredBuffer {
100 snapshot: BufferSnapshot,
101 _subscriptions: [gpui::Subscription; 2],
102}
103
104#[derive(Clone)]
105pub enum Event {
106 BufferChange {
107 old_snapshot: BufferSnapshot,
108 new_snapshot: BufferSnapshot,
109 timestamp: Instant,
110 },
111}
112
113impl Zeta {
114 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
115 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
116 }
117
118 pub fn global(
119 client: &Arc<Client>,
120 user_store: &Entity<UserStore>,
121 cx: &mut App,
122 ) -> Entity<Self> {
123 cx.try_global::<ZetaGlobal>()
124 .map(|global| global.0.clone())
125 .unwrap_or_else(|| {
126 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
127 cx.set_global(ZetaGlobal(zeta.clone()));
128 zeta
129 })
130 }
131
132 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
133 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
134
135 Self {
136 projects: HashMap::new(),
137 client,
138 user_store,
139 options: DEFAULT_OPTIONS,
140 llm_token: LlmApiToken::default(),
141 _llm_token_subscription: cx.subscribe(
142 &refresh_llm_token_listener,
143 |this, _listener, _event, cx| {
144 let client = this.client.clone();
145 let llm_token = this.llm_token.clone();
146 cx.spawn(async move |_this, _cx| {
147 llm_token.refresh(&client).await?;
148 anyhow::Ok(())
149 })
150 .detach_and_log_err(cx);
151 },
152 ),
153 update_required: false,
154 debug_tx: None,
155 }
156 }
157
158 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
159 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
160 self.debug_tx = Some(debug_watch_tx);
161 debug_watch_rx
162 }
163
164 pub fn options(&self) -> &ZetaOptions {
165 &self.options
166 }
167
168 pub fn set_options(&mut self, options: ZetaOptions) {
169 self.options = options;
170 }
171
172 pub fn clear_history(&mut self) {
173 for zeta_project in self.projects.values_mut() {
174 zeta_project.events.clear();
175 }
176 }
177
178 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
179 self.user_store.read(cx).edit_prediction_usage()
180 }
181
182 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
183 self.get_or_init_zeta_project(project, cx);
184 }
185
186 pub fn register_buffer(
187 &mut self,
188 buffer: &Entity<Buffer>,
189 project: &Entity<Project>,
190 cx: &mut Context<Self>,
191 ) {
192 let zeta_project = self.get_or_init_zeta_project(project, cx);
193 Self::register_buffer_impl(zeta_project, buffer, project, cx);
194 }
195
196 fn get_or_init_zeta_project(
197 &mut self,
198 project: &Entity<Project>,
199 cx: &mut App,
200 ) -> &mut ZetaProject {
201 self.projects
202 .entry(project.entity_id())
203 .or_insert_with(|| ZetaProject {
204 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
205 events: VecDeque::new(),
206 registered_buffers: HashMap::new(),
207 })
208 }
209
210 fn register_buffer_impl<'a>(
211 zeta_project: &'a mut ZetaProject,
212 buffer: &Entity<Buffer>,
213 project: &Entity<Project>,
214 cx: &mut Context<Self>,
215 ) -> &'a mut RegisteredBuffer {
216 let buffer_id = buffer.entity_id();
217 match zeta_project.registered_buffers.entry(buffer_id) {
218 hash_map::Entry::Occupied(entry) => entry.into_mut(),
219 hash_map::Entry::Vacant(entry) => {
220 let snapshot = buffer.read(cx).snapshot();
221 let project_entity_id = project.entity_id();
222 entry.insert(RegisteredBuffer {
223 snapshot,
224 _subscriptions: [
225 cx.subscribe(buffer, {
226 let project = project.downgrade();
227 move |this, buffer, event, cx| {
228 if let language::BufferEvent::Edited = event
229 && let Some(project) = project.upgrade()
230 {
231 this.report_changes_for_buffer(&buffer, &project, cx);
232 }
233 }
234 }),
235 cx.observe_release(buffer, move |this, _buffer, _cx| {
236 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
237 else {
238 return;
239 };
240 zeta_project.registered_buffers.remove(&buffer_id);
241 }),
242 ],
243 })
244 }
245 }
246 }
247
248 fn report_changes_for_buffer(
249 &mut self,
250 buffer: &Entity<Buffer>,
251 project: &Entity<Project>,
252 cx: &mut Context<Self>,
253 ) -> BufferSnapshot {
254 let zeta_project = self.get_or_init_zeta_project(project, cx);
255 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
256
257 let new_snapshot = buffer.read(cx).snapshot();
258 if new_snapshot.version != registered_buffer.snapshot.version {
259 let old_snapshot =
260 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
261 Self::push_event(
262 zeta_project,
263 Event::BufferChange {
264 old_snapshot,
265 new_snapshot: new_snapshot.clone(),
266 timestamp: Instant::now(),
267 },
268 );
269 }
270
271 new_snapshot
272 }
273
274 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
275 let events = &mut zeta_project.events;
276
277 if let Some(Event::BufferChange {
278 new_snapshot: last_new_snapshot,
279 timestamp: last_timestamp,
280 ..
281 }) = events.back_mut()
282 {
283 // Coalesce edits for the same buffer when they happen one after the other.
284 let Event::BufferChange {
285 old_snapshot,
286 new_snapshot,
287 timestamp,
288 } = &event;
289
290 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
291 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
292 && old_snapshot.version == last_new_snapshot.version
293 {
294 *last_new_snapshot = new_snapshot.clone();
295 *last_timestamp = *timestamp;
296 return;
297 }
298 }
299
300 if events.len() >= MAX_EVENT_COUNT {
301 // These are halved instead of popping to improve prompt caching.
302 events.drain(..MAX_EVENT_COUNT / 2);
303 }
304
305 events.push_back(event);
306 }
307
308 pub fn request_prediction(
309 &mut self,
310 project: &Entity<Project>,
311 buffer: &Entity<Buffer>,
312 position: language::Anchor,
313 cx: &mut Context<Self>,
314 ) -> Task<Result<Option<EditPrediction>>> {
315 let project_state = self.projects.get(&project.entity_id());
316
317 let index_state = project_state.map(|state| {
318 state
319 .syntax_index
320 .read_with(cx, |index, _cx| index.state().clone())
321 });
322 let options = self.options.clone();
323 let snapshot = buffer.read(cx).snapshot();
324 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
325 return Task::ready(Err(anyhow!("No file path for excerpt")));
326 };
327 let client = self.client.clone();
328 let llm_token = self.llm_token.clone();
329 let app_version = AppVersion::global(cx);
330 let worktree_snapshots = project
331 .read(cx)
332 .worktrees(cx)
333 .map(|worktree| worktree.read(cx).snapshot())
334 .collect::<Vec<_>>();
335 let debug_tx = self.debug_tx.clone();
336
337 let events = project_state
338 .map(|state| {
339 state
340 .events
341 .iter()
342 .map(|event| match event {
343 Event::BufferChange {
344 old_snapshot,
345 new_snapshot,
346 ..
347 } => {
348 let path = new_snapshot.file().map(|f| f.full_path(cx));
349
350 let old_path = old_snapshot.file().and_then(|f| {
351 let old_path = f.full_path(cx);
352 if Some(&old_path) != path.as_ref() {
353 Some(old_path)
354 } else {
355 None
356 }
357 });
358
359 predict_edits_v3::Event::BufferChange {
360 old_path,
361 path,
362 diff: language::unified_diff(
363 &old_snapshot.text(),
364 &new_snapshot.text(),
365 ),
366 //todo: Actually detect if this edit was predicted or not
367 predicted: false,
368 }
369 }
370 })
371 .collect::<Vec<_>>()
372 })
373 .unwrap_or_default();
374
375 let diagnostics = snapshot.diagnostic_sets().clone();
376
377 let request_task = cx.background_spawn({
378 let snapshot = snapshot.clone();
379 let buffer = buffer.clone();
380 async move {
381 let index_state = if let Some(index_state) = index_state {
382 Some(index_state.lock_owned().await)
383 } else {
384 None
385 };
386
387 let cursor_offset = position.to_offset(&snapshot);
388 let cursor_point = cursor_offset.to_point(&snapshot);
389
390 let before_retrieval = chrono::Utc::now();
391
392 let Some(context) = EditPredictionContext::gather_context(
393 cursor_point,
394 &snapshot,
395 &options.excerpt,
396 index_state.as_deref(),
397 ) else {
398 return Ok(None);
399 };
400
401 let debug_context = if let Some(debug_tx) = debug_tx {
402 Some((debug_tx, context.clone()))
403 } else {
404 None
405 };
406
407 let (diagnostic_groups, diagnostic_groups_truncated) =
408 Self::gather_nearby_diagnostics(
409 cursor_offset,
410 &diagnostics,
411 &snapshot,
412 options.max_diagnostic_bytes,
413 );
414
415 let request = make_cloud_request(
416 excerpt_path,
417 context,
418 events,
419 // TODO data collection
420 false,
421 diagnostic_groups,
422 diagnostic_groups_truncated,
423 None,
424 debug_context.is_some(),
425 &worktree_snapshots,
426 index_state.as_deref(),
427 Some(options.max_prompt_bytes),
428 options.prompt_format,
429 );
430
431 let retrieval_time = chrono::Utc::now() - before_retrieval;
432 let response = Self::perform_request(client, llm_token, app_version, request).await;
433
434 if let Some((debug_tx, context)) = debug_context {
435 debug_tx
436 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
437 |response| {
438 let Some(request) =
439 some_or_debug_panic(response.0.debug_info.clone())
440 else {
441 return Err("Missing debug info".to_string());
442 };
443 Ok(PredictionDebugInfo {
444 context,
445 request,
446 retrieval_time,
447 buffer: buffer.downgrade(),
448 position,
449 })
450 },
451 ))
452 .ok();
453 }
454
455 let (response, usage) = response?;
456 let edits = edits_from_response(&response.edits, &snapshot);
457
458 anyhow::Ok(Some((response.request_id, edits, usage)))
459 }
460 });
461
462 let buffer = buffer.clone();
463
464 cx.spawn(async move |this, cx| {
465 match request_task.await {
466 Ok(Some((id, edits, usage))) => {
467 if let Some(usage) = usage {
468 this.update(cx, |this, cx| {
469 this.user_store.update(cx, |user_store, cx| {
470 user_store.update_edit_prediction_usage(usage, cx);
471 });
472 })
473 .ok();
474 }
475
476 // TODO telemetry: duration, etc
477 let Some((edits, snapshot, edit_preview_task)) =
478 buffer.read_with(cx, |buffer, cx| {
479 let new_snapshot = buffer.snapshot();
480 let edits: Arc<[_]> =
481 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
482 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
483 })?
484 else {
485 return Ok(None);
486 };
487
488 Ok(Some(EditPrediction {
489 id: id.into(),
490 edits,
491 snapshot,
492 edit_preview: edit_preview_task.await,
493 }))
494 }
495 Ok(None) => Ok(None),
496 Err(err) => {
497 if err.is::<ZedUpdateRequiredError>() {
498 cx.update(|cx| {
499 this.update(cx, |this, _cx| {
500 this.update_required = true;
501 })
502 .ok();
503
504 let error_message: SharedString = err.to_string().into();
505 show_app_notification(
506 NotificationId::unique::<ZedUpdateRequiredError>(),
507 cx,
508 move |cx| {
509 cx.new(|cx| {
510 ErrorMessagePrompt::new(error_message.clone(), cx)
511 .with_link_button(
512 "Update Zed",
513 "https://zed.dev/releases",
514 )
515 })
516 },
517 );
518 })
519 .ok();
520 }
521
522 Err(err)
523 }
524 }
525 })
526 }
527
528 async fn perform_request(
529 client: Arc<Client>,
530 llm_token: LlmApiToken,
531 app_version: SemanticVersion,
532 request: predict_edits_v3::PredictEditsRequest,
533 ) -> Result<(
534 predict_edits_v3::PredictEditsResponse,
535 Option<EditPredictionUsage>,
536 )> {
537 let http_client = client.http_client();
538 let mut token = llm_token.acquire(&client).await?;
539 let mut did_retry = false;
540
541 loop {
542 let request_builder = http_client::Request::builder().method(Method::POST);
543 let request_builder =
544 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
545 request_builder.uri(predict_edits_url)
546 } else {
547 request_builder.uri(
548 http_client
549 .build_zed_llm_url("/predict_edits/v3", &[])?
550 .as_ref(),
551 )
552 };
553 let request = request_builder
554 .header("Content-Type", "application/json")
555 .header("Authorization", format!("Bearer {}", token))
556 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
557 .body(serde_json::to_string(&request)?.into())?;
558
559 let mut response = http_client.send(request).await?;
560
561 if let Some(minimum_required_version) = response
562 .headers()
563 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
564 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
565 {
566 anyhow::ensure!(
567 app_version >= minimum_required_version,
568 ZedUpdateRequiredError {
569 minimum_version: minimum_required_version
570 }
571 );
572 }
573
574 if response.status().is_success() {
575 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
576
577 let mut body = Vec::new();
578 response.body_mut().read_to_end(&mut body).await?;
579 return Ok((serde_json::from_slice(&body)?, usage));
580 } else if !did_retry
581 && response
582 .headers()
583 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
584 .is_some()
585 {
586 did_retry = true;
587 token = llm_token.refresh(&client).await?;
588 } else {
589 let mut body = String::new();
590 response.body_mut().read_to_string(&mut body).await?;
591 anyhow::bail!(
592 "error predicting edits.\nStatus: {:?}\nBody: {}",
593 response.status(),
594 body
595 );
596 }
597 }
598 }
599
600 fn gather_nearby_diagnostics(
601 cursor_offset: usize,
602 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
603 snapshot: &BufferSnapshot,
604 max_diagnostics_bytes: usize,
605 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
606 // TODO: Could make this more efficient
607 let mut diagnostic_groups = Vec::new();
608 for (language_server_id, diagnostics) in diagnostic_sets {
609 let mut groups = Vec::new();
610 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
611 diagnostic_groups.extend(
612 groups
613 .into_iter()
614 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
615 );
616 }
617
618 // sort by proximity to cursor
619 diagnostic_groups.sort_by_key(|group| {
620 let range = &group.entries[group.primary_ix].range;
621 if range.start >= cursor_offset {
622 range.start - cursor_offset
623 } else if cursor_offset >= range.end {
624 cursor_offset - range.end
625 } else {
626 (cursor_offset - range.start).min(range.end - cursor_offset)
627 }
628 });
629
630 let mut results = Vec::new();
631 let mut diagnostic_groups_truncated = false;
632 let mut diagnostics_byte_count = 0;
633 for group in diagnostic_groups {
634 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
635 diagnostics_byte_count += raw_value.get().len();
636 if diagnostics_byte_count > max_diagnostics_bytes {
637 diagnostic_groups_truncated = true;
638 break;
639 }
640 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
641 }
642
643 (results, diagnostic_groups_truncated)
644 }
645
646 // TODO: Dedupe with similar code in request_prediction?
647 pub fn cloud_request_for_zeta_cli(
648 &mut self,
649 project: &Entity<Project>,
650 buffer: &Entity<Buffer>,
651 position: language::Anchor,
652 cx: &mut Context<Self>,
653 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
654 let project_state = self.projects.get(&project.entity_id());
655
656 let index_state = project_state.map(|state| {
657 state
658 .syntax_index
659 .read_with(cx, |index, _cx| index.state().clone())
660 });
661 let options = self.options.clone();
662 let snapshot = buffer.read(cx).snapshot();
663 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
664 return Task::ready(Err(anyhow!("No file path for excerpt")));
665 };
666 let worktree_snapshots = project
667 .read(cx)
668 .worktrees(cx)
669 .map(|worktree| worktree.read(cx).snapshot())
670 .collect::<Vec<_>>();
671
672 cx.background_spawn(async move {
673 let index_state = if let Some(index_state) = index_state {
674 Some(index_state.lock_owned().await)
675 } else {
676 None
677 };
678
679 let cursor_point = position.to_point(&snapshot);
680
681 let debug_info = true;
682 EditPredictionContext::gather_context(
683 cursor_point,
684 &snapshot,
685 &options.excerpt,
686 index_state.as_deref(),
687 )
688 .context("Failed to select excerpt")
689 .map(|context| {
690 make_cloud_request(
691 excerpt_path.into(),
692 context,
693 // TODO pass everything
694 Vec::new(),
695 false,
696 Vec::new(),
697 false,
698 None,
699 debug_info,
700 &worktree_snapshots,
701 index_state.as_deref(),
702 Some(options.max_prompt_bytes),
703 options.prompt_format,
704 )
705 })
706 })
707 }
708}
709
710#[derive(Error, Debug)]
711#[error(
712 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
713)]
714pub struct ZedUpdateRequiredError {
715 minimum_version: SemanticVersion,
716}
717
718fn make_cloud_request(
719 excerpt_path: Arc<Path>,
720 context: EditPredictionContext,
721 events: Vec<predict_edits_v3::Event>,
722 can_collect_data: bool,
723 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
724 diagnostic_groups_truncated: bool,
725 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
726 debug_info: bool,
727 worktrees: &Vec<worktree::Snapshot>,
728 index_state: Option<&SyntaxIndexState>,
729 prompt_max_bytes: Option<usize>,
730 prompt_format: PromptFormat,
731) -> predict_edits_v3::PredictEditsRequest {
732 let mut signatures = Vec::new();
733 let mut declaration_to_signature_index = HashMap::default();
734 let mut referenced_declarations = Vec::new();
735
736 for snippet in context.snippets {
737 let project_entry_id = snippet.declaration.project_entry_id();
738 let Some(path) = worktrees.iter().find_map(|worktree| {
739 worktree.entry_for_id(project_entry_id).map(|entry| {
740 let mut full_path = RelPathBuf::new();
741 full_path.push(worktree.root_name());
742 full_path.push(&entry.path);
743 full_path
744 })
745 }) else {
746 continue;
747 };
748
749 let parent_index = index_state.and_then(|index_state| {
750 snippet.declaration.parent().and_then(|parent| {
751 add_signature(
752 parent,
753 &mut declaration_to_signature_index,
754 &mut signatures,
755 index_state,
756 )
757 })
758 });
759
760 let (text, text_is_truncated) = snippet.declaration.item_text();
761 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
762 path: path.as_std_path().into(),
763 text: text.into(),
764 range: snippet.declaration.item_range(),
765 text_is_truncated,
766 signature_range: snippet.declaration.signature_range_in_item_text(),
767 parent_index,
768 score_components: snippet.score_components,
769 signature_score: snippet.scores.signature,
770 declaration_score: snippet.scores.declaration,
771 });
772 }
773
774 let excerpt_parent = index_state.and_then(|index_state| {
775 context
776 .excerpt
777 .parent_declarations
778 .last()
779 .and_then(|(parent, _)| {
780 add_signature(
781 *parent,
782 &mut declaration_to_signature_index,
783 &mut signatures,
784 index_state,
785 )
786 })
787 });
788
789 predict_edits_v3::PredictEditsRequest {
790 excerpt_path,
791 excerpt: context.excerpt_text.body,
792 excerpt_range: context.excerpt.range,
793 cursor_offset: context.cursor_offset_in_excerpt,
794 referenced_declarations,
795 signatures,
796 excerpt_parent,
797 events,
798 can_collect_data,
799 diagnostic_groups,
800 diagnostic_groups_truncated,
801 git_info,
802 debug_info,
803 prompt_max_bytes,
804 prompt_format,
805 }
806}
807
808fn add_signature(
809 declaration_id: DeclarationId,
810 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
811 signatures: &mut Vec<Signature>,
812 index: &SyntaxIndexState,
813) -> Option<usize> {
814 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
815 return Some(*signature_index);
816 }
817 let Some(parent_declaration) = index.declaration(declaration_id) else {
818 log::error!("bug: missing parent declaration");
819 return None;
820 };
821 let parent_index = parent_declaration.parent().and_then(|parent| {
822 add_signature(parent, declaration_to_signature_index, signatures, index)
823 });
824 let (text, text_is_truncated) = parent_declaration.signature_text();
825 let signature_index = signatures.len();
826 signatures.push(Signature {
827 text: text.into(),
828 text_is_truncated,
829 parent_index,
830 range: parent_declaration.signature_range(),
831 });
832 declaration_to_signature_index.insert(declaration_id, signature_index);
833 Some(signature_index)
834}
835
836#[cfg(test)]
837mod tests {
838 use std::{
839 path::{Path, PathBuf},
840 sync::Arc,
841 };
842
843 use client::UserStore;
844 use clock::FakeSystemClock;
845 use cloud_llm_client::predict_edits_v3;
846 use futures::{
847 AsyncReadExt, StreamExt,
848 channel::{mpsc, oneshot},
849 };
850 use gpui::{
851 Entity, TestAppContext,
852 http_client::{FakeHttpClient, Response},
853 prelude::*,
854 };
855 use indoc::indoc;
856 use language::{LanguageServerId, OffsetRangeExt as _};
857 use project::{FakeFs, Project};
858 use serde_json::json;
859 use settings::SettingsStore;
860 use util::path;
861 use uuid::Uuid;
862
863 use crate::Zeta;
864
865 #[gpui::test]
866 async fn test_simple_request(cx: &mut TestAppContext) {
867 let (zeta, mut req_rx) = init_test(cx);
868 let fs = FakeFs::new(cx.executor());
869 fs.insert_tree(
870 "/root",
871 json!({
872 "foo.md": "Hello!\nHow\nBye"
873 }),
874 )
875 .await;
876 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
877
878 let buffer = project
879 .update(cx, |project, cx| {
880 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
881 project.open_buffer(path, cx)
882 })
883 .await
884 .unwrap();
885 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
886 let position = snapshot.anchor_before(language::Point::new(1, 3));
887
888 let prediction_task = zeta.update(cx, |zeta, cx| {
889 zeta.request_prediction(&project, &buffer, position, cx)
890 });
891
892 let (request, respond_tx) = req_rx.next().await.unwrap();
893 assert_eq!(
894 request.excerpt_path.as_ref(),
895 Path::new(path!("root/foo.md"))
896 );
897 assert_eq!(request.cursor_offset, 10);
898
899 respond_tx
900 .send(predict_edits_v3::PredictEditsResponse {
901 request_id: Uuid::new_v4(),
902 edits: vec![predict_edits_v3::Edit {
903 path: Path::new(path!("root/foo.md")).into(),
904 range: 0..snapshot.len(),
905 content: "Hello!\nHow are you?\nBye".into(),
906 }],
907 debug_info: None,
908 })
909 .unwrap();
910
911 let prediction = prediction_task.await.unwrap().unwrap();
912
913 assert_eq!(prediction.edits.len(), 1);
914 assert_eq!(
915 prediction.edits[0].0.to_point(&snapshot).start,
916 language::Point::new(1, 3)
917 );
918 assert_eq!(prediction.edits[0].1, " are you?");
919 }
920
921 #[gpui::test]
922 async fn test_request_events(cx: &mut TestAppContext) {
923 let (zeta, mut req_rx) = init_test(cx);
924 let fs = FakeFs::new(cx.executor());
925 fs.insert_tree(
926 "/root",
927 json!({
928 "foo.md": "Hello!\n\nBye"
929 }),
930 )
931 .await;
932 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
933
934 let buffer = project
935 .update(cx, |project, cx| {
936 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
937 project.open_buffer(path, cx)
938 })
939 .await
940 .unwrap();
941
942 zeta.update(cx, |zeta, cx| {
943 zeta.register_buffer(&buffer, &project, cx);
944 });
945
946 buffer.update(cx, |buffer, cx| {
947 buffer.edit(vec![(7..7, "How")], None, cx);
948 });
949
950 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
951 let position = snapshot.anchor_before(language::Point::new(1, 3));
952
953 let prediction_task = zeta.update(cx, |zeta, cx| {
954 zeta.request_prediction(&project, &buffer, position, cx)
955 });
956
957 let (request, respond_tx) = req_rx.next().await.unwrap();
958
959 assert_eq!(request.events.len(), 1);
960 assert_eq!(
961 request.events[0],
962 predict_edits_v3::Event::BufferChange {
963 path: Some(PathBuf::from(path!("root/foo.md"))),
964 old_path: None,
965 diff: indoc! {"
966 @@ -1,3 +1,3 @@
967 Hello!
968 -
969 +How
970 Bye
971 "}
972 .to_string(),
973 predicted: false
974 }
975 );
976
977 respond_tx
978 .send(predict_edits_v3::PredictEditsResponse {
979 request_id: Uuid::new_v4(),
980 edits: vec![predict_edits_v3::Edit {
981 path: Path::new(path!("root/foo.md")).into(),
982 range: 0..snapshot.len(),
983 content: "Hello!\nHow are you?\nBye".into(),
984 }],
985 debug_info: None,
986 })
987 .unwrap();
988
989 let prediction = prediction_task.await.unwrap().unwrap();
990
991 assert_eq!(prediction.edits.len(), 1);
992 assert_eq!(
993 prediction.edits[0].0.to_point(&snapshot).start,
994 language::Point::new(1, 3)
995 );
996 assert_eq!(prediction.edits[0].1, " are you?");
997 }
998
999 #[gpui::test]
1000 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1001 let (zeta, mut req_rx) = init_test(cx);
1002 let fs = FakeFs::new(cx.executor());
1003 fs.insert_tree(
1004 "/root",
1005 json!({
1006 "foo.md": "Hello!\nBye"
1007 }),
1008 )
1009 .await;
1010 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1011
1012 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1013 let diagnostic = lsp::Diagnostic {
1014 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1015 severity: Some(lsp::DiagnosticSeverity::ERROR),
1016 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1017 ..Default::default()
1018 };
1019
1020 project.update(cx, |project, cx| {
1021 project.lsp_store().update(cx, |lsp_store, cx| {
1022 // Create some diagnostics
1023 lsp_store
1024 .update_diagnostics(
1025 LanguageServerId(0),
1026 lsp::PublishDiagnosticsParams {
1027 uri: path_to_buffer_uri.clone(),
1028 diagnostics: vec![diagnostic],
1029 version: None,
1030 },
1031 None,
1032 language::DiagnosticSourceKind::Pushed,
1033 &[],
1034 cx,
1035 )
1036 .unwrap();
1037 });
1038 });
1039
1040 let buffer = project
1041 .update(cx, |project, cx| {
1042 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1043 project.open_buffer(path, cx)
1044 })
1045 .await
1046 .unwrap();
1047
1048 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1049 let position = snapshot.anchor_before(language::Point::new(0, 0));
1050
1051 let _prediction_task = zeta.update(cx, |zeta, cx| {
1052 zeta.request_prediction(&project, &buffer, position, cx)
1053 });
1054
1055 let (request, _respond_tx) = req_rx.next().await.unwrap();
1056
1057 assert_eq!(request.diagnostic_groups.len(), 1);
1058 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1059 .unwrap();
1060 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1061 assert_eq!(
1062 value,
1063 json!({
1064 "entries": [{
1065 "range": {
1066 "start": 8,
1067 "end": 10
1068 },
1069 "diagnostic": {
1070 "source": null,
1071 "code": null,
1072 "code_description": null,
1073 "severity": 1,
1074 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1075 "markdown": null,
1076 "group_id": 0,
1077 "is_primary": true,
1078 "is_disk_based": false,
1079 "is_unnecessary": false,
1080 "source_kind": "Pushed",
1081 "data": null,
1082 "underline": true
1083 }
1084 }],
1085 "primary_ix": 0
1086 })
1087 );
1088 }
1089
1090 fn init_test(
1091 cx: &mut TestAppContext,
1092 ) -> (
1093 Entity<Zeta>,
1094 mpsc::UnboundedReceiver<(
1095 predict_edits_v3::PredictEditsRequest,
1096 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1097 )>,
1098 ) {
1099 cx.update(move |cx| {
1100 let settings_store = SettingsStore::test(cx);
1101 cx.set_global(settings_store);
1102 language::init(cx);
1103 Project::init_settings(cx);
1104
1105 let (req_tx, req_rx) = mpsc::unbounded();
1106
1107 let http_client = FakeHttpClient::create({
1108 move |req| {
1109 let uri = req.uri().path().to_string();
1110 let mut body = req.into_body();
1111 let req_tx = req_tx.clone();
1112 async move {
1113 let resp = match uri.as_str() {
1114 "/client/llm_tokens" => serde_json::to_string(&json!({
1115 "token": "test"
1116 }))
1117 .unwrap(),
1118 "/predict_edits/v3" => {
1119 let mut buf = Vec::new();
1120 body.read_to_end(&mut buf).await.ok();
1121 let req = serde_json::from_slice(&buf).unwrap();
1122
1123 let (res_tx, res_rx) = oneshot::channel();
1124 req_tx.unbounded_send((req, res_tx)).unwrap();
1125 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1126 }
1127 _ => {
1128 panic!("Unexpected path: {}", uri)
1129 }
1130 };
1131
1132 Ok(Response::builder().body(resp.into()).unwrap())
1133 }
1134 }
1135 });
1136
1137 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1138 client.cloud_client().set_credentials(1, "test".into());
1139
1140 language_model::init(client.clone(), cx);
1141
1142 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1143 let zeta = Zeta::global(&client, &user_store, cx);
1144 (zeta, req_rx)
1145 })
1146 }
1147}