1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::{Project, ProjectPath};
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::ops::Range;
35use std::path::Path;
36use std::str::FromStr as _;
37use std::sync::{Arc, LazyLock};
38use std::time::{Duration, Instant};
39use std::{env, mem};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod assemble_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49mod sweep_ai;
50pub mod udiff;
51mod xml_edits;
52
53use crate::assemble_excerpts::assemble_excerpts;
54pub use crate::prediction::EditPrediction;
55pub use crate::prediction::EditPredictionId;
56pub use provider::ZetaEditPredictionProvider;
57
58/// Maximum number of events to track.
59const EVENT_COUNT_MAX_SWEEP: usize = 6;
60const EVENT_COUNT_MAX_ZETA: usize = 16;
61const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
62
63pub struct SweepFeatureFlag;
64
65impl FeatureFlag for SweepFeatureFlag {
66 const NAME: &str = "sweep-ai";
67}
68pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
69 max_bytes: 512,
70 min_bytes: 128,
71 target_before_cursor_over_total_bytes: 0.5,
72};
73
74pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
75 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
76
77pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
78 excerpt: DEFAULT_EXCERPT_OPTIONS,
79};
80
81pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
82 EditPredictionContextOptions {
83 use_imports: true,
84 max_retrieved_declarations: 0,
85 excerpt: DEFAULT_EXCERPT_OPTIONS,
86 score: EditPredictionScoreOptions {
87 omit_excerpt_overlaps: true,
88 },
89 };
90
91pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
92 context: DEFAULT_CONTEXT_OPTIONS,
93 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
94 max_diagnostic_bytes: 2048,
95 prompt_format: PromptFormat::DEFAULT,
96 file_indexing_parallelism: 1,
97 buffer_change_grouping_interval: Duration::from_secs(1),
98};
99
100static USE_OLLAMA: LazyLock<bool> =
101 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
102static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
103 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
104 "qwen3-coder:30b".to_string()
105 } else {
106 "yqvev8r3".to_string()
107 })
108});
109static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
110 match env::var("ZED_ZETA2_MODEL").as_deref() {
111 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
112 Ok(model) => model,
113 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
114 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
115 }
116 .to_string()
117});
118static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
119 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
120 if *USE_OLLAMA {
121 Some("http://localhost:11434/v1/chat/completions".into())
122 } else {
123 None
124 }
125 })
126});
127
128pub struct Zeta2FeatureFlag;
129
130impl FeatureFlag for Zeta2FeatureFlag {
131 const NAME: &'static str = "zeta2";
132
133 fn enabled_for_staff() -> bool {
134 false
135 }
136}
137
138#[derive(Clone)]
139struct ZetaGlobal(Entity<Zeta>);
140
141impl Global for ZetaGlobal {}
142
143pub struct Zeta {
144 client: Arc<Client>,
145 user_store: Entity<UserStore>,
146 llm_token: LlmApiToken,
147 _llm_token_subscription: Subscription,
148 projects: HashMap<EntityId, ZetaProject>,
149 options: ZetaOptions,
150 update_required: bool,
151 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
152 #[cfg(feature = "eval-support")]
153 eval_cache: Option<Arc<dyn EvalCache>>,
154 edit_prediction_model: ZetaEditPredictionModel,
155 sweep_api_token: Option<String>,
156 sweep_ai_debug_info: Arc<str>,
157}
158
159#[derive(PartialEq, Eq)]
160pub enum ZetaEditPredictionModel {
161 ZedCloud,
162 Sweep,
163}
164
165#[derive(Debug, Clone, PartialEq)]
166pub struct ZetaOptions {
167 pub context: ContextMode,
168 pub max_prompt_bytes: usize,
169 pub max_diagnostic_bytes: usize,
170 pub prompt_format: predict_edits_v3::PromptFormat,
171 pub file_indexing_parallelism: usize,
172 pub buffer_change_grouping_interval: Duration,
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub enum ContextMode {
177 Agentic(AgenticContextOptions),
178 Syntax(EditPredictionContextOptions),
179}
180
181#[derive(Debug, Clone, PartialEq)]
182pub struct AgenticContextOptions {
183 pub excerpt: EditPredictionExcerptOptions,
184}
185
186impl ContextMode {
187 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
188 match self {
189 ContextMode::Agentic(options) => &options.excerpt,
190 ContextMode::Syntax(options) => &options.excerpt,
191 }
192 }
193}
194
195#[derive(Debug)]
196pub enum ZetaDebugInfo {
197 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
198 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
199 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
200 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
201 EditPredictionRequested(ZetaEditPredictionDebugInfo),
202}
203
204#[derive(Debug)]
205pub struct ZetaContextRetrievalStartedDebugInfo {
206 pub project: Entity<Project>,
207 pub timestamp: Instant,
208 pub search_prompt: String,
209}
210
211#[derive(Debug)]
212pub struct ZetaContextRetrievalDebugInfo {
213 pub project: Entity<Project>,
214 pub timestamp: Instant,
215}
216
217#[derive(Debug)]
218pub struct ZetaEditPredictionDebugInfo {
219 pub request: predict_edits_v3::PredictEditsRequest,
220 pub retrieval_time: TimeDelta,
221 pub buffer: WeakEntity<Buffer>,
222 pub position: language::Anchor,
223 pub local_prompt: Result<String, String>,
224 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
225}
226
227#[derive(Debug)]
228pub struct ZetaSearchQueryDebugInfo {
229 pub project: Entity<Project>,
230 pub timestamp: Instant,
231 pub search_queries: Vec<SearchToolQuery>,
232}
233
234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
235
236struct ZetaProject {
237 syntax_index: Option<Entity<SyntaxIndex>>,
238 events: VecDeque<Event>,
239 recent_paths: VecDeque<ProjectPath>,
240 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
241 current_prediction: Option<CurrentEditPrediction>,
242 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
243 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
244 refresh_context_debounce_task: Option<Task<Option<()>>>,
245 refresh_context_timestamp: Option<Instant>,
246 _subscription: gpui::Subscription,
247}
248
249#[derive(Debug, Clone)]
250struct CurrentEditPrediction {
251 pub requested_by_buffer_id: EntityId,
252 pub prediction: EditPrediction,
253}
254
255impl CurrentEditPrediction {
256 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
257 let Some(new_edits) = self
258 .prediction
259 .interpolate(&self.prediction.buffer.read(cx))
260 else {
261 return false;
262 };
263
264 if self.prediction.buffer != old_prediction.prediction.buffer {
265 return true;
266 }
267
268 let Some(old_edits) = old_prediction
269 .prediction
270 .interpolate(&old_prediction.prediction.buffer.read(cx))
271 else {
272 return true;
273 };
274
275 // This reduces the occurrence of UI thrash from replacing edits
276 //
277 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
278 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
279 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
280 && old_edits.len() == 1
281 && new_edits.len() == 1
282 {
283 let (old_range, old_text) = &old_edits[0];
284 let (new_range, new_text) = &new_edits[0];
285 new_range == old_range && new_text.starts_with(old_text.as_ref())
286 } else {
287 true
288 }
289 }
290}
291
292/// A prediction from the perspective of a buffer.
293#[derive(Debug)]
294enum BufferEditPrediction<'a> {
295 Local { prediction: &'a EditPrediction },
296 Jump { prediction: &'a EditPrediction },
297}
298
299struct RegisteredBuffer {
300 snapshot: BufferSnapshot,
301 _subscriptions: [gpui::Subscription; 2],
302}
303
304#[derive(Clone)]
305pub enum Event {
306 BufferChange {
307 old_snapshot: BufferSnapshot,
308 new_snapshot: BufferSnapshot,
309 end_edit_anchor: Option<Anchor>,
310 timestamp: Instant,
311 },
312}
313
314impl Event {
315 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
316 match self {
317 Event::BufferChange {
318 old_snapshot,
319 new_snapshot,
320 ..
321 } => {
322 let path = new_snapshot.file().map(|f| f.full_path(cx));
323
324 let old_path = old_snapshot.file().and_then(|f| {
325 let old_path = f.full_path(cx);
326 if Some(&old_path) != path.as_ref() {
327 Some(old_path)
328 } else {
329 None
330 }
331 });
332
333 // TODO [zeta2] move to bg?
334 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
335
336 if path == old_path && diff.is_empty() {
337 None
338 } else {
339 Some(predict_edits_v3::Event::BufferChange {
340 old_path,
341 path,
342 diff,
343 //todo: Actually detect if this edit was predicted or not
344 predicted: false,
345 })
346 }
347 }
348 }
349 }
350
351 pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
352 match self {
353 Event::BufferChange { new_snapshot, .. } => new_snapshot
354 .file()
355 .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
356 }
357 }
358}
359
360impl Zeta {
361 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
362 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
363 }
364
365 pub fn global(
366 client: &Arc<Client>,
367 user_store: &Entity<UserStore>,
368 cx: &mut App,
369 ) -> Entity<Self> {
370 cx.try_global::<ZetaGlobal>()
371 .map(|global| global.0.clone())
372 .unwrap_or_else(|| {
373 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
374 cx.set_global(ZetaGlobal(zeta.clone()));
375 zeta
376 })
377 }
378
379 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
380 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
381
382 Self {
383 projects: HashMap::default(),
384 client,
385 user_store,
386 options: DEFAULT_OPTIONS,
387 llm_token: LlmApiToken::default(),
388 _llm_token_subscription: cx.subscribe(
389 &refresh_llm_token_listener,
390 |this, _listener, _event, cx| {
391 let client = this.client.clone();
392 let llm_token = this.llm_token.clone();
393 cx.spawn(async move |_this, _cx| {
394 llm_token.refresh(&client).await?;
395 anyhow::Ok(())
396 })
397 .detach_and_log_err(cx);
398 },
399 ),
400 update_required: false,
401 debug_tx: None,
402 #[cfg(feature = "eval-support")]
403 eval_cache: None,
404 edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
405 sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
406 .context("No SWEEP_AI_TOKEN environment variable set")
407 .log_err(),
408 sweep_ai_debug_info: sweep_ai::debug_info(cx),
409 }
410 }
411
412 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
413 self.edit_prediction_model = model;
414 }
415
416 pub fn has_sweep_api_token(&self) -> bool {
417 self.sweep_api_token.is_some()
418 }
419
420 #[cfg(feature = "eval-support")]
421 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
422 self.eval_cache = Some(cache);
423 }
424
425 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
426 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
427 self.debug_tx = Some(debug_watch_tx);
428 debug_watch_rx
429 }
430
431 pub fn options(&self) -> &ZetaOptions {
432 &self.options
433 }
434
435 pub fn set_options(&mut self, options: ZetaOptions) {
436 self.options = options;
437 }
438
439 pub fn clear_history(&mut self) {
440 for zeta_project in self.projects.values_mut() {
441 zeta_project.events.clear();
442 }
443 }
444
445 pub fn history_for_project(
446 &self,
447 project: &Entity<Project>,
448 ) -> impl DoubleEndedIterator<Item = &Event> {
449 self.projects
450 .get(&project.entity_id())
451 .map(|project| project.events.iter())
452 .into_iter()
453 .flatten()
454 }
455
456 pub fn context_for_project(
457 &self,
458 project: &Entity<Project>,
459 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
460 self.projects
461 .get(&project.entity_id())
462 .and_then(|project| {
463 Some(
464 project
465 .context
466 .as_ref()?
467 .iter()
468 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
469 )
470 })
471 .into_iter()
472 .flatten()
473 }
474
475 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
476 if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
477 self.user_store.read(cx).edit_prediction_usage()
478 } else {
479 None
480 }
481 }
482
483 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
484 self.get_or_init_zeta_project(project, cx);
485 }
486
487 pub fn register_buffer(
488 &mut self,
489 buffer: &Entity<Buffer>,
490 project: &Entity<Project>,
491 cx: &mut Context<Self>,
492 ) {
493 let zeta_project = self.get_or_init_zeta_project(project, cx);
494 Self::register_buffer_impl(zeta_project, buffer, project, cx);
495 }
496
497 fn get_or_init_zeta_project(
498 &mut self,
499 project: &Entity<Project>,
500 cx: &mut Context<Self>,
501 ) -> &mut ZetaProject {
502 self.projects
503 .entry(project.entity_id())
504 .or_insert_with(|| ZetaProject {
505 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
506 Some(cx.new(|cx| {
507 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
508 }))
509 } else {
510 None
511 },
512 events: VecDeque::new(),
513 recent_paths: VecDeque::new(),
514 registered_buffers: HashMap::default(),
515 current_prediction: None,
516 context: None,
517 refresh_context_task: None,
518 refresh_context_debounce_task: None,
519 refresh_context_timestamp: None,
520 _subscription: cx.subscribe(&project, |this, project, event, cx| {
521 // TODO [zeta2] init with recent paths
522 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
523 if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
524 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
525 if let Some(path) = path {
526 if let Some(ix) = zeta_project
527 .recent_paths
528 .iter()
529 .position(|probe| probe == &path)
530 {
531 zeta_project.recent_paths.remove(ix);
532 }
533 zeta_project.recent_paths.push_front(path);
534 }
535 }
536 }
537 }),
538 })
539 }
540
541 fn register_buffer_impl<'a>(
542 zeta_project: &'a mut ZetaProject,
543 buffer: &Entity<Buffer>,
544 project: &Entity<Project>,
545 cx: &mut Context<Self>,
546 ) -> &'a mut RegisteredBuffer {
547 let buffer_id = buffer.entity_id();
548 match zeta_project.registered_buffers.entry(buffer_id) {
549 hash_map::Entry::Occupied(entry) => entry.into_mut(),
550 hash_map::Entry::Vacant(entry) => {
551 let snapshot = buffer.read(cx).snapshot();
552 let project_entity_id = project.entity_id();
553 entry.insert(RegisteredBuffer {
554 snapshot,
555 _subscriptions: [
556 cx.subscribe(buffer, {
557 let project = project.downgrade();
558 move |this, buffer, event, cx| {
559 if let language::BufferEvent::Edited = event
560 && let Some(project) = project.upgrade()
561 {
562 this.report_changes_for_buffer(&buffer, &project, cx);
563 }
564 }
565 }),
566 cx.observe_release(buffer, move |this, _buffer, _cx| {
567 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
568 else {
569 return;
570 };
571 zeta_project.registered_buffers.remove(&buffer_id);
572 }),
573 ],
574 })
575 }
576 }
577 }
578
579 fn report_changes_for_buffer(
580 &mut self,
581 buffer: &Entity<Buffer>,
582 project: &Entity<Project>,
583 cx: &mut Context<Self>,
584 ) {
585 let event_count_max = match self.edit_prediction_model {
586 ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
587 ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
588 };
589
590 let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
591 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
592
593 let new_snapshot = buffer.read(cx).snapshot();
594 if new_snapshot.version == registered_buffer.snapshot.version {
595 return;
596 }
597
598 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
599 let end_edit_anchor = new_snapshot
600 .anchored_edits_since::<Point>(&old_snapshot.version)
601 .last()
602 .map(|(_, range)| range.end);
603 let events = &mut sweep_ai_project.events;
604
605 if let Some(Event::BufferChange {
606 new_snapshot: last_new_snapshot,
607 end_edit_anchor: last_end_edit_anchor,
608 ..
609 }) = events.back_mut()
610 {
611 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
612 == last_new_snapshot.remote_id()
613 && old_snapshot.version == last_new_snapshot.version;
614
615 let should_coalesce = is_next_snapshot_of_same_buffer
616 && end_edit_anchor
617 .as_ref()
618 .zip(last_end_edit_anchor.as_ref())
619 .is_some_and(|(a, b)| {
620 let a = a.to_point(&new_snapshot);
621 let b = b.to_point(&new_snapshot);
622 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
623 });
624
625 if should_coalesce {
626 *last_end_edit_anchor = end_edit_anchor;
627 *last_new_snapshot = new_snapshot;
628 return;
629 }
630 }
631
632 if events.len() >= event_count_max {
633 events.pop_front();
634 }
635
636 events.push_back(Event::BufferChange {
637 old_snapshot,
638 new_snapshot,
639 end_edit_anchor,
640 timestamp: Instant::now(),
641 });
642 }
643
644 fn current_prediction_for_buffer(
645 &self,
646 buffer: &Entity<Buffer>,
647 project: &Entity<Project>,
648 cx: &App,
649 ) -> Option<BufferEditPrediction<'_>> {
650 let project_state = self.projects.get(&project.entity_id())?;
651
652 let CurrentEditPrediction {
653 requested_by_buffer_id,
654 prediction,
655 } = project_state.current_prediction.as_ref()?;
656
657 if prediction.targets_buffer(buffer.read(cx)) {
658 Some(BufferEditPrediction::Local { prediction })
659 } else if *requested_by_buffer_id == buffer.entity_id() {
660 Some(BufferEditPrediction::Jump { prediction })
661 } else {
662 None
663 }
664 }
665
666 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
667 if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
668 return;
669 }
670
671 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
672 return;
673 };
674
675 let Some(prediction) = project_state.current_prediction.take() else {
676 return;
677 };
678 let request_id = prediction.prediction.id.to_string();
679
680 let client = self.client.clone();
681 let llm_token = self.llm_token.clone();
682 let app_version = AppVersion::global(cx);
683 cx.spawn(async move |this, cx| {
684 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
685 http_client::Url::parse(&predict_edits_url)?
686 } else {
687 client
688 .http_client()
689 .build_zed_llm_url("/predict_edits/accept", &[])?
690 };
691
692 let response = cx
693 .background_spawn(Self::send_api_request::<()>(
694 move |builder| {
695 let req = builder.uri(url.as_ref()).body(
696 serde_json::to_string(&AcceptEditPredictionBody {
697 request_id: request_id.clone(),
698 })?
699 .into(),
700 );
701 Ok(req?)
702 },
703 client,
704 llm_token,
705 app_version,
706 ))
707 .await;
708
709 Self::handle_api_response(&this, response, cx)?;
710 anyhow::Ok(())
711 })
712 .detach_and_log_err(cx);
713 }
714
715 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
716 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
717 project_state.current_prediction.take();
718 };
719 }
720
721 pub fn refresh_prediction(
722 &mut self,
723 project: &Entity<Project>,
724 buffer: &Entity<Buffer>,
725 position: language::Anchor,
726 cx: &mut Context<Self>,
727 ) -> Task<Result<()>> {
728 let request_task = self.request_prediction(project, buffer, position, cx);
729 let buffer = buffer.clone();
730 let project = project.clone();
731
732 cx.spawn(async move |this, cx| {
733 if let Some(prediction) = request_task.await? {
734 this.update(cx, |this, cx| {
735 let project_state = this
736 .projects
737 .get_mut(&project.entity_id())
738 .context("Project not found")?;
739
740 let new_prediction = CurrentEditPrediction {
741 requested_by_buffer_id: buffer.entity_id(),
742 prediction: prediction,
743 };
744
745 if project_state
746 .current_prediction
747 .as_ref()
748 .is_none_or(|old_prediction| {
749 new_prediction.should_replace_prediction(&old_prediction, cx)
750 })
751 {
752 project_state.current_prediction = Some(new_prediction);
753 }
754 anyhow::Ok(())
755 })??;
756 }
757 Ok(())
758 })
759 }
760
761 pub fn request_prediction(
762 &mut self,
763 project: &Entity<Project>,
764 active_buffer: &Entity<Buffer>,
765 position: language::Anchor,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<Option<EditPrediction>>> {
768 match self.edit_prediction_model {
769 ZetaEditPredictionModel::ZedCloud => {
770 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
771 }
772 ZetaEditPredictionModel::Sweep => {
773 self.request_prediction_with_sweep(project, active_buffer, position, cx)
774 }
775 }
776 }
777
778 fn request_prediction_with_sweep(
779 &mut self,
780 project: &Entity<Project>,
781 active_buffer: &Entity<Buffer>,
782 position: language::Anchor,
783 cx: &mut Context<Self>,
784 ) -> Task<Result<Option<EditPrediction>>> {
785 let snapshot = active_buffer.read(cx).snapshot();
786 let debug_info = self.sweep_ai_debug_info.clone();
787 let Some(api_token) = self.sweep_api_token.clone() else {
788 return Task::ready(Ok(None));
789 };
790 let full_path: Arc<Path> = snapshot
791 .file()
792 .map(|file| file.full_path(cx))
793 .unwrap_or_else(|| "untitled".into())
794 .into();
795
796 let project_file = project::File::from_dyn(snapshot.file());
797 let repo_name = project_file
798 .map(|file| file.worktree.read(cx).root_name_str())
799 .unwrap_or("untitled")
800 .into();
801 let offset = position.to_offset(&snapshot);
802
803 let project_state = self.get_or_init_zeta_project(project, cx);
804 let events = project_state.events.clone();
805 let recent_buffers = project_state.recent_paths.iter().cloned();
806 let http_client = cx.http_client();
807
808 let recent_buffer_snapshots = recent_buffers
809 .filter_map(|project_path| {
810 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
811 if active_buffer == &buffer {
812 None
813 } else {
814 Some(buffer.read(cx).snapshot())
815 }
816 })
817 .take(3)
818 .collect::<Vec<_>>();
819
820 let result = cx.background_spawn(async move {
821 let text = snapshot.text();
822
823 let mut recent_changes = String::new();
824 for event in events {
825 sweep_ai::write_event(event, &mut recent_changes).unwrap();
826 }
827
828 let file_chunks = recent_buffer_snapshots
829 .into_iter()
830 .map(|snapshot| {
831 let end_point = language::Point::new(30, 0).min(snapshot.max_point());
832 sweep_ai::FileChunk {
833 content: snapshot
834 .text_for_range(language::Point::zero()..end_point)
835 .collect(),
836 file_path: snapshot
837 .file()
838 .map(|f| f.path().as_unix_str())
839 .unwrap_or("untitled")
840 .to_string(),
841 start_line: 0,
842 end_line: end_point.row as usize,
843 timestamp: snapshot.file().and_then(|file| {
844 Some(
845 file.disk_state()
846 .mtime()?
847 .to_seconds_and_nanos_for_persistence()?
848 .0,
849 )
850 }),
851 }
852 })
853 .collect();
854
855 let request_body = sweep_ai::AutocompleteRequest {
856 debug_info,
857 repo_name,
858 file_path: full_path.clone(),
859 file_contents: text.clone(),
860 original_file_contents: text,
861 cursor_position: offset,
862 recent_changes: recent_changes.clone(),
863 changes_above_cursor: true,
864 multiple_suggestions: false,
865 branch: None,
866 file_chunks,
867 retrieval_chunks: vec![],
868 recent_user_actions: vec![],
869 // TODO
870 privacy_mode_enabled: false,
871 };
872
873 let mut buf: Vec<u8> = Vec::new();
874 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
875 serde_json::to_writer(writer, &request_body)?;
876 let body: AsyncBody = buf.into();
877
878 const SWEEP_API_URL: &str =
879 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
880
881 let request = http_client::Request::builder()
882 .uri(SWEEP_API_URL)
883 .header("Content-Type", "application/json")
884 .header("Authorization", format!("Bearer {}", api_token))
885 .header("Connection", "keep-alive")
886 .header("Content-Encoding", "br")
887 .method(Method::POST)
888 .body(body)?;
889
890 let mut response = http_client.send(request).await?;
891
892 let mut body: Vec<u8> = Vec::new();
893 response.body_mut().read_to_end(&mut body).await?;
894
895 if !response.status().is_success() {
896 anyhow::bail!(
897 "Request failed with status: {:?}\nBody: {}",
898 response.status(),
899 String::from_utf8_lossy(&body),
900 );
901 };
902
903 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
904
905 let old_text = snapshot
906 .text_for_range(response.start_index..response.end_index)
907 .collect::<String>();
908 let edits = language::text_diff(&old_text, &response.completion)
909 .into_iter()
910 .map(|(range, text)| {
911 (
912 snapshot.anchor_after(response.start_index + range.start)
913 ..snapshot.anchor_before(response.start_index + range.end),
914 text,
915 )
916 })
917 .collect::<Vec<_>>();
918
919 anyhow::Ok((response.autocomplete_id, edits, snapshot))
920 });
921
922 let buffer = active_buffer.clone();
923
924 cx.spawn(async move |_, cx| {
925 let (id, edits, old_snapshot) = result.await?;
926
927 if edits.is_empty() {
928 return anyhow::Ok(None);
929 }
930
931 let Some((edits, new_snapshot, preview_task)) =
932 buffer.read_with(cx, |buffer, cx| {
933 let new_snapshot = buffer.snapshot();
934
935 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
936 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
937 .into();
938 let preview_task = buffer.preview_edits(edits.clone(), cx);
939
940 Some((edits, new_snapshot, preview_task))
941 })?
942 else {
943 return anyhow::Ok(None);
944 };
945
946 let prediction = EditPrediction {
947 id: EditPredictionId(id.into()),
948 edits,
949 snapshot: new_snapshot,
950 edit_preview: preview_task.await,
951 buffer,
952 };
953
954 anyhow::Ok(Some(prediction))
955 })
956 }
957
958 fn request_prediction_with_zed_cloud(
959 &mut self,
960 project: &Entity<Project>,
961 active_buffer: &Entity<Buffer>,
962 position: language::Anchor,
963 cx: &mut Context<Self>,
964 ) -> Task<Result<Option<EditPrediction>>> {
965 let project_state = self.projects.get(&project.entity_id());
966
967 let index_state = project_state.and_then(|state| {
968 state
969 .syntax_index
970 .as_ref()
971 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
972 });
973 let options = self.options.clone();
974 let active_snapshot = active_buffer.read(cx).snapshot();
975 let Some(excerpt_path) = active_snapshot
976 .file()
977 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
978 else {
979 return Task::ready(Err(anyhow!("No file path for excerpt")));
980 };
981 let client = self.client.clone();
982 let llm_token = self.llm_token.clone();
983 let app_version = AppVersion::global(cx);
984 let worktree_snapshots = project
985 .read(cx)
986 .worktrees(cx)
987 .map(|worktree| worktree.read(cx).snapshot())
988 .collect::<Vec<_>>();
989 let debug_tx = self.debug_tx.clone();
990
991 let events = project_state
992 .map(|state| {
993 state
994 .events
995 .iter()
996 .filter_map(|event| event.to_request_event(cx))
997 .collect::<Vec<_>>()
998 })
999 .unwrap_or_default();
1000
1001 let diagnostics = active_snapshot.diagnostic_sets().clone();
1002
1003 let parent_abs_path =
1004 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1005 let mut path = f.worktree.read(cx).absolutize(&f.path);
1006 if path.pop() { Some(path) } else { None }
1007 });
1008
1009 // TODO data collection
1010 let can_collect_data = cx.is_staff();
1011
1012 let empty_context_files = HashMap::default();
1013 let context_files = project_state
1014 .and_then(|project_state| project_state.context.as_ref())
1015 .unwrap_or(&empty_context_files);
1016
1017 #[cfg(feature = "eval-support")]
1018 let parsed_fut = futures::future::join_all(
1019 context_files
1020 .keys()
1021 .map(|buffer| buffer.read(cx).parsing_idle()),
1022 );
1023
1024 let mut included_files = context_files
1025 .iter()
1026 .filter_map(|(buffer_entity, ranges)| {
1027 let buffer = buffer_entity.read(cx);
1028 Some((
1029 buffer_entity.clone(),
1030 buffer.snapshot(),
1031 buffer.file()?.full_path(cx).into(),
1032 ranges.clone(),
1033 ))
1034 })
1035 .collect::<Vec<_>>();
1036
1037 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1038 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1039 });
1040
1041 #[cfg(feature = "eval-support")]
1042 let eval_cache = self.eval_cache.clone();
1043
1044 let request_task = cx.background_spawn({
1045 let active_buffer = active_buffer.clone();
1046 async move {
1047 #[cfg(feature = "eval-support")]
1048 parsed_fut.await;
1049
1050 let index_state = if let Some(index_state) = index_state {
1051 Some(index_state.lock_owned().await)
1052 } else {
1053 None
1054 };
1055
1056 let cursor_offset = position.to_offset(&active_snapshot);
1057 let cursor_point = cursor_offset.to_point(&active_snapshot);
1058
1059 let before_retrieval = chrono::Utc::now();
1060
1061 let (diagnostic_groups, diagnostic_groups_truncated) =
1062 Self::gather_nearby_diagnostics(
1063 cursor_offset,
1064 &diagnostics,
1065 &active_snapshot,
1066 options.max_diagnostic_bytes,
1067 );
1068
1069 let cloud_request = match options.context {
1070 ContextMode::Agentic(context_options) => {
1071 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1072 cursor_point,
1073 &active_snapshot,
1074 &context_options.excerpt,
1075 index_state.as_deref(),
1076 ) else {
1077 return Ok((None, None));
1078 };
1079
1080 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1081 ..active_snapshot.anchor_before(excerpt.range.end);
1082
1083 if let Some(buffer_ix) =
1084 included_files.iter().position(|(_, snapshot, _, _)| {
1085 snapshot.remote_id() == active_snapshot.remote_id()
1086 })
1087 {
1088 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1089 ranges.push(excerpt_anchor_range);
1090 retrieval_search::merge_anchor_ranges(ranges, buffer);
1091 let last_ix = included_files.len() - 1;
1092 included_files.swap(buffer_ix, last_ix);
1093 } else {
1094 included_files.push((
1095 active_buffer.clone(),
1096 active_snapshot.clone(),
1097 excerpt_path.clone(),
1098 vec![excerpt_anchor_range],
1099 ));
1100 }
1101
1102 let included_files = included_files
1103 .iter()
1104 .map(|(_, snapshot, path, ranges)| {
1105 let ranges = ranges
1106 .iter()
1107 .map(|range| {
1108 let point_range = range.to_point(&snapshot);
1109 Line(point_range.start.row)..Line(point_range.end.row)
1110 })
1111 .collect::<Vec<_>>();
1112 let excerpts = assemble_excerpts(&snapshot, ranges);
1113 predict_edits_v3::IncludedFile {
1114 path: path.clone(),
1115 max_row: Line(snapshot.max_point().row),
1116 excerpts,
1117 }
1118 })
1119 .collect::<Vec<_>>();
1120
1121 predict_edits_v3::PredictEditsRequest {
1122 excerpt_path,
1123 excerpt: String::new(),
1124 excerpt_line_range: Line(0)..Line(0),
1125 excerpt_range: 0..0,
1126 cursor_point: predict_edits_v3::Point {
1127 line: predict_edits_v3::Line(cursor_point.row),
1128 column: cursor_point.column,
1129 },
1130 included_files,
1131 referenced_declarations: vec![],
1132 events,
1133 can_collect_data,
1134 diagnostic_groups,
1135 diagnostic_groups_truncated,
1136 debug_info: debug_tx.is_some(),
1137 prompt_max_bytes: Some(options.max_prompt_bytes),
1138 prompt_format: options.prompt_format,
1139 // TODO [zeta2]
1140 signatures: vec![],
1141 excerpt_parent: None,
1142 git_info: None,
1143 }
1144 }
1145 ContextMode::Syntax(context_options) => {
1146 let Some(context) = EditPredictionContext::gather_context(
1147 cursor_point,
1148 &active_snapshot,
1149 parent_abs_path.as_deref(),
1150 &context_options,
1151 index_state.as_deref(),
1152 ) else {
1153 return Ok((None, None));
1154 };
1155
1156 make_syntax_context_cloud_request(
1157 excerpt_path,
1158 context,
1159 events,
1160 can_collect_data,
1161 diagnostic_groups,
1162 diagnostic_groups_truncated,
1163 None,
1164 debug_tx.is_some(),
1165 &worktree_snapshots,
1166 index_state.as_deref(),
1167 Some(options.max_prompt_bytes),
1168 options.prompt_format,
1169 )
1170 }
1171 };
1172
1173 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1174
1175 let retrieval_time = chrono::Utc::now() - before_retrieval;
1176
1177 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1178 let (response_tx, response_rx) = oneshot::channel();
1179
1180 debug_tx
1181 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1182 ZetaEditPredictionDebugInfo {
1183 request: cloud_request.clone(),
1184 retrieval_time,
1185 buffer: active_buffer.downgrade(),
1186 local_prompt: match prompt_result.as_ref() {
1187 Ok((prompt, _)) => Ok(prompt.clone()),
1188 Err(err) => Err(err.to_string()),
1189 },
1190 position,
1191 response_rx,
1192 },
1193 ))
1194 .ok();
1195 Some(response_tx)
1196 } else {
1197 None
1198 };
1199
1200 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1201 if let Some(debug_response_tx) = debug_response_tx {
1202 debug_response_tx
1203 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1204 .ok();
1205 }
1206 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1207 }
1208
1209 let (prompt, _) = prompt_result?;
1210 let request = open_ai::Request {
1211 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1212 messages: vec![open_ai::RequestMessage::User {
1213 content: open_ai::MessageContent::Plain(prompt),
1214 }],
1215 stream: false,
1216 max_completion_tokens: None,
1217 stop: Default::default(),
1218 temperature: 0.7,
1219 tool_choice: None,
1220 parallel_tool_calls: None,
1221 tools: vec![],
1222 prompt_cache_key: None,
1223 reasoning_effort: None,
1224 };
1225
1226 log::trace!("Sending edit prediction request");
1227
1228 let before_request = chrono::Utc::now();
1229 let response = Self::send_raw_llm_request(
1230 request,
1231 client,
1232 llm_token,
1233 app_version,
1234 #[cfg(feature = "eval-support")]
1235 eval_cache,
1236 #[cfg(feature = "eval-support")]
1237 EvalCacheEntryKind::Prediction,
1238 )
1239 .await;
1240 let request_time = chrono::Utc::now() - before_request;
1241
1242 log::trace!("Got edit prediction response");
1243
1244 if let Some(debug_response_tx) = debug_response_tx {
1245 debug_response_tx
1246 .send((
1247 response
1248 .as_ref()
1249 .map_err(|err| err.to_string())
1250 .map(|response| response.0.clone()),
1251 request_time,
1252 ))
1253 .ok();
1254 }
1255
1256 let (res, usage) = response?;
1257 let request_id = EditPredictionId(res.id.clone().into());
1258 let Some(mut output_text) = text_from_response(res) else {
1259 return Ok((None, usage));
1260 };
1261
1262 if output_text.contains(CURSOR_MARKER) {
1263 log::trace!("Stripping out {CURSOR_MARKER} from response");
1264 output_text = output_text.replace(CURSOR_MARKER, "");
1265 }
1266
1267 let get_buffer_from_context = |path: &Path| {
1268 included_files
1269 .iter()
1270 .find_map(|(_, buffer, probe_path, ranges)| {
1271 if probe_path.as_ref() == path {
1272 Some((buffer, ranges.as_slice()))
1273 } else {
1274 None
1275 }
1276 })
1277 };
1278
1279 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1280 PromptFormat::NumLinesUniDiff => {
1281 // TODO: Implement parsing of multi-file diffs
1282 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1283 }
1284 PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1285 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1286 let edits = vec![];
1287 (&active_snapshot, edits)
1288 } else {
1289 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1290 }
1291 }
1292 PromptFormat::OldTextNewText => {
1293 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1294 .await?
1295 }
1296 _ => {
1297 bail!("unsupported prompt format {}", options.prompt_format)
1298 }
1299 };
1300
1301 let edited_buffer = included_files
1302 .iter()
1303 .find_map(|(buffer, snapshot, _, _)| {
1304 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1305 Some(buffer.clone())
1306 } else {
1307 None
1308 }
1309 })
1310 .context("Failed to find buffer in included_buffers")?;
1311
1312 anyhow::Ok((
1313 Some((
1314 request_id,
1315 edited_buffer,
1316 edited_buffer_snapshot.clone(),
1317 edits,
1318 )),
1319 usage,
1320 ))
1321 }
1322 });
1323
1324 cx.spawn({
1325 async move |this, cx| {
1326 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1327 Self::handle_api_response(&this, request_task.await, cx)?
1328 else {
1329 return Ok(None);
1330 };
1331
1332 // TODO telemetry: duration, etc
1333 Ok(
1334 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1335 .await,
1336 )
1337 }
1338 })
1339 }
1340
1341 async fn send_raw_llm_request(
1342 request: open_ai::Request,
1343 client: Arc<Client>,
1344 llm_token: LlmApiToken,
1345 app_version: SemanticVersion,
1346 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1347 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1348 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1349 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1350 http_client::Url::parse(&predict_edits_url)?
1351 } else {
1352 client
1353 .http_client()
1354 .build_zed_llm_url("/predict_edits/raw", &[])?
1355 };
1356
1357 #[cfg(feature = "eval-support")]
1358 let cache_key = if let Some(cache) = eval_cache {
1359 use collections::FxHasher;
1360 use std::hash::{Hash, Hasher};
1361
1362 let mut hasher = FxHasher::default();
1363 url.hash(&mut hasher);
1364 let request_str = serde_json::to_string_pretty(&request)?;
1365 request_str.hash(&mut hasher);
1366 let hash = hasher.finish();
1367
1368 let key = (eval_cache_kind, hash);
1369 if let Some(response_str) = cache.read(key) {
1370 return Ok((serde_json::from_str(&response_str)?, None));
1371 }
1372
1373 Some((cache, request_str, key))
1374 } else {
1375 None
1376 };
1377
1378 let (response, usage) = Self::send_api_request(
1379 |builder| {
1380 let req = builder
1381 .uri(url.as_ref())
1382 .body(serde_json::to_string(&request)?.into());
1383 Ok(req?)
1384 },
1385 client,
1386 llm_token,
1387 app_version,
1388 )
1389 .await?;
1390
1391 #[cfg(feature = "eval-support")]
1392 if let Some((cache, request, key)) = cache_key {
1393 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1394 }
1395
1396 Ok((response, usage))
1397 }
1398
1399 fn handle_api_response<T>(
1400 this: &WeakEntity<Self>,
1401 response: Result<(T, Option<EditPredictionUsage>)>,
1402 cx: &mut gpui::AsyncApp,
1403 ) -> Result<T> {
1404 match response {
1405 Ok((data, usage)) => {
1406 if let Some(usage) = usage {
1407 this.update(cx, |this, cx| {
1408 this.user_store.update(cx, |user_store, cx| {
1409 user_store.update_edit_prediction_usage(usage, cx);
1410 });
1411 })
1412 .ok();
1413 }
1414 Ok(data)
1415 }
1416 Err(err) => {
1417 if err.is::<ZedUpdateRequiredError>() {
1418 cx.update(|cx| {
1419 this.update(cx, |this, _cx| {
1420 this.update_required = true;
1421 })
1422 .ok();
1423
1424 let error_message: SharedString = err.to_string().into();
1425 show_app_notification(
1426 NotificationId::unique::<ZedUpdateRequiredError>(),
1427 cx,
1428 move |cx| {
1429 cx.new(|cx| {
1430 ErrorMessagePrompt::new(error_message.clone(), cx)
1431 .with_link_button("Update Zed", "https://zed.dev/releases")
1432 })
1433 },
1434 );
1435 })
1436 .ok();
1437 }
1438 Err(err)
1439 }
1440 }
1441 }
1442
1443 async fn send_api_request<Res>(
1444 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1445 client: Arc<Client>,
1446 llm_token: LlmApiToken,
1447 app_version: SemanticVersion,
1448 ) -> Result<(Res, Option<EditPredictionUsage>)>
1449 where
1450 Res: DeserializeOwned,
1451 {
1452 let http_client = client.http_client();
1453 let mut token = llm_token.acquire(&client).await?;
1454 let mut did_retry = false;
1455
1456 loop {
1457 let request_builder = http_client::Request::builder().method(Method::POST);
1458
1459 let request = build(
1460 request_builder
1461 .header("Content-Type", "application/json")
1462 .header("Authorization", format!("Bearer {}", token))
1463 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1464 )?;
1465
1466 let mut response = http_client.send(request).await?;
1467
1468 if let Some(minimum_required_version) = response
1469 .headers()
1470 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1471 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1472 {
1473 anyhow::ensure!(
1474 app_version >= minimum_required_version,
1475 ZedUpdateRequiredError {
1476 minimum_version: minimum_required_version
1477 }
1478 );
1479 }
1480
1481 if response.status().is_success() {
1482 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1483
1484 let mut body = Vec::new();
1485 response.body_mut().read_to_end(&mut body).await?;
1486 return Ok((serde_json::from_slice(&body)?, usage));
1487 } else if !did_retry
1488 && response
1489 .headers()
1490 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1491 .is_some()
1492 {
1493 did_retry = true;
1494 token = llm_token.refresh(&client).await?;
1495 } else {
1496 let mut body = String::new();
1497 response.body_mut().read_to_string(&mut body).await?;
1498 anyhow::bail!(
1499 "Request failed with status: {:?}\nBody: {}",
1500 response.status(),
1501 body
1502 );
1503 }
1504 }
1505 }
1506
1507 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1508 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1509
1510 // Refresh the related excerpts when the user just beguns editing after
1511 // an idle period, and after they pause editing.
1512 fn refresh_context_if_needed(
1513 &mut self,
1514 project: &Entity<Project>,
1515 buffer: &Entity<language::Buffer>,
1516 cursor_position: language::Anchor,
1517 cx: &mut Context<Self>,
1518 ) {
1519 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1520 return;
1521 }
1522
1523 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1524 return;
1525 };
1526
1527 let now = Instant::now();
1528 let was_idle = zeta_project
1529 .refresh_context_timestamp
1530 .map_or(true, |timestamp| {
1531 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1532 });
1533 zeta_project.refresh_context_timestamp = Some(now);
1534 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1535 let buffer = buffer.clone();
1536 let project = project.clone();
1537 async move |this, cx| {
1538 if was_idle {
1539 log::debug!("refetching edit prediction context after idle");
1540 } else {
1541 cx.background_executor()
1542 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1543 .await;
1544 log::debug!("refetching edit prediction context after pause");
1545 }
1546 this.update(cx, |this, cx| {
1547 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1548
1549 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1550 zeta_project.refresh_context_task = Some(task.log_err());
1551 };
1552 })
1553 .ok()
1554 }
1555 }));
1556 }
1557
1558 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1559 // and avoid spawning more than one concurrent task.
1560 pub fn refresh_context(
1561 &mut self,
1562 project: Entity<Project>,
1563 buffer: Entity<language::Buffer>,
1564 cursor_position: language::Anchor,
1565 cx: &mut Context<Self>,
1566 ) -> Task<Result<()>> {
1567 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1568 return Task::ready(anyhow::Ok(()));
1569 };
1570
1571 let ContextMode::Agentic(options) = &self.options().context else {
1572 return Task::ready(anyhow::Ok(()));
1573 };
1574
1575 let snapshot = buffer.read(cx).snapshot();
1576 let cursor_point = cursor_position.to_point(&snapshot);
1577 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1578 cursor_point,
1579 &snapshot,
1580 &options.excerpt,
1581 None,
1582 ) else {
1583 return Task::ready(Ok(()));
1584 };
1585
1586 let app_version = AppVersion::global(cx);
1587 let client = self.client.clone();
1588 let llm_token = self.llm_token.clone();
1589 let debug_tx = self.debug_tx.clone();
1590 let current_file_path: Arc<Path> = snapshot
1591 .file()
1592 .map(|f| f.full_path(cx).into())
1593 .unwrap_or_else(|| Path::new("untitled").into());
1594
1595 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1596 predict_edits_v3::PlanContextRetrievalRequest {
1597 excerpt: cursor_excerpt.text(&snapshot).body,
1598 excerpt_path: current_file_path,
1599 excerpt_line_range: cursor_excerpt.line_range,
1600 cursor_file_max_row: Line(snapshot.max_point().row),
1601 events: zeta_project
1602 .events
1603 .iter()
1604 .filter_map(|ev| ev.to_request_event(cx))
1605 .collect(),
1606 },
1607 ) {
1608 Ok(prompt) => prompt,
1609 Err(err) => {
1610 return Task::ready(Err(err));
1611 }
1612 };
1613
1614 if let Some(debug_tx) = &debug_tx {
1615 debug_tx
1616 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1617 ZetaContextRetrievalStartedDebugInfo {
1618 project: project.clone(),
1619 timestamp: Instant::now(),
1620 search_prompt: prompt.clone(),
1621 },
1622 ))
1623 .ok();
1624 }
1625
1626 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1627 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1628 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1629 );
1630
1631 let description = schema
1632 .get("description")
1633 .and_then(|description| description.as_str())
1634 .unwrap()
1635 .to_string();
1636
1637 (schema.into(), description)
1638 });
1639
1640 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1641
1642 let request = open_ai::Request {
1643 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1644 messages: vec![open_ai::RequestMessage::User {
1645 content: open_ai::MessageContent::Plain(prompt),
1646 }],
1647 stream: false,
1648 max_completion_tokens: None,
1649 stop: Default::default(),
1650 temperature: 0.7,
1651 tool_choice: None,
1652 parallel_tool_calls: None,
1653 tools: vec![open_ai::ToolDefinition::Function {
1654 function: FunctionDefinition {
1655 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1656 description: Some(tool_description),
1657 parameters: Some(tool_schema),
1658 },
1659 }],
1660 prompt_cache_key: None,
1661 reasoning_effort: None,
1662 };
1663
1664 #[cfg(feature = "eval-support")]
1665 let eval_cache = self.eval_cache.clone();
1666
1667 cx.spawn(async move |this, cx| {
1668 log::trace!("Sending search planning request");
1669 let response = Self::send_raw_llm_request(
1670 request,
1671 client,
1672 llm_token,
1673 app_version,
1674 #[cfg(feature = "eval-support")]
1675 eval_cache.clone(),
1676 #[cfg(feature = "eval-support")]
1677 EvalCacheEntryKind::Context,
1678 )
1679 .await;
1680 let mut response = Self::handle_api_response(&this, response, cx)?;
1681 log::trace!("Got search planning response");
1682
1683 let choice = response
1684 .choices
1685 .pop()
1686 .context("No choices in retrieval response")?;
1687 let open_ai::RequestMessage::Assistant {
1688 content: _,
1689 tool_calls,
1690 } = choice.message
1691 else {
1692 anyhow::bail!("Retrieval response didn't include an assistant message");
1693 };
1694
1695 let mut queries: Vec<SearchToolQuery> = Vec::new();
1696 for tool_call in tool_calls {
1697 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1698 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1699 log::warn!(
1700 "Context retrieval response tried to call an unknown tool: {}",
1701 function.name
1702 );
1703
1704 continue;
1705 }
1706
1707 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1708 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1709 queries.extend(input.queries);
1710 }
1711
1712 if let Some(debug_tx) = &debug_tx {
1713 debug_tx
1714 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1715 ZetaSearchQueryDebugInfo {
1716 project: project.clone(),
1717 timestamp: Instant::now(),
1718 search_queries: queries.clone(),
1719 },
1720 ))
1721 .ok();
1722 }
1723
1724 log::trace!("Running retrieval search: {queries:#?}");
1725
1726 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1727 queries,
1728 project.clone(),
1729 #[cfg(feature = "eval-support")]
1730 eval_cache,
1731 cx,
1732 )
1733 .await;
1734
1735 log::trace!("Search queries executed");
1736
1737 if let Some(debug_tx) = &debug_tx {
1738 debug_tx
1739 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1740 ZetaContextRetrievalDebugInfo {
1741 project: project.clone(),
1742 timestamp: Instant::now(),
1743 },
1744 ))
1745 .ok();
1746 }
1747
1748 this.update(cx, |this, _cx| {
1749 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1750 return Ok(());
1751 };
1752 zeta_project.refresh_context_task.take();
1753 if let Some(debug_tx) = &this.debug_tx {
1754 debug_tx
1755 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1756 ZetaContextRetrievalDebugInfo {
1757 project,
1758 timestamp: Instant::now(),
1759 },
1760 ))
1761 .ok();
1762 }
1763 match related_excerpts_result {
1764 Ok(excerpts) => {
1765 zeta_project.context = Some(excerpts);
1766 Ok(())
1767 }
1768 Err(error) => Err(error),
1769 }
1770 })?
1771 })
1772 }
1773
1774 pub fn set_context(
1775 &mut self,
1776 project: Entity<Project>,
1777 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1778 ) {
1779 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1780 zeta_project.context = Some(context);
1781 }
1782 }
1783
1784 fn gather_nearby_diagnostics(
1785 cursor_offset: usize,
1786 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1787 snapshot: &BufferSnapshot,
1788 max_diagnostics_bytes: usize,
1789 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1790 // TODO: Could make this more efficient
1791 let mut diagnostic_groups = Vec::new();
1792 for (language_server_id, diagnostics) in diagnostic_sets {
1793 let mut groups = Vec::new();
1794 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1795 diagnostic_groups.extend(
1796 groups
1797 .into_iter()
1798 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1799 );
1800 }
1801
1802 // sort by proximity to cursor
1803 diagnostic_groups.sort_by_key(|group| {
1804 let range = &group.entries[group.primary_ix].range;
1805 if range.start >= cursor_offset {
1806 range.start - cursor_offset
1807 } else if cursor_offset >= range.end {
1808 cursor_offset - range.end
1809 } else {
1810 (cursor_offset - range.start).min(range.end - cursor_offset)
1811 }
1812 });
1813
1814 let mut results = Vec::new();
1815 let mut diagnostic_groups_truncated = false;
1816 let mut diagnostics_byte_count = 0;
1817 for group in diagnostic_groups {
1818 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1819 diagnostics_byte_count += raw_value.get().len();
1820 if diagnostics_byte_count > max_diagnostics_bytes {
1821 diagnostic_groups_truncated = true;
1822 break;
1823 }
1824 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1825 }
1826
1827 (results, diagnostic_groups_truncated)
1828 }
1829
1830 // TODO: Dedupe with similar code in request_prediction?
1831 pub fn cloud_request_for_zeta_cli(
1832 &mut self,
1833 project: &Entity<Project>,
1834 buffer: &Entity<Buffer>,
1835 position: language::Anchor,
1836 cx: &mut Context<Self>,
1837 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1838 let project_state = self.projects.get(&project.entity_id());
1839
1840 let index_state = project_state.and_then(|state| {
1841 state
1842 .syntax_index
1843 .as_ref()
1844 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1845 });
1846 let options = self.options.clone();
1847 let snapshot = buffer.read(cx).snapshot();
1848 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1849 return Task::ready(Err(anyhow!("No file path for excerpt")));
1850 };
1851 let worktree_snapshots = project
1852 .read(cx)
1853 .worktrees(cx)
1854 .map(|worktree| worktree.read(cx).snapshot())
1855 .collect::<Vec<_>>();
1856
1857 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1858 let mut path = f.worktree.read(cx).absolutize(&f.path);
1859 if path.pop() { Some(path) } else { None }
1860 });
1861
1862 cx.background_spawn(async move {
1863 let index_state = if let Some(index_state) = index_state {
1864 Some(index_state.lock_owned().await)
1865 } else {
1866 None
1867 };
1868
1869 let cursor_point = position.to_point(&snapshot);
1870
1871 let debug_info = true;
1872 EditPredictionContext::gather_context(
1873 cursor_point,
1874 &snapshot,
1875 parent_abs_path.as_deref(),
1876 match &options.context {
1877 ContextMode::Agentic(_) => {
1878 // TODO
1879 panic!("Llm mode not supported in zeta cli yet");
1880 }
1881 ContextMode::Syntax(edit_prediction_context_options) => {
1882 edit_prediction_context_options
1883 }
1884 },
1885 index_state.as_deref(),
1886 )
1887 .context("Failed to select excerpt")
1888 .map(|context| {
1889 make_syntax_context_cloud_request(
1890 excerpt_path.into(),
1891 context,
1892 // TODO pass everything
1893 Vec::new(),
1894 false,
1895 Vec::new(),
1896 false,
1897 None,
1898 debug_info,
1899 &worktree_snapshots,
1900 index_state.as_deref(),
1901 Some(options.max_prompt_bytes),
1902 options.prompt_format,
1903 )
1904 })
1905 })
1906 }
1907
1908 pub fn wait_for_initial_indexing(
1909 &mut self,
1910 project: &Entity<Project>,
1911 cx: &mut Context<Self>,
1912 ) -> Task<Result<()>> {
1913 let zeta_project = self.get_or_init_zeta_project(project, cx);
1914 if let Some(syntax_index) = &zeta_project.syntax_index {
1915 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1916 } else {
1917 Task::ready(Ok(()))
1918 }
1919 }
1920}
1921
1922pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1923 let choice = res.choices.pop()?;
1924 let output_text = match choice.message {
1925 open_ai::RequestMessage::Assistant {
1926 content: Some(open_ai::MessageContent::Plain(content)),
1927 ..
1928 } => content,
1929 open_ai::RequestMessage::Assistant {
1930 content: Some(open_ai::MessageContent::Multipart(mut content)),
1931 ..
1932 } => {
1933 if content.is_empty() {
1934 log::error!("No output from Baseten completion response");
1935 return None;
1936 }
1937
1938 match content.remove(0) {
1939 open_ai::MessagePart::Text { text } => text,
1940 open_ai::MessagePart::Image { .. } => {
1941 log::error!("Expected text, got an image");
1942 return None;
1943 }
1944 }
1945 }
1946 _ => {
1947 log::error!("Invalid response message: {:?}", choice.message);
1948 return None;
1949 }
1950 };
1951 Some(output_text)
1952}
1953
1954#[derive(Error, Debug)]
1955#[error(
1956 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1957)]
1958pub struct ZedUpdateRequiredError {
1959 minimum_version: SemanticVersion,
1960}
1961
1962fn make_syntax_context_cloud_request(
1963 excerpt_path: Arc<Path>,
1964 context: EditPredictionContext,
1965 events: Vec<predict_edits_v3::Event>,
1966 can_collect_data: bool,
1967 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1968 diagnostic_groups_truncated: bool,
1969 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1970 debug_info: bool,
1971 worktrees: &Vec<worktree::Snapshot>,
1972 index_state: Option<&SyntaxIndexState>,
1973 prompt_max_bytes: Option<usize>,
1974 prompt_format: PromptFormat,
1975) -> predict_edits_v3::PredictEditsRequest {
1976 let mut signatures = Vec::new();
1977 let mut declaration_to_signature_index = HashMap::default();
1978 let mut referenced_declarations = Vec::new();
1979
1980 for snippet in context.declarations {
1981 let project_entry_id = snippet.declaration.project_entry_id();
1982 let Some(path) = worktrees.iter().find_map(|worktree| {
1983 worktree.entry_for_id(project_entry_id).map(|entry| {
1984 let mut full_path = RelPathBuf::new();
1985 full_path.push(worktree.root_name());
1986 full_path.push(&entry.path);
1987 full_path
1988 })
1989 }) else {
1990 continue;
1991 };
1992
1993 let parent_index = index_state.and_then(|index_state| {
1994 snippet.declaration.parent().and_then(|parent| {
1995 add_signature(
1996 parent,
1997 &mut declaration_to_signature_index,
1998 &mut signatures,
1999 index_state,
2000 )
2001 })
2002 });
2003
2004 let (text, text_is_truncated) = snippet.declaration.item_text();
2005 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2006 path: path.as_std_path().into(),
2007 text: text.into(),
2008 range: snippet.declaration.item_line_range(),
2009 text_is_truncated,
2010 signature_range: snippet.declaration.signature_range_in_item_text(),
2011 parent_index,
2012 signature_score: snippet.score(DeclarationStyle::Signature),
2013 declaration_score: snippet.score(DeclarationStyle::Declaration),
2014 score_components: snippet.components,
2015 });
2016 }
2017
2018 let excerpt_parent = index_state.and_then(|index_state| {
2019 context
2020 .excerpt
2021 .parent_declarations
2022 .last()
2023 .and_then(|(parent, _)| {
2024 add_signature(
2025 *parent,
2026 &mut declaration_to_signature_index,
2027 &mut signatures,
2028 index_state,
2029 )
2030 })
2031 });
2032
2033 predict_edits_v3::PredictEditsRequest {
2034 excerpt_path,
2035 excerpt: context.excerpt_text.body,
2036 excerpt_line_range: context.excerpt.line_range,
2037 excerpt_range: context.excerpt.range,
2038 cursor_point: predict_edits_v3::Point {
2039 line: predict_edits_v3::Line(context.cursor_point.row),
2040 column: context.cursor_point.column,
2041 },
2042 referenced_declarations,
2043 included_files: vec![],
2044 signatures,
2045 excerpt_parent,
2046 events,
2047 can_collect_data,
2048 diagnostic_groups,
2049 diagnostic_groups_truncated,
2050 git_info,
2051 debug_info,
2052 prompt_max_bytes,
2053 prompt_format,
2054 }
2055}
2056
2057fn add_signature(
2058 declaration_id: DeclarationId,
2059 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2060 signatures: &mut Vec<Signature>,
2061 index: &SyntaxIndexState,
2062) -> Option<usize> {
2063 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2064 return Some(*signature_index);
2065 }
2066 let Some(parent_declaration) = index.declaration(declaration_id) else {
2067 log::error!("bug: missing parent declaration");
2068 return None;
2069 };
2070 let parent_index = parent_declaration.parent().and_then(|parent| {
2071 add_signature(parent, declaration_to_signature_index, signatures, index)
2072 });
2073 let (text, text_is_truncated) = parent_declaration.signature_text();
2074 let signature_index = signatures.len();
2075 signatures.push(Signature {
2076 text: text.into(),
2077 text_is_truncated,
2078 parent_index,
2079 range: parent_declaration.signature_line_range(),
2080 });
2081 declaration_to_signature_index.insert(declaration_id, signature_index);
2082 Some(signature_index)
2083}
2084
2085#[cfg(feature = "eval-support")]
2086pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2087
2088#[cfg(feature = "eval-support")]
2089#[derive(Debug, Clone, Copy, PartialEq)]
2090pub enum EvalCacheEntryKind {
2091 Context,
2092 Search,
2093 Prediction,
2094}
2095
2096#[cfg(feature = "eval-support")]
2097impl std::fmt::Display for EvalCacheEntryKind {
2098 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2099 match self {
2100 EvalCacheEntryKind::Search => write!(f, "search"),
2101 EvalCacheEntryKind::Context => write!(f, "context"),
2102 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2103 }
2104 }
2105}
2106
2107#[cfg(feature = "eval-support")]
2108pub trait EvalCache: Send + Sync {
2109 fn read(&self, key: EvalCacheKey) -> Option<String>;
2110 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2111}
2112
2113#[cfg(test)]
2114mod tests {
2115 use std::{path::Path, sync::Arc};
2116
2117 use client::UserStore;
2118 use clock::FakeSystemClock;
2119 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2120 use futures::{
2121 AsyncReadExt, StreamExt,
2122 channel::{mpsc, oneshot},
2123 };
2124 use gpui::{
2125 Entity, TestAppContext,
2126 http_client::{FakeHttpClient, Response},
2127 prelude::*,
2128 };
2129 use indoc::indoc;
2130 use language::OffsetRangeExt as _;
2131 use open_ai::Usage;
2132 use pretty_assertions::{assert_eq, assert_matches};
2133 use project::{FakeFs, Project};
2134 use serde_json::json;
2135 use settings::SettingsStore;
2136 use util::path;
2137 use uuid::Uuid;
2138
2139 use crate::{BufferEditPrediction, Zeta};
2140
2141 #[gpui::test]
2142 async fn test_current_state(cx: &mut TestAppContext) {
2143 let (zeta, mut req_rx) = init_test(cx);
2144 let fs = FakeFs::new(cx.executor());
2145 fs.insert_tree(
2146 "/root",
2147 json!({
2148 "1.txt": "Hello!\nHow\nBye\n",
2149 "2.txt": "Hola!\nComo\nAdios\n"
2150 }),
2151 )
2152 .await;
2153 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2154
2155 zeta.update(cx, |zeta, cx| {
2156 zeta.register_project(&project, cx);
2157 });
2158
2159 let buffer1 = project
2160 .update(cx, |project, cx| {
2161 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2162 project.open_buffer(path, cx)
2163 })
2164 .await
2165 .unwrap();
2166 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2167 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2168
2169 // Prediction for current file
2170
2171 let prediction_task = zeta.update(cx, |zeta, cx| {
2172 zeta.refresh_prediction(&project, &buffer1, position, cx)
2173 });
2174 let (_request, respond_tx) = req_rx.next().await.unwrap();
2175
2176 respond_tx
2177 .send(model_response(indoc! {r"
2178 --- a/root/1.txt
2179 +++ b/root/1.txt
2180 @@ ... @@
2181 Hello!
2182 -How
2183 +How are you?
2184 Bye
2185 "}))
2186 .unwrap();
2187 prediction_task.await.unwrap();
2188
2189 zeta.read_with(cx, |zeta, cx| {
2190 let prediction = zeta
2191 .current_prediction_for_buffer(&buffer1, &project, cx)
2192 .unwrap();
2193 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2194 });
2195
2196 // Context refresh
2197 let refresh_task = zeta.update(cx, |zeta, cx| {
2198 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2199 });
2200 let (_request, respond_tx) = req_rx.next().await.unwrap();
2201 respond_tx
2202 .send(open_ai::Response {
2203 id: Uuid::new_v4().to_string(),
2204 object: "response".into(),
2205 created: 0,
2206 model: "model".into(),
2207 choices: vec![open_ai::Choice {
2208 index: 0,
2209 message: open_ai::RequestMessage::Assistant {
2210 content: None,
2211 tool_calls: vec![open_ai::ToolCall {
2212 id: "search".into(),
2213 content: open_ai::ToolCallContent::Function {
2214 function: open_ai::FunctionContent {
2215 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2216 .to_string(),
2217 arguments: serde_json::to_string(&SearchToolInput {
2218 queries: Box::new([SearchToolQuery {
2219 glob: "root/2.txt".to_string(),
2220 syntax_node: vec![],
2221 content: Some(".".into()),
2222 }]),
2223 })
2224 .unwrap(),
2225 },
2226 },
2227 }],
2228 },
2229 finish_reason: None,
2230 }],
2231 usage: Usage {
2232 prompt_tokens: 0,
2233 completion_tokens: 0,
2234 total_tokens: 0,
2235 },
2236 })
2237 .unwrap();
2238 refresh_task.await.unwrap();
2239
2240 zeta.update(cx, |zeta, _cx| {
2241 zeta.discard_current_prediction(&project);
2242 });
2243
2244 // Prediction for another file
2245 let prediction_task = zeta.update(cx, |zeta, cx| {
2246 zeta.refresh_prediction(&project, &buffer1, position, cx)
2247 });
2248 let (_request, respond_tx) = req_rx.next().await.unwrap();
2249 respond_tx
2250 .send(model_response(indoc! {r#"
2251 --- a/root/2.txt
2252 +++ b/root/2.txt
2253 Hola!
2254 -Como
2255 +Como estas?
2256 Adios
2257 "#}))
2258 .unwrap();
2259 prediction_task.await.unwrap();
2260 zeta.read_with(cx, |zeta, cx| {
2261 let prediction = zeta
2262 .current_prediction_for_buffer(&buffer1, &project, cx)
2263 .unwrap();
2264 assert_matches!(
2265 prediction,
2266 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2267 );
2268 });
2269
2270 let buffer2 = project
2271 .update(cx, |project, cx| {
2272 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2273 project.open_buffer(path, cx)
2274 })
2275 .await
2276 .unwrap();
2277
2278 zeta.read_with(cx, |zeta, cx| {
2279 let prediction = zeta
2280 .current_prediction_for_buffer(&buffer2, &project, cx)
2281 .unwrap();
2282 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2283 });
2284 }
2285
2286 #[gpui::test]
2287 async fn test_simple_request(cx: &mut TestAppContext) {
2288 let (zeta, mut req_rx) = init_test(cx);
2289 let fs = FakeFs::new(cx.executor());
2290 fs.insert_tree(
2291 "/root",
2292 json!({
2293 "foo.md": "Hello!\nHow\nBye\n"
2294 }),
2295 )
2296 .await;
2297 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2298
2299 let buffer = project
2300 .update(cx, |project, cx| {
2301 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2302 project.open_buffer(path, cx)
2303 })
2304 .await
2305 .unwrap();
2306 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2307 let position = snapshot.anchor_before(language::Point::new(1, 3));
2308
2309 let prediction_task = zeta.update(cx, |zeta, cx| {
2310 zeta.request_prediction(&project, &buffer, position, cx)
2311 });
2312
2313 let (_, respond_tx) = req_rx.next().await.unwrap();
2314
2315 // TODO Put back when we have a structured request again
2316 // assert_eq!(
2317 // request.excerpt_path.as_ref(),
2318 // Path::new(path!("root/foo.md"))
2319 // );
2320 // assert_eq!(
2321 // request.cursor_point,
2322 // Point {
2323 // line: Line(1),
2324 // column: 3
2325 // }
2326 // );
2327
2328 respond_tx
2329 .send(model_response(indoc! { r"
2330 --- a/root/foo.md
2331 +++ b/root/foo.md
2332 @@ ... @@
2333 Hello!
2334 -How
2335 +How are you?
2336 Bye
2337 "}))
2338 .unwrap();
2339
2340 let prediction = prediction_task.await.unwrap().unwrap();
2341
2342 assert_eq!(prediction.edits.len(), 1);
2343 assert_eq!(
2344 prediction.edits[0].0.to_point(&snapshot).start,
2345 language::Point::new(1, 3)
2346 );
2347 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2348 }
2349
2350 #[gpui::test]
2351 async fn test_request_events(cx: &mut TestAppContext) {
2352 let (zeta, mut req_rx) = init_test(cx);
2353 let fs = FakeFs::new(cx.executor());
2354 fs.insert_tree(
2355 "/root",
2356 json!({
2357 "foo.md": "Hello!\n\nBye\n"
2358 }),
2359 )
2360 .await;
2361 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2362
2363 let buffer = project
2364 .update(cx, |project, cx| {
2365 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2366 project.open_buffer(path, cx)
2367 })
2368 .await
2369 .unwrap();
2370
2371 zeta.update(cx, |zeta, cx| {
2372 zeta.register_buffer(&buffer, &project, cx);
2373 });
2374
2375 buffer.update(cx, |buffer, cx| {
2376 buffer.edit(vec![(7..7, "How")], None, cx);
2377 });
2378
2379 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2380 let position = snapshot.anchor_before(language::Point::new(1, 3));
2381
2382 let prediction_task = zeta.update(cx, |zeta, cx| {
2383 zeta.request_prediction(&project, &buffer, position, cx)
2384 });
2385
2386 let (request, respond_tx) = req_rx.next().await.unwrap();
2387
2388 let prompt = prompt_from_request(&request);
2389 assert!(
2390 prompt.contains(indoc! {"
2391 --- a/root/foo.md
2392 +++ b/root/foo.md
2393 @@ -1,3 +1,3 @@
2394 Hello!
2395 -
2396 +How
2397 Bye
2398 "}),
2399 "{prompt}"
2400 );
2401
2402 respond_tx
2403 .send(model_response(indoc! {r#"
2404 --- a/root/foo.md
2405 +++ b/root/foo.md
2406 @@ ... @@
2407 Hello!
2408 -How
2409 +How are you?
2410 Bye
2411 "#}))
2412 .unwrap();
2413
2414 let prediction = prediction_task.await.unwrap().unwrap();
2415
2416 assert_eq!(prediction.edits.len(), 1);
2417 assert_eq!(
2418 prediction.edits[0].0.to_point(&snapshot).start,
2419 language::Point::new(1, 3)
2420 );
2421 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2422 }
2423
2424 // Skipped until we start including diagnostics in prompt
2425 // #[gpui::test]
2426 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2427 // let (zeta, mut req_rx) = init_test(cx);
2428 // let fs = FakeFs::new(cx.executor());
2429 // fs.insert_tree(
2430 // "/root",
2431 // json!({
2432 // "foo.md": "Hello!\nBye"
2433 // }),
2434 // )
2435 // .await;
2436 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2437
2438 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2439 // let diagnostic = lsp::Diagnostic {
2440 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2441 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2442 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2443 // ..Default::default()
2444 // };
2445
2446 // project.update(cx, |project, cx| {
2447 // project.lsp_store().update(cx, |lsp_store, cx| {
2448 // // Create some diagnostics
2449 // lsp_store
2450 // .update_diagnostics(
2451 // LanguageServerId(0),
2452 // lsp::PublishDiagnosticsParams {
2453 // uri: path_to_buffer_uri.clone(),
2454 // diagnostics: vec![diagnostic],
2455 // version: None,
2456 // },
2457 // None,
2458 // language::DiagnosticSourceKind::Pushed,
2459 // &[],
2460 // cx,
2461 // )
2462 // .unwrap();
2463 // });
2464 // });
2465
2466 // let buffer = project
2467 // .update(cx, |project, cx| {
2468 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2469 // project.open_buffer(path, cx)
2470 // })
2471 // .await
2472 // .unwrap();
2473
2474 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2475 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2476
2477 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2478 // zeta.request_prediction(&project, &buffer, position, cx)
2479 // });
2480
2481 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2482
2483 // assert_eq!(request.diagnostic_groups.len(), 1);
2484 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2485 // .unwrap();
2486 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2487 // assert_eq!(
2488 // value,
2489 // json!({
2490 // "entries": [{
2491 // "range": {
2492 // "start": 8,
2493 // "end": 10
2494 // },
2495 // "diagnostic": {
2496 // "source": null,
2497 // "code": null,
2498 // "code_description": null,
2499 // "severity": 1,
2500 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2501 // "markdown": null,
2502 // "group_id": 0,
2503 // "is_primary": true,
2504 // "is_disk_based": false,
2505 // "is_unnecessary": false,
2506 // "source_kind": "Pushed",
2507 // "data": null,
2508 // "underline": true
2509 // }
2510 // }],
2511 // "primary_ix": 0
2512 // })
2513 // );
2514 // }
2515
2516 fn model_response(text: &str) -> open_ai::Response {
2517 open_ai::Response {
2518 id: Uuid::new_v4().to_string(),
2519 object: "response".into(),
2520 created: 0,
2521 model: "model".into(),
2522 choices: vec![open_ai::Choice {
2523 index: 0,
2524 message: open_ai::RequestMessage::Assistant {
2525 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2526 tool_calls: vec![],
2527 },
2528 finish_reason: None,
2529 }],
2530 usage: Usage {
2531 prompt_tokens: 0,
2532 completion_tokens: 0,
2533 total_tokens: 0,
2534 },
2535 }
2536 }
2537
2538 fn prompt_from_request(request: &open_ai::Request) -> &str {
2539 assert_eq!(request.messages.len(), 1);
2540 let open_ai::RequestMessage::User {
2541 content: open_ai::MessageContent::Plain(content),
2542 ..
2543 } = &request.messages[0]
2544 else {
2545 panic!(
2546 "Request does not have single user message of type Plain. {:#?}",
2547 request
2548 );
2549 };
2550 content
2551 }
2552
2553 fn init_test(
2554 cx: &mut TestAppContext,
2555 ) -> (
2556 Entity<Zeta>,
2557 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2558 ) {
2559 cx.update(move |cx| {
2560 let settings_store = SettingsStore::test(cx);
2561 cx.set_global(settings_store);
2562 zlog::init_test();
2563
2564 let (req_tx, req_rx) = mpsc::unbounded();
2565
2566 let http_client = FakeHttpClient::create({
2567 move |req| {
2568 let uri = req.uri().path().to_string();
2569 let mut body = req.into_body();
2570 let req_tx = req_tx.clone();
2571 async move {
2572 let resp = match uri.as_str() {
2573 "/client/llm_tokens" => serde_json::to_string(&json!({
2574 "token": "test"
2575 }))
2576 .unwrap(),
2577 "/predict_edits/raw" => {
2578 let mut buf = Vec::new();
2579 body.read_to_end(&mut buf).await.ok();
2580 let req = serde_json::from_slice(&buf).unwrap();
2581
2582 let (res_tx, res_rx) = oneshot::channel();
2583 req_tx.unbounded_send((req, res_tx)).unwrap();
2584 serde_json::to_string(&res_rx.await?).unwrap()
2585 }
2586 _ => {
2587 panic!("Unexpected path: {}", uri)
2588 }
2589 };
2590
2591 Ok(Response::builder().body(resp.into()).unwrap())
2592 }
2593 }
2594 });
2595
2596 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2597 client.cloud_client().set_credentials(1, "test".into());
2598
2599 language_model::init(client.clone(), cx);
2600
2601 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2602 let zeta = Zeta::global(&client, &user_store, cx);
2603
2604 (zeta, req_rx)
2605 })
2606 }
2607}