1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56 prompt_format: PromptFormat::MarkedExcerpt,
57};
58
59#[derive(Clone)]
60struct ZetaGlobal(Entity<Zeta>);
61
62impl Global for ZetaGlobal {}
63
64pub struct Zeta {
65 client: Arc<Client>,
66 user_store: Entity<UserStore>,
67 llm_token: LlmApiToken,
68 _llm_token_subscription: Subscription,
69 projects: HashMap<EntityId, ZetaProject>,
70 options: ZetaOptions,
71 update_required: bool,
72 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct ZetaOptions {
77 pub excerpt: EditPredictionExcerptOptions,
78 pub max_prompt_bytes: usize,
79 pub max_diagnostic_bytes: usize,
80 pub prompt_format: predict_edits_v3::PromptFormat,
81}
82
83pub struct PredictionDebugInfo {
84 pub context: EditPredictionContext,
85 pub retrieval_time: TimeDelta,
86 pub request: RequestDebugInfo,
87 pub buffer: WeakEntity<Buffer>,
88 pub position: language::Anchor,
89}
90
91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
92
93struct ZetaProject {
94 syntax_index: Entity<SyntaxIndex>,
95 events: VecDeque<Event>,
96 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
97}
98
99struct RegisteredBuffer {
100 snapshot: BufferSnapshot,
101 _subscriptions: [gpui::Subscription; 2],
102}
103
104#[derive(Clone)]
105pub enum Event {
106 BufferChange {
107 old_snapshot: BufferSnapshot,
108 new_snapshot: BufferSnapshot,
109 timestamp: Instant,
110 },
111}
112
113impl Zeta {
114 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
115 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
116 }
117
118 pub fn global(
119 client: &Arc<Client>,
120 user_store: &Entity<UserStore>,
121 cx: &mut App,
122 ) -> Entity<Self> {
123 cx.try_global::<ZetaGlobal>()
124 .map(|global| global.0.clone())
125 .unwrap_or_else(|| {
126 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
127 cx.set_global(ZetaGlobal(zeta.clone()));
128 zeta
129 })
130 }
131
132 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
133 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
134
135 Self {
136 projects: HashMap::new(),
137 client,
138 user_store,
139 options: DEFAULT_OPTIONS,
140 llm_token: LlmApiToken::default(),
141 _llm_token_subscription: cx.subscribe(
142 &refresh_llm_token_listener,
143 |this, _listener, _event, cx| {
144 let client = this.client.clone();
145 let llm_token = this.llm_token.clone();
146 cx.spawn(async move |_this, _cx| {
147 llm_token.refresh(&client).await?;
148 anyhow::Ok(())
149 })
150 .detach_and_log_err(cx);
151 },
152 ),
153 update_required: false,
154 debug_tx: None,
155 }
156 }
157
158 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
159 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
160 self.debug_tx = Some(debug_watch_tx);
161 debug_watch_rx
162 }
163
164 pub fn options(&self) -> &ZetaOptions {
165 &self.options
166 }
167
168 pub fn set_options(&mut self, options: ZetaOptions) {
169 self.options = options;
170 }
171
172 pub fn clear_history(&mut self) {
173 for zeta_project in self.projects.values_mut() {
174 zeta_project.events.clear();
175 }
176 }
177
178 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
179 self.user_store.read(cx).edit_prediction_usage()
180 }
181
182 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
183 self.get_or_init_zeta_project(project, cx);
184 }
185
186 pub fn register_buffer(
187 &mut self,
188 buffer: &Entity<Buffer>,
189 project: &Entity<Project>,
190 cx: &mut Context<Self>,
191 ) {
192 let zeta_project = self.get_or_init_zeta_project(project, cx);
193 Self::register_buffer_impl(zeta_project, buffer, project, cx);
194 }
195
196 fn get_or_init_zeta_project(
197 &mut self,
198 project: &Entity<Project>,
199 cx: &mut App,
200 ) -> &mut ZetaProject {
201 self.projects
202 .entry(project.entity_id())
203 .or_insert_with(|| ZetaProject {
204 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
205 events: VecDeque::new(),
206 registered_buffers: HashMap::new(),
207 })
208 }
209
210 fn register_buffer_impl<'a>(
211 zeta_project: &'a mut ZetaProject,
212 buffer: &Entity<Buffer>,
213 project: &Entity<Project>,
214 cx: &mut Context<Self>,
215 ) -> &'a mut RegisteredBuffer {
216 let buffer_id = buffer.entity_id();
217 match zeta_project.registered_buffers.entry(buffer_id) {
218 hash_map::Entry::Occupied(entry) => entry.into_mut(),
219 hash_map::Entry::Vacant(entry) => {
220 let snapshot = buffer.read(cx).snapshot();
221 let project_entity_id = project.entity_id();
222 entry.insert(RegisteredBuffer {
223 snapshot,
224 _subscriptions: [
225 cx.subscribe(buffer, {
226 let project = project.downgrade();
227 move |this, buffer, event, cx| {
228 if let language::BufferEvent::Edited = event
229 && let Some(project) = project.upgrade()
230 {
231 this.report_changes_for_buffer(&buffer, &project, cx);
232 }
233 }
234 }),
235 cx.observe_release(buffer, move |this, _buffer, _cx| {
236 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
237 else {
238 return;
239 };
240 zeta_project.registered_buffers.remove(&buffer_id);
241 }),
242 ],
243 })
244 }
245 }
246 }
247
248 fn report_changes_for_buffer(
249 &mut self,
250 buffer: &Entity<Buffer>,
251 project: &Entity<Project>,
252 cx: &mut Context<Self>,
253 ) -> BufferSnapshot {
254 let zeta_project = self.get_or_init_zeta_project(project, cx);
255 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
256
257 let new_snapshot = buffer.read(cx).snapshot();
258 if new_snapshot.version != registered_buffer.snapshot.version {
259 let old_snapshot =
260 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
261 Self::push_event(
262 zeta_project,
263 Event::BufferChange {
264 old_snapshot,
265 new_snapshot: new_snapshot.clone(),
266 timestamp: Instant::now(),
267 },
268 );
269 }
270
271 new_snapshot
272 }
273
274 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
275 let events = &mut zeta_project.events;
276
277 if let Some(Event::BufferChange {
278 new_snapshot: last_new_snapshot,
279 timestamp: last_timestamp,
280 ..
281 }) = events.back_mut()
282 {
283 // Coalesce edits for the same buffer when they happen one after the other.
284 let Event::BufferChange {
285 old_snapshot,
286 new_snapshot,
287 timestamp,
288 } = &event;
289
290 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
291 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
292 && old_snapshot.version == last_new_snapshot.version
293 {
294 *last_new_snapshot = new_snapshot.clone();
295 *last_timestamp = *timestamp;
296 return;
297 }
298 }
299
300 if events.len() >= MAX_EVENT_COUNT {
301 // These are halved instead of popping to improve prompt caching.
302 events.drain(..MAX_EVENT_COUNT / 2);
303 }
304
305 events.push_back(event);
306 }
307
308 pub fn request_prediction(
309 &mut self,
310 project: &Entity<Project>,
311 buffer: &Entity<Buffer>,
312 position: language::Anchor,
313 cx: &mut Context<Self>,
314 ) -> Task<Result<Option<EditPrediction>>> {
315 let project_state = self.projects.get(&project.entity_id());
316
317 let index_state = project_state.map(|state| {
318 state
319 .syntax_index
320 .read_with(cx, |index, _cx| index.state().clone())
321 });
322 let options = self.options.clone();
323 let snapshot = buffer.read(cx).snapshot();
324 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
325 return Task::ready(Err(anyhow!("No file path for excerpt")));
326 };
327 let client = self.client.clone();
328 let llm_token = self.llm_token.clone();
329 let app_version = AppVersion::global(cx);
330 let worktree_snapshots = project
331 .read(cx)
332 .worktrees(cx)
333 .map(|worktree| worktree.read(cx).snapshot())
334 .collect::<Vec<_>>();
335 let debug_tx = self.debug_tx.clone();
336
337 let events = project_state
338 .map(|state| {
339 state
340 .events
341 .iter()
342 .filter_map(|event| match event {
343 Event::BufferChange {
344 old_snapshot,
345 new_snapshot,
346 ..
347 } => {
348 let path = new_snapshot.file().map(|f| f.full_path(cx));
349
350 let old_path = old_snapshot.file().and_then(|f| {
351 let old_path = f.full_path(cx);
352 if Some(&old_path) != path.as_ref() {
353 Some(old_path)
354 } else {
355 None
356 }
357 });
358
359 // TODO [zeta2] move to bg?
360 let diff =
361 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
362
363 if path == old_path && diff.is_empty() {
364 None
365 } else {
366 Some(predict_edits_v3::Event::BufferChange {
367 old_path,
368 path,
369 diff,
370 //todo: Actually detect if this edit was predicted or not
371 predicted: false,
372 })
373 }
374 }
375 })
376 .collect::<Vec<_>>()
377 })
378 .unwrap_or_default();
379
380 let diagnostics = snapshot.diagnostic_sets().clone();
381
382 let request_task = cx.background_spawn({
383 let snapshot = snapshot.clone();
384 let buffer = buffer.clone();
385 async move {
386 let index_state = if let Some(index_state) = index_state {
387 Some(index_state.lock_owned().await)
388 } else {
389 None
390 };
391
392 let cursor_offset = position.to_offset(&snapshot);
393 let cursor_point = cursor_offset.to_point(&snapshot);
394
395 let before_retrieval = chrono::Utc::now();
396
397 let Some(context) = EditPredictionContext::gather_context(
398 cursor_point,
399 &snapshot,
400 &options.excerpt,
401 index_state.as_deref(),
402 ) else {
403 return Ok(None);
404 };
405
406 let debug_context = if let Some(debug_tx) = debug_tx {
407 Some((debug_tx, context.clone()))
408 } else {
409 None
410 };
411
412 let (diagnostic_groups, diagnostic_groups_truncated) =
413 Self::gather_nearby_diagnostics(
414 cursor_offset,
415 &diagnostics,
416 &snapshot,
417 options.max_diagnostic_bytes,
418 );
419
420 let request = make_cloud_request(
421 excerpt_path,
422 context,
423 events,
424 // TODO data collection
425 false,
426 diagnostic_groups,
427 diagnostic_groups_truncated,
428 None,
429 debug_context.is_some(),
430 &worktree_snapshots,
431 index_state.as_deref(),
432 Some(options.max_prompt_bytes),
433 options.prompt_format,
434 );
435
436 let retrieval_time = chrono::Utc::now() - before_retrieval;
437 let response = Self::perform_request(client, llm_token, app_version, request).await;
438
439 if let Some((debug_tx, context)) = debug_context {
440 debug_tx
441 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
442 |response| {
443 let Some(request) =
444 some_or_debug_panic(response.0.debug_info.clone())
445 else {
446 return Err("Missing debug info".to_string());
447 };
448 Ok(PredictionDebugInfo {
449 context,
450 request,
451 retrieval_time,
452 buffer: buffer.downgrade(),
453 position,
454 })
455 },
456 ))
457 .ok();
458 }
459
460 let (response, usage) = response?;
461 let edits = edits_from_response(&response.edits, &snapshot);
462
463 anyhow::Ok(Some((response.request_id, edits, usage)))
464 }
465 });
466
467 let buffer = buffer.clone();
468
469 cx.spawn(async move |this, cx| {
470 match request_task.await {
471 Ok(Some((id, edits, usage))) => {
472 if let Some(usage) = usage {
473 this.update(cx, |this, cx| {
474 this.user_store.update(cx, |user_store, cx| {
475 user_store.update_edit_prediction_usage(usage, cx);
476 });
477 })
478 .ok();
479 }
480
481 // TODO telemetry: duration, etc
482 let Some((edits, snapshot, edit_preview_task)) =
483 buffer.read_with(cx, |buffer, cx| {
484 let new_snapshot = buffer.snapshot();
485 let edits: Arc<[_]> =
486 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
487 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
488 })?
489 else {
490 return Ok(None);
491 };
492
493 Ok(Some(EditPrediction {
494 id: id.into(),
495 edits,
496 snapshot,
497 edit_preview: edit_preview_task.await,
498 }))
499 }
500 Ok(None) => Ok(None),
501 Err(err) => {
502 if err.is::<ZedUpdateRequiredError>() {
503 cx.update(|cx| {
504 this.update(cx, |this, _cx| {
505 this.update_required = true;
506 })
507 .ok();
508
509 let error_message: SharedString = err.to_string().into();
510 show_app_notification(
511 NotificationId::unique::<ZedUpdateRequiredError>(),
512 cx,
513 move |cx| {
514 cx.new(|cx| {
515 ErrorMessagePrompt::new(error_message.clone(), cx)
516 .with_link_button(
517 "Update Zed",
518 "https://zed.dev/releases",
519 )
520 })
521 },
522 );
523 })
524 .ok();
525 }
526
527 Err(err)
528 }
529 }
530 })
531 }
532
533 async fn perform_request(
534 client: Arc<Client>,
535 llm_token: LlmApiToken,
536 app_version: SemanticVersion,
537 request: predict_edits_v3::PredictEditsRequest,
538 ) -> Result<(
539 predict_edits_v3::PredictEditsResponse,
540 Option<EditPredictionUsage>,
541 )> {
542 let http_client = client.http_client();
543 let mut token = llm_token.acquire(&client).await?;
544 let mut did_retry = false;
545
546 loop {
547 let request_builder = http_client::Request::builder().method(Method::POST);
548 let request_builder =
549 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
550 request_builder.uri(predict_edits_url)
551 } else {
552 request_builder.uri(
553 http_client
554 .build_zed_llm_url("/predict_edits/v3", &[])?
555 .as_ref(),
556 )
557 };
558 let request = request_builder
559 .header("Content-Type", "application/json")
560 .header("Authorization", format!("Bearer {}", token))
561 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
562 .body(serde_json::to_string(&request)?.into())?;
563
564 let mut response = http_client.send(request).await?;
565
566 if let Some(minimum_required_version) = response
567 .headers()
568 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
569 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
570 {
571 anyhow::ensure!(
572 app_version >= minimum_required_version,
573 ZedUpdateRequiredError {
574 minimum_version: minimum_required_version
575 }
576 );
577 }
578
579 if response.status().is_success() {
580 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
581
582 let mut body = Vec::new();
583 response.body_mut().read_to_end(&mut body).await?;
584 return Ok((serde_json::from_slice(&body)?, usage));
585 } else if !did_retry
586 && response
587 .headers()
588 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
589 .is_some()
590 {
591 did_retry = true;
592 token = llm_token.refresh(&client).await?;
593 } else {
594 let mut body = String::new();
595 response.body_mut().read_to_string(&mut body).await?;
596 anyhow::bail!(
597 "error predicting edits.\nStatus: {:?}\nBody: {}",
598 response.status(),
599 body
600 );
601 }
602 }
603 }
604
605 fn gather_nearby_diagnostics(
606 cursor_offset: usize,
607 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
608 snapshot: &BufferSnapshot,
609 max_diagnostics_bytes: usize,
610 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
611 // TODO: Could make this more efficient
612 let mut diagnostic_groups = Vec::new();
613 for (language_server_id, diagnostics) in diagnostic_sets {
614 let mut groups = Vec::new();
615 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
616 diagnostic_groups.extend(
617 groups
618 .into_iter()
619 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
620 );
621 }
622
623 // sort by proximity to cursor
624 diagnostic_groups.sort_by_key(|group| {
625 let range = &group.entries[group.primary_ix].range;
626 if range.start >= cursor_offset {
627 range.start - cursor_offset
628 } else if cursor_offset >= range.end {
629 cursor_offset - range.end
630 } else {
631 (cursor_offset - range.start).min(range.end - cursor_offset)
632 }
633 });
634
635 let mut results = Vec::new();
636 let mut diagnostic_groups_truncated = false;
637 let mut diagnostics_byte_count = 0;
638 for group in diagnostic_groups {
639 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
640 diagnostics_byte_count += raw_value.get().len();
641 if diagnostics_byte_count > max_diagnostics_bytes {
642 diagnostic_groups_truncated = true;
643 break;
644 }
645 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
646 }
647
648 (results, diagnostic_groups_truncated)
649 }
650
651 // TODO: Dedupe with similar code in request_prediction?
652 pub fn cloud_request_for_zeta_cli(
653 &mut self,
654 project: &Entity<Project>,
655 buffer: &Entity<Buffer>,
656 position: language::Anchor,
657 cx: &mut Context<Self>,
658 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
659 let project_state = self.projects.get(&project.entity_id());
660
661 let index_state = project_state.map(|state| {
662 state
663 .syntax_index
664 .read_with(cx, |index, _cx| index.state().clone())
665 });
666 let options = self.options.clone();
667 let snapshot = buffer.read(cx).snapshot();
668 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
669 return Task::ready(Err(anyhow!("No file path for excerpt")));
670 };
671 let worktree_snapshots = project
672 .read(cx)
673 .worktrees(cx)
674 .map(|worktree| worktree.read(cx).snapshot())
675 .collect::<Vec<_>>();
676
677 cx.background_spawn(async move {
678 let index_state = if let Some(index_state) = index_state {
679 Some(index_state.lock_owned().await)
680 } else {
681 None
682 };
683
684 let cursor_point = position.to_point(&snapshot);
685
686 let debug_info = true;
687 EditPredictionContext::gather_context(
688 cursor_point,
689 &snapshot,
690 &options.excerpt,
691 index_state.as_deref(),
692 )
693 .context("Failed to select excerpt")
694 .map(|context| {
695 make_cloud_request(
696 excerpt_path.into(),
697 context,
698 // TODO pass everything
699 Vec::new(),
700 false,
701 Vec::new(),
702 false,
703 None,
704 debug_info,
705 &worktree_snapshots,
706 index_state.as_deref(),
707 Some(options.max_prompt_bytes),
708 options.prompt_format,
709 )
710 })
711 })
712 }
713}
714
715#[derive(Error, Debug)]
716#[error(
717 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
718)]
719pub struct ZedUpdateRequiredError {
720 minimum_version: SemanticVersion,
721}
722
723fn make_cloud_request(
724 excerpt_path: Arc<Path>,
725 context: EditPredictionContext,
726 events: Vec<predict_edits_v3::Event>,
727 can_collect_data: bool,
728 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
729 diagnostic_groups_truncated: bool,
730 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
731 debug_info: bool,
732 worktrees: &Vec<worktree::Snapshot>,
733 index_state: Option<&SyntaxIndexState>,
734 prompt_max_bytes: Option<usize>,
735 prompt_format: PromptFormat,
736) -> predict_edits_v3::PredictEditsRequest {
737 let mut signatures = Vec::new();
738 let mut declaration_to_signature_index = HashMap::default();
739 let mut referenced_declarations = Vec::new();
740
741 for snippet in context.declarations {
742 let project_entry_id = snippet.declaration.project_entry_id();
743 let Some(path) = worktrees.iter().find_map(|worktree| {
744 worktree.entry_for_id(project_entry_id).map(|entry| {
745 let mut full_path = RelPathBuf::new();
746 full_path.push(worktree.root_name());
747 full_path.push(&entry.path);
748 full_path
749 })
750 }) else {
751 continue;
752 };
753
754 let parent_index = index_state.and_then(|index_state| {
755 snippet.declaration.parent().and_then(|parent| {
756 add_signature(
757 parent,
758 &mut declaration_to_signature_index,
759 &mut signatures,
760 index_state,
761 )
762 })
763 });
764
765 let (text, text_is_truncated) = snippet.declaration.item_text();
766 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
767 path: path.as_std_path().into(),
768 text: text.into(),
769 range: snippet.declaration.item_range(),
770 text_is_truncated,
771 signature_range: snippet.declaration.signature_range_in_item_text(),
772 parent_index,
773 score_components: snippet.score_components,
774 signature_score: snippet.scores.signature,
775 declaration_score: snippet.scores.declaration,
776 });
777 }
778
779 let excerpt_parent = index_state.and_then(|index_state| {
780 context
781 .excerpt
782 .parent_declarations
783 .last()
784 .and_then(|(parent, _)| {
785 add_signature(
786 *parent,
787 &mut declaration_to_signature_index,
788 &mut signatures,
789 index_state,
790 )
791 })
792 });
793
794 predict_edits_v3::PredictEditsRequest {
795 excerpt_path,
796 excerpt: context.excerpt_text.body,
797 excerpt_range: context.excerpt.range,
798 cursor_offset: context.cursor_offset_in_excerpt,
799 referenced_declarations,
800 signatures,
801 excerpt_parent,
802 events,
803 can_collect_data,
804 diagnostic_groups,
805 diagnostic_groups_truncated,
806 git_info,
807 debug_info,
808 prompt_max_bytes,
809 prompt_format,
810 }
811}
812
813fn add_signature(
814 declaration_id: DeclarationId,
815 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
816 signatures: &mut Vec<Signature>,
817 index: &SyntaxIndexState,
818) -> Option<usize> {
819 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
820 return Some(*signature_index);
821 }
822 let Some(parent_declaration) = index.declaration(declaration_id) else {
823 log::error!("bug: missing parent declaration");
824 return None;
825 };
826 let parent_index = parent_declaration.parent().and_then(|parent| {
827 add_signature(parent, declaration_to_signature_index, signatures, index)
828 });
829 let (text, text_is_truncated) = parent_declaration.signature_text();
830 let signature_index = signatures.len();
831 signatures.push(Signature {
832 text: text.into(),
833 text_is_truncated,
834 parent_index,
835 range: parent_declaration.signature_range(),
836 });
837 declaration_to_signature_index.insert(declaration_id, signature_index);
838 Some(signature_index)
839}
840
841#[cfg(test)]
842mod tests {
843 use std::{
844 path::{Path, PathBuf},
845 sync::Arc,
846 };
847
848 use client::UserStore;
849 use clock::FakeSystemClock;
850 use cloud_llm_client::predict_edits_v3;
851 use futures::{
852 AsyncReadExt, StreamExt,
853 channel::{mpsc, oneshot},
854 };
855 use gpui::{
856 Entity, TestAppContext,
857 http_client::{FakeHttpClient, Response},
858 prelude::*,
859 };
860 use indoc::indoc;
861 use language::{LanguageServerId, OffsetRangeExt as _};
862 use project::{FakeFs, Project};
863 use serde_json::json;
864 use settings::SettingsStore;
865 use util::path;
866 use uuid::Uuid;
867
868 use crate::Zeta;
869
870 #[gpui::test]
871 async fn test_simple_request(cx: &mut TestAppContext) {
872 let (zeta, mut req_rx) = init_test(cx);
873 let fs = FakeFs::new(cx.executor());
874 fs.insert_tree(
875 "/root",
876 json!({
877 "foo.md": "Hello!\nHow\nBye"
878 }),
879 )
880 .await;
881 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
882
883 let buffer = project
884 .update(cx, |project, cx| {
885 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
886 project.open_buffer(path, cx)
887 })
888 .await
889 .unwrap();
890 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
891 let position = snapshot.anchor_before(language::Point::new(1, 3));
892
893 let prediction_task = zeta.update(cx, |zeta, cx| {
894 zeta.request_prediction(&project, &buffer, position, cx)
895 });
896
897 let (request, respond_tx) = req_rx.next().await.unwrap();
898 assert_eq!(
899 request.excerpt_path.as_ref(),
900 Path::new(path!("root/foo.md"))
901 );
902 assert_eq!(request.cursor_offset, 10);
903
904 respond_tx
905 .send(predict_edits_v3::PredictEditsResponse {
906 request_id: Uuid::new_v4(),
907 edits: vec![predict_edits_v3::Edit {
908 path: Path::new(path!("root/foo.md")).into(),
909 range: 0..snapshot.len(),
910 content: "Hello!\nHow are you?\nBye".into(),
911 }],
912 debug_info: None,
913 })
914 .unwrap();
915
916 let prediction = prediction_task.await.unwrap().unwrap();
917
918 assert_eq!(prediction.edits.len(), 1);
919 assert_eq!(
920 prediction.edits[0].0.to_point(&snapshot).start,
921 language::Point::new(1, 3)
922 );
923 assert_eq!(prediction.edits[0].1, " are you?");
924 }
925
926 #[gpui::test]
927 async fn test_request_events(cx: &mut TestAppContext) {
928 let (zeta, mut req_rx) = init_test(cx);
929 let fs = FakeFs::new(cx.executor());
930 fs.insert_tree(
931 "/root",
932 json!({
933 "foo.md": "Hello!\n\nBye"
934 }),
935 )
936 .await;
937 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
938
939 let buffer = project
940 .update(cx, |project, cx| {
941 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
942 project.open_buffer(path, cx)
943 })
944 .await
945 .unwrap();
946
947 zeta.update(cx, |zeta, cx| {
948 zeta.register_buffer(&buffer, &project, cx);
949 });
950
951 buffer.update(cx, |buffer, cx| {
952 buffer.edit(vec![(7..7, "How")], None, cx);
953 });
954
955 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
956 let position = snapshot.anchor_before(language::Point::new(1, 3));
957
958 let prediction_task = zeta.update(cx, |zeta, cx| {
959 zeta.request_prediction(&project, &buffer, position, cx)
960 });
961
962 let (request, respond_tx) = req_rx.next().await.unwrap();
963
964 assert_eq!(request.events.len(), 1);
965 assert_eq!(
966 request.events[0],
967 predict_edits_v3::Event::BufferChange {
968 path: Some(PathBuf::from(path!("root/foo.md"))),
969 old_path: None,
970 diff: indoc! {"
971 @@ -1,3 +1,3 @@
972 Hello!
973 -
974 +How
975 Bye
976 "}
977 .to_string(),
978 predicted: false
979 }
980 );
981
982 respond_tx
983 .send(predict_edits_v3::PredictEditsResponse {
984 request_id: Uuid::new_v4(),
985 edits: vec![predict_edits_v3::Edit {
986 path: Path::new(path!("root/foo.md")).into(),
987 range: 0..snapshot.len(),
988 content: "Hello!\nHow are you?\nBye".into(),
989 }],
990 debug_info: None,
991 })
992 .unwrap();
993
994 let prediction = prediction_task.await.unwrap().unwrap();
995
996 assert_eq!(prediction.edits.len(), 1);
997 assert_eq!(
998 prediction.edits[0].0.to_point(&snapshot).start,
999 language::Point::new(1, 3)
1000 );
1001 assert_eq!(prediction.edits[0].1, " are you?");
1002 }
1003
1004 #[gpui::test]
1005 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1006 let (zeta, mut req_rx) = init_test(cx);
1007 let fs = FakeFs::new(cx.executor());
1008 fs.insert_tree(
1009 "/root",
1010 json!({
1011 "foo.md": "Hello!\nBye"
1012 }),
1013 )
1014 .await;
1015 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1016
1017 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1018 let diagnostic = lsp::Diagnostic {
1019 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1020 severity: Some(lsp::DiagnosticSeverity::ERROR),
1021 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1022 ..Default::default()
1023 };
1024
1025 project.update(cx, |project, cx| {
1026 project.lsp_store().update(cx, |lsp_store, cx| {
1027 // Create some diagnostics
1028 lsp_store
1029 .update_diagnostics(
1030 LanguageServerId(0),
1031 lsp::PublishDiagnosticsParams {
1032 uri: path_to_buffer_uri.clone(),
1033 diagnostics: vec![diagnostic],
1034 version: None,
1035 },
1036 None,
1037 language::DiagnosticSourceKind::Pushed,
1038 &[],
1039 cx,
1040 )
1041 .unwrap();
1042 });
1043 });
1044
1045 let buffer = project
1046 .update(cx, |project, cx| {
1047 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1048 project.open_buffer(path, cx)
1049 })
1050 .await
1051 .unwrap();
1052
1053 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1054 let position = snapshot.anchor_before(language::Point::new(0, 0));
1055
1056 let _prediction_task = zeta.update(cx, |zeta, cx| {
1057 zeta.request_prediction(&project, &buffer, position, cx)
1058 });
1059
1060 let (request, _respond_tx) = req_rx.next().await.unwrap();
1061
1062 assert_eq!(request.diagnostic_groups.len(), 1);
1063 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1064 .unwrap();
1065 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1066 assert_eq!(
1067 value,
1068 json!({
1069 "entries": [{
1070 "range": {
1071 "start": 8,
1072 "end": 10
1073 },
1074 "diagnostic": {
1075 "source": null,
1076 "code": null,
1077 "code_description": null,
1078 "severity": 1,
1079 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1080 "markdown": null,
1081 "group_id": 0,
1082 "is_primary": true,
1083 "is_disk_based": false,
1084 "is_unnecessary": false,
1085 "source_kind": "Pushed",
1086 "data": null,
1087 "underline": true
1088 }
1089 }],
1090 "primary_ix": 0
1091 })
1092 );
1093 }
1094
1095 fn init_test(
1096 cx: &mut TestAppContext,
1097 ) -> (
1098 Entity<Zeta>,
1099 mpsc::UnboundedReceiver<(
1100 predict_edits_v3::PredictEditsRequest,
1101 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1102 )>,
1103 ) {
1104 cx.update(move |cx| {
1105 let settings_store = SettingsStore::test(cx);
1106 cx.set_global(settings_store);
1107 language::init(cx);
1108 Project::init_settings(cx);
1109
1110 let (req_tx, req_rx) = mpsc::unbounded();
1111
1112 let http_client = FakeHttpClient::create({
1113 move |req| {
1114 let uri = req.uri().path().to_string();
1115 let mut body = req.into_body();
1116 let req_tx = req_tx.clone();
1117 async move {
1118 let resp = match uri.as_str() {
1119 "/client/llm_tokens" => serde_json::to_string(&json!({
1120 "token": "test"
1121 }))
1122 .unwrap(),
1123 "/predict_edits/v3" => {
1124 let mut buf = Vec::new();
1125 body.read_to_end(&mut buf).await.ok();
1126 let req = serde_json::from_slice(&buf).unwrap();
1127
1128 let (res_tx, res_rx) = oneshot::channel();
1129 req_tx.unbounded_send((req, res_tx)).unwrap();
1130 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1131 }
1132 _ => {
1133 panic!("Unexpected path: {}", uri)
1134 }
1135 };
1136
1137 Ok(Response::builder().body(resp.into()).unwrap())
1138 }
1139 }
1140 });
1141
1142 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1143 client.cloud_client().set_credentials(1, "test".into());
1144
1145 language_model::init(client.clone(), cx);
1146
1147 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1148 let zeta = Zeta::global(&client, &user_store, cx);
1149 (zeta, req_rx)
1150 })
1151 }
1152}