1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::{Project, ProjectPath};
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::ops::Range;
35use std::path::Path;
36use std::str::FromStr as _;
37use std::sync::{Arc, LazyLock};
38use std::time::{Duration, Instant};
39use std::{env, mem};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod assemble_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49mod sweep_ai;
50pub mod udiff;
51mod xml_edits;
52
53use crate::assemble_excerpts::assemble_excerpts;
54pub use crate::prediction::EditPrediction;
55pub use crate::prediction::EditPredictionId;
56pub use provider::ZetaEditPredictionProvider;
57
58/// Maximum number of events to track.
59const EVENT_COUNT_MAX_SWEEP: usize = 6;
60const EVENT_COUNT_MAX_ZETA: usize = 16;
61const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
62
63pub struct SweepFeatureFlag;
64
65impl FeatureFlag for SweepFeatureFlag {
66 const NAME: &str = "sweep-ai";
67}
68pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
69 max_bytes: 512,
70 min_bytes: 128,
71 target_before_cursor_over_total_bytes: 0.5,
72};
73
74pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
75 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
76
77pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
78 excerpt: DEFAULT_EXCERPT_OPTIONS,
79};
80
81pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
82 EditPredictionContextOptions {
83 use_imports: true,
84 max_retrieved_declarations: 0,
85 excerpt: DEFAULT_EXCERPT_OPTIONS,
86 score: EditPredictionScoreOptions {
87 omit_excerpt_overlaps: true,
88 },
89 };
90
91pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
92 context: DEFAULT_CONTEXT_OPTIONS,
93 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
94 max_diagnostic_bytes: 2048,
95 prompt_format: PromptFormat::DEFAULT,
96 file_indexing_parallelism: 1,
97 buffer_change_grouping_interval: Duration::from_secs(1),
98};
99
100static USE_OLLAMA: LazyLock<bool> =
101 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
102static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
103 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
104 "qwen3-coder:30b".to_string()
105 } else {
106 "yqvev8r3".to_string()
107 })
108});
109static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
110 match env::var("ZED_ZETA2_MODEL").as_deref() {
111 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
112 Ok(model) => model,
113 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
114 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
115 }
116 .to_string()
117});
118static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
119 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
120 if *USE_OLLAMA {
121 Some("http://localhost:11434/v1/chat/completions".into())
122 } else {
123 None
124 }
125 })
126});
127
128pub struct Zeta2FeatureFlag;
129
130impl FeatureFlag for Zeta2FeatureFlag {
131 const NAME: &'static str = "zeta2";
132
133 fn enabled_for_staff() -> bool {
134 false
135 }
136}
137
138#[derive(Clone)]
139struct ZetaGlobal(Entity<Zeta>);
140
141impl Global for ZetaGlobal {}
142
143pub struct Zeta {
144 client: Arc<Client>,
145 user_store: Entity<UserStore>,
146 llm_token: LlmApiToken,
147 _llm_token_subscription: Subscription,
148 projects: HashMap<EntityId, ZetaProject>,
149 options: ZetaOptions,
150 update_required: bool,
151 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
152 #[cfg(feature = "eval-support")]
153 eval_cache: Option<Arc<dyn EvalCache>>,
154 edit_prediction_model: ZetaEditPredictionModel,
155 sweep_api_token: Option<String>,
156 sweep_ai_debug_info: Arc<str>,
157}
158
159#[derive(PartialEq, Eq)]
160pub enum ZetaEditPredictionModel {
161 ZedCloud,
162 Sweep,
163}
164
165#[derive(Debug, Clone, PartialEq)]
166pub struct ZetaOptions {
167 pub context: ContextMode,
168 pub max_prompt_bytes: usize,
169 pub max_diagnostic_bytes: usize,
170 pub prompt_format: predict_edits_v3::PromptFormat,
171 pub file_indexing_parallelism: usize,
172 pub buffer_change_grouping_interval: Duration,
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub enum ContextMode {
177 Agentic(AgenticContextOptions),
178 Syntax(EditPredictionContextOptions),
179}
180
181#[derive(Debug, Clone, PartialEq)]
182pub struct AgenticContextOptions {
183 pub excerpt: EditPredictionExcerptOptions,
184}
185
186impl ContextMode {
187 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
188 match self {
189 ContextMode::Agentic(options) => &options.excerpt,
190 ContextMode::Syntax(options) => &options.excerpt,
191 }
192 }
193}
194
195#[derive(Debug)]
196pub enum ZetaDebugInfo {
197 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
198 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
199 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
200 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
201 EditPredictionRequested(ZetaEditPredictionDebugInfo),
202}
203
204#[derive(Debug)]
205pub struct ZetaContextRetrievalStartedDebugInfo {
206 pub project: Entity<Project>,
207 pub timestamp: Instant,
208 pub search_prompt: String,
209}
210
211#[derive(Debug)]
212pub struct ZetaContextRetrievalDebugInfo {
213 pub project: Entity<Project>,
214 pub timestamp: Instant,
215}
216
217#[derive(Debug)]
218pub struct ZetaEditPredictionDebugInfo {
219 pub request: predict_edits_v3::PredictEditsRequest,
220 pub retrieval_time: TimeDelta,
221 pub buffer: WeakEntity<Buffer>,
222 pub position: language::Anchor,
223 pub local_prompt: Result<String, String>,
224 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
225}
226
227#[derive(Debug)]
228pub struct ZetaSearchQueryDebugInfo {
229 pub project: Entity<Project>,
230 pub timestamp: Instant,
231 pub search_queries: Vec<SearchToolQuery>,
232}
233
234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
235
236struct ZetaProject {
237 syntax_index: Option<Entity<SyntaxIndex>>,
238 events: VecDeque<Event>,
239 recent_paths: VecDeque<ProjectPath>,
240 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
241 current_prediction: Option<CurrentEditPrediction>,
242 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
243 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
244 refresh_context_debounce_task: Option<Task<Option<()>>>,
245 refresh_context_timestamp: Option<Instant>,
246 _subscription: gpui::Subscription,
247}
248
249#[derive(Debug, Clone)]
250struct CurrentEditPrediction {
251 pub requested_by_buffer_id: EntityId,
252 pub prediction: EditPrediction,
253}
254
255impl CurrentEditPrediction {
256 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
257 let Some(new_edits) = self
258 .prediction
259 .interpolate(&self.prediction.buffer.read(cx))
260 else {
261 return false;
262 };
263
264 if self.prediction.buffer != old_prediction.prediction.buffer {
265 return true;
266 }
267
268 let Some(old_edits) = old_prediction
269 .prediction
270 .interpolate(&old_prediction.prediction.buffer.read(cx))
271 else {
272 return true;
273 };
274
275 // This reduces the occurrence of UI thrash from replacing edits
276 //
277 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
278 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
279 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
280 && old_edits.len() == 1
281 && new_edits.len() == 1
282 {
283 let (old_range, old_text) = &old_edits[0];
284 let (new_range, new_text) = &new_edits[0];
285 new_range == old_range && new_text.starts_with(old_text.as_ref())
286 } else {
287 true
288 }
289 }
290}
291
292/// A prediction from the perspective of a buffer.
293#[derive(Debug)]
294enum BufferEditPrediction<'a> {
295 Local { prediction: &'a EditPrediction },
296 Jump { prediction: &'a EditPrediction },
297}
298
299struct RegisteredBuffer {
300 snapshot: BufferSnapshot,
301 _subscriptions: [gpui::Subscription; 2],
302}
303
304#[derive(Clone)]
305pub enum Event {
306 BufferChange {
307 old_snapshot: BufferSnapshot,
308 new_snapshot: BufferSnapshot,
309 end_edit_anchor: Option<Anchor>,
310 timestamp: Instant,
311 },
312}
313
314impl Event {
315 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
316 match self {
317 Event::BufferChange {
318 old_snapshot,
319 new_snapshot,
320 ..
321 } => {
322 let path = new_snapshot.file().map(|f| f.full_path(cx));
323
324 let old_path = old_snapshot.file().and_then(|f| {
325 let old_path = f.full_path(cx);
326 if Some(&old_path) != path.as_ref() {
327 Some(old_path)
328 } else {
329 None
330 }
331 });
332
333 // TODO [zeta2] move to bg?
334 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
335
336 if path == old_path && diff.is_empty() {
337 None
338 } else {
339 Some(predict_edits_v3::Event::BufferChange {
340 old_path,
341 path,
342 diff,
343 //todo: Actually detect if this edit was predicted or not
344 predicted: false,
345 })
346 }
347 }
348 }
349 }
350
351 pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
352 match self {
353 Event::BufferChange { new_snapshot, .. } => new_snapshot
354 .file()
355 .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
356 }
357 }
358}
359
360impl Zeta {
361 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
362 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
363 }
364
365 pub fn global(
366 client: &Arc<Client>,
367 user_store: &Entity<UserStore>,
368 cx: &mut App,
369 ) -> Entity<Self> {
370 cx.try_global::<ZetaGlobal>()
371 .map(|global| global.0.clone())
372 .unwrap_or_else(|| {
373 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
374 cx.set_global(ZetaGlobal(zeta.clone()));
375 zeta
376 })
377 }
378
379 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
380 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
381
382 Self {
383 projects: HashMap::default(),
384 client,
385 user_store,
386 options: DEFAULT_OPTIONS,
387 llm_token: LlmApiToken::default(),
388 _llm_token_subscription: cx.subscribe(
389 &refresh_llm_token_listener,
390 |this, _listener, _event, cx| {
391 let client = this.client.clone();
392 let llm_token = this.llm_token.clone();
393 cx.spawn(async move |_this, _cx| {
394 llm_token.refresh(&client).await?;
395 anyhow::Ok(())
396 })
397 .detach_and_log_err(cx);
398 },
399 ),
400 update_required: false,
401 debug_tx: None,
402 #[cfg(feature = "eval-support")]
403 eval_cache: None,
404 edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
405 sweep_api_token: None,
406 sweep_ai_debug_info: sweep_ai::debug_info(cx),
407 }
408 }
409
410 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
411 if model == ZetaEditPredictionModel::Sweep {
412 self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
413 .context("No SWEEP_AI_TOKEN environment variable set")
414 .log_err();
415 }
416 self.edit_prediction_model = model;
417 }
418
419 #[cfg(feature = "eval-support")]
420 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
421 self.eval_cache = Some(cache);
422 }
423
424 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
425 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
426 self.debug_tx = Some(debug_watch_tx);
427 debug_watch_rx
428 }
429
430 pub fn options(&self) -> &ZetaOptions {
431 &self.options
432 }
433
434 pub fn set_options(&mut self, options: ZetaOptions) {
435 self.options = options;
436 }
437
438 pub fn clear_history(&mut self) {
439 for zeta_project in self.projects.values_mut() {
440 zeta_project.events.clear();
441 }
442 }
443
444 pub fn history_for_project(
445 &self,
446 project: &Entity<Project>,
447 ) -> impl DoubleEndedIterator<Item = &Event> {
448 self.projects
449 .get(&project.entity_id())
450 .map(|project| project.events.iter())
451 .into_iter()
452 .flatten()
453 }
454
455 pub fn context_for_project(
456 &self,
457 project: &Entity<Project>,
458 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
459 self.projects
460 .get(&project.entity_id())
461 .and_then(|project| {
462 Some(
463 project
464 .context
465 .as_ref()?
466 .iter()
467 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
468 )
469 })
470 .into_iter()
471 .flatten()
472 }
473
474 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
475 self.user_store.read(cx).edit_prediction_usage()
476 }
477
478 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
479 self.get_or_init_zeta_project(project, cx);
480 }
481
482 pub fn register_buffer(
483 &mut self,
484 buffer: &Entity<Buffer>,
485 project: &Entity<Project>,
486 cx: &mut Context<Self>,
487 ) {
488 let zeta_project = self.get_or_init_zeta_project(project, cx);
489 Self::register_buffer_impl(zeta_project, buffer, project, cx);
490 }
491
492 fn get_or_init_zeta_project(
493 &mut self,
494 project: &Entity<Project>,
495 cx: &mut Context<Self>,
496 ) -> &mut ZetaProject {
497 self.projects
498 .entry(project.entity_id())
499 .or_insert_with(|| ZetaProject {
500 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
501 Some(cx.new(|cx| {
502 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
503 }))
504 } else {
505 None
506 },
507 events: VecDeque::new(),
508 recent_paths: VecDeque::new(),
509 registered_buffers: HashMap::default(),
510 current_prediction: None,
511 context: None,
512 refresh_context_task: None,
513 refresh_context_debounce_task: None,
514 refresh_context_timestamp: None,
515 _subscription: cx.subscribe(&project, |this, project, event, cx| {
516 // TODO [zeta2] init with recent paths
517 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
518 if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
519 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
520 if let Some(path) = path {
521 if let Some(ix) = zeta_project
522 .recent_paths
523 .iter()
524 .position(|probe| probe == &path)
525 {
526 zeta_project.recent_paths.remove(ix);
527 }
528 zeta_project.recent_paths.push_front(path);
529 }
530 }
531 }
532 }),
533 })
534 }
535
536 fn register_buffer_impl<'a>(
537 zeta_project: &'a mut ZetaProject,
538 buffer: &Entity<Buffer>,
539 project: &Entity<Project>,
540 cx: &mut Context<Self>,
541 ) -> &'a mut RegisteredBuffer {
542 let buffer_id = buffer.entity_id();
543 match zeta_project.registered_buffers.entry(buffer_id) {
544 hash_map::Entry::Occupied(entry) => entry.into_mut(),
545 hash_map::Entry::Vacant(entry) => {
546 let snapshot = buffer.read(cx).snapshot();
547 let project_entity_id = project.entity_id();
548 entry.insert(RegisteredBuffer {
549 snapshot,
550 _subscriptions: [
551 cx.subscribe(buffer, {
552 let project = project.downgrade();
553 move |this, buffer, event, cx| {
554 if let language::BufferEvent::Edited = event
555 && let Some(project) = project.upgrade()
556 {
557 this.report_changes_for_buffer(&buffer, &project, cx);
558 }
559 }
560 }),
561 cx.observe_release(buffer, move |this, _buffer, _cx| {
562 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
563 else {
564 return;
565 };
566 zeta_project.registered_buffers.remove(&buffer_id);
567 }),
568 ],
569 })
570 }
571 }
572 }
573
574 fn report_changes_for_buffer(
575 &mut self,
576 buffer: &Entity<Buffer>,
577 project: &Entity<Project>,
578 cx: &mut Context<Self>,
579 ) {
580 let event_count_max = match self.edit_prediction_model {
581 ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
582 ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
583 };
584
585 let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
586 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
587
588 let new_snapshot = buffer.read(cx).snapshot();
589 if new_snapshot.version == registered_buffer.snapshot.version {
590 return;
591 }
592
593 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
594 let end_edit_anchor = new_snapshot
595 .anchored_edits_since::<Point>(&old_snapshot.version)
596 .last()
597 .map(|(_, range)| range.end);
598 let events = &mut sweep_ai_project.events;
599
600 if let Some(Event::BufferChange {
601 new_snapshot: last_new_snapshot,
602 end_edit_anchor: last_end_edit_anchor,
603 ..
604 }) = events.back_mut()
605 {
606 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
607 == last_new_snapshot.remote_id()
608 && old_snapshot.version == last_new_snapshot.version;
609
610 let should_coalesce = is_next_snapshot_of_same_buffer
611 && end_edit_anchor
612 .as_ref()
613 .zip(last_end_edit_anchor.as_ref())
614 .is_some_and(|(a, b)| {
615 let a = a.to_point(&new_snapshot);
616 let b = b.to_point(&new_snapshot);
617 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
618 });
619
620 if should_coalesce {
621 *last_end_edit_anchor = end_edit_anchor;
622 *last_new_snapshot = new_snapshot;
623 return;
624 }
625 }
626
627 if events.len() >= event_count_max {
628 events.pop_front();
629 }
630
631 events.push_back(Event::BufferChange {
632 old_snapshot,
633 new_snapshot,
634 end_edit_anchor,
635 timestamp: Instant::now(),
636 });
637 }
638
639 fn current_prediction_for_buffer(
640 &self,
641 buffer: &Entity<Buffer>,
642 project: &Entity<Project>,
643 cx: &App,
644 ) -> Option<BufferEditPrediction<'_>> {
645 let project_state = self.projects.get(&project.entity_id())?;
646
647 let CurrentEditPrediction {
648 requested_by_buffer_id,
649 prediction,
650 } = project_state.current_prediction.as_ref()?;
651
652 if prediction.targets_buffer(buffer.read(cx)) {
653 Some(BufferEditPrediction::Local { prediction })
654 } else if *requested_by_buffer_id == buffer.entity_id() {
655 Some(BufferEditPrediction::Jump { prediction })
656 } else {
657 None
658 }
659 }
660
661 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
662 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
663 return;
664 };
665
666 let Some(prediction) = project_state.current_prediction.take() else {
667 return;
668 };
669 let request_id = prediction.prediction.id.to_string();
670
671 let client = self.client.clone();
672 let llm_token = self.llm_token.clone();
673 let app_version = AppVersion::global(cx);
674 cx.spawn(async move |this, cx| {
675 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
676 http_client::Url::parse(&predict_edits_url)?
677 } else {
678 client
679 .http_client()
680 .build_zed_llm_url("/predict_edits/accept", &[])?
681 };
682
683 let response = cx
684 .background_spawn(Self::send_api_request::<()>(
685 move |builder| {
686 let req = builder.uri(url.as_ref()).body(
687 serde_json::to_string(&AcceptEditPredictionBody {
688 request_id: request_id.clone(),
689 })?
690 .into(),
691 );
692 Ok(req?)
693 },
694 client,
695 llm_token,
696 app_version,
697 ))
698 .await;
699
700 Self::handle_api_response(&this, response, cx)?;
701 anyhow::Ok(())
702 })
703 .detach_and_log_err(cx);
704 }
705
706 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
707 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
708 project_state.current_prediction.take();
709 };
710 }
711
712 pub fn refresh_prediction(
713 &mut self,
714 project: &Entity<Project>,
715 buffer: &Entity<Buffer>,
716 position: language::Anchor,
717 cx: &mut Context<Self>,
718 ) -> Task<Result<()>> {
719 let request_task = self.request_prediction(project, buffer, position, cx);
720 let buffer = buffer.clone();
721 let project = project.clone();
722
723 cx.spawn(async move |this, cx| {
724 if let Some(prediction) = request_task.await? {
725 this.update(cx, |this, cx| {
726 let project_state = this
727 .projects
728 .get_mut(&project.entity_id())
729 .context("Project not found")?;
730
731 let new_prediction = CurrentEditPrediction {
732 requested_by_buffer_id: buffer.entity_id(),
733 prediction: prediction,
734 };
735
736 if project_state
737 .current_prediction
738 .as_ref()
739 .is_none_or(|old_prediction| {
740 new_prediction.should_replace_prediction(&old_prediction, cx)
741 })
742 {
743 project_state.current_prediction = Some(new_prediction);
744 }
745 anyhow::Ok(())
746 })??;
747 }
748 Ok(())
749 })
750 }
751
752 pub fn request_prediction(
753 &mut self,
754 project: &Entity<Project>,
755 active_buffer: &Entity<Buffer>,
756 position: language::Anchor,
757 cx: &mut Context<Self>,
758 ) -> Task<Result<Option<EditPrediction>>> {
759 match self.edit_prediction_model {
760 ZetaEditPredictionModel::ZedCloud => {
761 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
762 }
763 ZetaEditPredictionModel::Sweep => {
764 self.request_prediction_with_sweep(project, active_buffer, position, cx)
765 }
766 }
767 }
768
769 fn request_prediction_with_sweep(
770 &mut self,
771 project: &Entity<Project>,
772 active_buffer: &Entity<Buffer>,
773 position: language::Anchor,
774 cx: &mut Context<Self>,
775 ) -> Task<Result<Option<EditPrediction>>> {
776 let snapshot = active_buffer.read(cx).snapshot();
777 let debug_info = self.sweep_ai_debug_info.clone();
778 let Some(api_token) = self.sweep_api_token.clone() else {
779 return Task::ready(Ok(None));
780 };
781 let full_path: Arc<Path> = snapshot
782 .file()
783 .map(|file| file.full_path(cx))
784 .unwrap_or_else(|| "untitled".into())
785 .into();
786
787 let project_file = project::File::from_dyn(snapshot.file());
788 let repo_name = project_file
789 .map(|file| file.worktree.read(cx).root_name_str())
790 .unwrap_or("untitled")
791 .into();
792 let offset = position.to_offset(&snapshot);
793
794 let project_state = self.get_or_init_zeta_project(project, cx);
795 let events = project_state.events.clone();
796 let recent_buffers = project_state.recent_paths.iter().cloned();
797 let http_client = cx.http_client();
798
799 let recent_buffer_snapshots = recent_buffers
800 .filter_map(|project_path| {
801 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
802 if active_buffer == &buffer {
803 None
804 } else {
805 Some(buffer.read(cx).snapshot())
806 }
807 })
808 .take(3)
809 .collect::<Vec<_>>();
810
811 let result = cx.background_spawn(async move {
812 let text = snapshot.text();
813
814 let mut recent_changes = String::new();
815 for event in events {
816 sweep_ai::write_event(event, &mut recent_changes).unwrap();
817 }
818
819 let file_chunks = recent_buffer_snapshots
820 .into_iter()
821 .map(|snapshot| {
822 let end_point = language::Point::new(30, 0).min(snapshot.max_point());
823 sweep_ai::FileChunk {
824 content: snapshot
825 .text_for_range(language::Point::zero()..end_point)
826 .collect(),
827 file_path: snapshot
828 .file()
829 .map(|f| f.path().as_unix_str())
830 .unwrap_or("untitled")
831 .to_string(),
832 start_line: 0,
833 end_line: end_point.row as usize,
834 timestamp: snapshot.file().and_then(|file| {
835 Some(
836 file.disk_state()
837 .mtime()?
838 .to_seconds_and_nanos_for_persistence()?
839 .0,
840 )
841 }),
842 }
843 })
844 .collect();
845
846 let request_body = sweep_ai::AutocompleteRequest {
847 debug_info,
848 repo_name,
849 file_path: full_path.clone(),
850 file_contents: text.clone(),
851 original_file_contents: text,
852 cursor_position: offset,
853 recent_changes: recent_changes.clone(),
854 changes_above_cursor: true,
855 multiple_suggestions: false,
856 branch: None,
857 file_chunks,
858 retrieval_chunks: vec![],
859 recent_user_actions: vec![],
860 // TODO
861 privacy_mode_enabled: false,
862 };
863
864 let mut buf: Vec<u8> = Vec::new();
865 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
866 serde_json::to_writer(writer, &request_body)?;
867 let body: AsyncBody = buf.into();
868
869 const SWEEP_API_URL: &str =
870 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
871
872 let request = http_client::Request::builder()
873 .uri(SWEEP_API_URL)
874 .header("Content-Type", "application/json")
875 .header("Authorization", format!("Bearer {}", api_token))
876 .header("Connection", "keep-alive")
877 .header("Content-Encoding", "br")
878 .method(Method::POST)
879 .body(body)?;
880
881 let mut response = http_client.send(request).await?;
882
883 let mut body: Vec<u8> = Vec::new();
884 response.body_mut().read_to_end(&mut body).await?;
885
886 if !response.status().is_success() {
887 anyhow::bail!(
888 "Request failed with status: {:?}\nBody: {}",
889 response.status(),
890 String::from_utf8_lossy(&body),
891 );
892 };
893
894 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
895
896 let old_text = snapshot
897 .text_for_range(response.start_index..response.end_index)
898 .collect::<String>();
899 let edits = language::text_diff(&old_text, &response.completion)
900 .into_iter()
901 .map(|(range, text)| {
902 (
903 snapshot.anchor_after(response.start_index + range.start)
904 ..snapshot.anchor_before(response.start_index + range.end),
905 text,
906 )
907 })
908 .collect::<Vec<_>>();
909
910 anyhow::Ok((response.autocomplete_id, edits, snapshot))
911 });
912
913 let buffer = active_buffer.clone();
914
915 cx.spawn(async move |_, cx| {
916 let (id, edits, old_snapshot) = result.await?;
917
918 if edits.is_empty() {
919 return anyhow::Ok(None);
920 }
921
922 let Some((edits, new_snapshot, preview_task)) =
923 buffer.read_with(cx, |buffer, cx| {
924 let new_snapshot = buffer.snapshot();
925
926 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
927 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
928 .into();
929 let preview_task = buffer.preview_edits(edits.clone(), cx);
930
931 Some((edits, new_snapshot, preview_task))
932 })?
933 else {
934 return anyhow::Ok(None);
935 };
936
937 let prediction = EditPrediction {
938 id: EditPredictionId(id.into()),
939 edits,
940 snapshot: new_snapshot,
941 edit_preview: preview_task.await,
942 buffer,
943 };
944
945 anyhow::Ok(Some(prediction))
946 })
947 }
948
949 fn request_prediction_with_zed_cloud(
950 &mut self,
951 project: &Entity<Project>,
952 active_buffer: &Entity<Buffer>,
953 position: language::Anchor,
954 cx: &mut Context<Self>,
955 ) -> Task<Result<Option<EditPrediction>>> {
956 let project_state = self.projects.get(&project.entity_id());
957
958 let index_state = project_state.and_then(|state| {
959 state
960 .syntax_index
961 .as_ref()
962 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
963 });
964 let options = self.options.clone();
965 let active_snapshot = active_buffer.read(cx).snapshot();
966 let Some(excerpt_path) = active_snapshot
967 .file()
968 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
969 else {
970 return Task::ready(Err(anyhow!("No file path for excerpt")));
971 };
972 let client = self.client.clone();
973 let llm_token = self.llm_token.clone();
974 let app_version = AppVersion::global(cx);
975 let worktree_snapshots = project
976 .read(cx)
977 .worktrees(cx)
978 .map(|worktree| worktree.read(cx).snapshot())
979 .collect::<Vec<_>>();
980 let debug_tx = self.debug_tx.clone();
981
982 let events = project_state
983 .map(|state| {
984 state
985 .events
986 .iter()
987 .filter_map(|event| event.to_request_event(cx))
988 .collect::<Vec<_>>()
989 })
990 .unwrap_or_default();
991
992 let diagnostics = active_snapshot.diagnostic_sets().clone();
993
994 let parent_abs_path =
995 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
996 let mut path = f.worktree.read(cx).absolutize(&f.path);
997 if path.pop() { Some(path) } else { None }
998 });
999
1000 // TODO data collection
1001 let can_collect_data = cx.is_staff();
1002
1003 let empty_context_files = HashMap::default();
1004 let context_files = project_state
1005 .and_then(|project_state| project_state.context.as_ref())
1006 .unwrap_or(&empty_context_files);
1007
1008 #[cfg(feature = "eval-support")]
1009 let parsed_fut = futures::future::join_all(
1010 context_files
1011 .keys()
1012 .map(|buffer| buffer.read(cx).parsing_idle()),
1013 );
1014
1015 let mut included_files = context_files
1016 .iter()
1017 .filter_map(|(buffer_entity, ranges)| {
1018 let buffer = buffer_entity.read(cx);
1019 Some((
1020 buffer_entity.clone(),
1021 buffer.snapshot(),
1022 buffer.file()?.full_path(cx).into(),
1023 ranges.clone(),
1024 ))
1025 })
1026 .collect::<Vec<_>>();
1027
1028 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1029 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1030 });
1031
1032 #[cfg(feature = "eval-support")]
1033 let eval_cache = self.eval_cache.clone();
1034
1035 let request_task = cx.background_spawn({
1036 let active_buffer = active_buffer.clone();
1037 async move {
1038 #[cfg(feature = "eval-support")]
1039 parsed_fut.await;
1040
1041 let index_state = if let Some(index_state) = index_state {
1042 Some(index_state.lock_owned().await)
1043 } else {
1044 None
1045 };
1046
1047 let cursor_offset = position.to_offset(&active_snapshot);
1048 let cursor_point = cursor_offset.to_point(&active_snapshot);
1049
1050 let before_retrieval = chrono::Utc::now();
1051
1052 let (diagnostic_groups, diagnostic_groups_truncated) =
1053 Self::gather_nearby_diagnostics(
1054 cursor_offset,
1055 &diagnostics,
1056 &active_snapshot,
1057 options.max_diagnostic_bytes,
1058 );
1059
1060 let cloud_request = match options.context {
1061 ContextMode::Agentic(context_options) => {
1062 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1063 cursor_point,
1064 &active_snapshot,
1065 &context_options.excerpt,
1066 index_state.as_deref(),
1067 ) else {
1068 return Ok((None, None));
1069 };
1070
1071 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1072 ..active_snapshot.anchor_before(excerpt.range.end);
1073
1074 if let Some(buffer_ix) =
1075 included_files.iter().position(|(_, snapshot, _, _)| {
1076 snapshot.remote_id() == active_snapshot.remote_id()
1077 })
1078 {
1079 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1080 ranges.push(excerpt_anchor_range);
1081 retrieval_search::merge_anchor_ranges(ranges, buffer);
1082 let last_ix = included_files.len() - 1;
1083 included_files.swap(buffer_ix, last_ix);
1084 } else {
1085 included_files.push((
1086 active_buffer.clone(),
1087 active_snapshot.clone(),
1088 excerpt_path.clone(),
1089 vec![excerpt_anchor_range],
1090 ));
1091 }
1092
1093 let included_files = included_files
1094 .iter()
1095 .map(|(_, snapshot, path, ranges)| {
1096 let ranges = ranges
1097 .iter()
1098 .map(|range| {
1099 let point_range = range.to_point(&snapshot);
1100 Line(point_range.start.row)..Line(point_range.end.row)
1101 })
1102 .collect::<Vec<_>>();
1103 let excerpts = assemble_excerpts(&snapshot, ranges);
1104 predict_edits_v3::IncludedFile {
1105 path: path.clone(),
1106 max_row: Line(snapshot.max_point().row),
1107 excerpts,
1108 }
1109 })
1110 .collect::<Vec<_>>();
1111
1112 predict_edits_v3::PredictEditsRequest {
1113 excerpt_path,
1114 excerpt: String::new(),
1115 excerpt_line_range: Line(0)..Line(0),
1116 excerpt_range: 0..0,
1117 cursor_point: predict_edits_v3::Point {
1118 line: predict_edits_v3::Line(cursor_point.row),
1119 column: cursor_point.column,
1120 },
1121 included_files,
1122 referenced_declarations: vec![],
1123 events,
1124 can_collect_data,
1125 diagnostic_groups,
1126 diagnostic_groups_truncated,
1127 debug_info: debug_tx.is_some(),
1128 prompt_max_bytes: Some(options.max_prompt_bytes),
1129 prompt_format: options.prompt_format,
1130 // TODO [zeta2]
1131 signatures: vec![],
1132 excerpt_parent: None,
1133 git_info: None,
1134 }
1135 }
1136 ContextMode::Syntax(context_options) => {
1137 let Some(context) = EditPredictionContext::gather_context(
1138 cursor_point,
1139 &active_snapshot,
1140 parent_abs_path.as_deref(),
1141 &context_options,
1142 index_state.as_deref(),
1143 ) else {
1144 return Ok((None, None));
1145 };
1146
1147 make_syntax_context_cloud_request(
1148 excerpt_path,
1149 context,
1150 events,
1151 can_collect_data,
1152 diagnostic_groups,
1153 diagnostic_groups_truncated,
1154 None,
1155 debug_tx.is_some(),
1156 &worktree_snapshots,
1157 index_state.as_deref(),
1158 Some(options.max_prompt_bytes),
1159 options.prompt_format,
1160 )
1161 }
1162 };
1163
1164 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1165
1166 let retrieval_time = chrono::Utc::now() - before_retrieval;
1167
1168 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1169 let (response_tx, response_rx) = oneshot::channel();
1170
1171 debug_tx
1172 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1173 ZetaEditPredictionDebugInfo {
1174 request: cloud_request.clone(),
1175 retrieval_time,
1176 buffer: active_buffer.downgrade(),
1177 local_prompt: match prompt_result.as_ref() {
1178 Ok((prompt, _)) => Ok(prompt.clone()),
1179 Err(err) => Err(err.to_string()),
1180 },
1181 position,
1182 response_rx,
1183 },
1184 ))
1185 .ok();
1186 Some(response_tx)
1187 } else {
1188 None
1189 };
1190
1191 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1192 if let Some(debug_response_tx) = debug_response_tx {
1193 debug_response_tx
1194 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1195 .ok();
1196 }
1197 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1198 }
1199
1200 let (prompt, _) = prompt_result?;
1201 let request = open_ai::Request {
1202 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1203 messages: vec![open_ai::RequestMessage::User {
1204 content: open_ai::MessageContent::Plain(prompt),
1205 }],
1206 stream: false,
1207 max_completion_tokens: None,
1208 stop: Default::default(),
1209 temperature: 0.7,
1210 tool_choice: None,
1211 parallel_tool_calls: None,
1212 tools: vec![],
1213 prompt_cache_key: None,
1214 reasoning_effort: None,
1215 };
1216
1217 log::trace!("Sending edit prediction request");
1218
1219 let before_request = chrono::Utc::now();
1220 let response = Self::send_raw_llm_request(
1221 request,
1222 client,
1223 llm_token,
1224 app_version,
1225 #[cfg(feature = "eval-support")]
1226 eval_cache,
1227 #[cfg(feature = "eval-support")]
1228 EvalCacheEntryKind::Prediction,
1229 )
1230 .await;
1231 let request_time = chrono::Utc::now() - before_request;
1232
1233 log::trace!("Got edit prediction response");
1234
1235 if let Some(debug_response_tx) = debug_response_tx {
1236 debug_response_tx
1237 .send((
1238 response
1239 .as_ref()
1240 .map_err(|err| err.to_string())
1241 .map(|response| response.0.clone()),
1242 request_time,
1243 ))
1244 .ok();
1245 }
1246
1247 let (res, usage) = response?;
1248 let request_id = EditPredictionId(res.id.clone().into());
1249 let Some(mut output_text) = text_from_response(res) else {
1250 return Ok((None, usage));
1251 };
1252
1253 if output_text.contains(CURSOR_MARKER) {
1254 log::trace!("Stripping out {CURSOR_MARKER} from response");
1255 output_text = output_text.replace(CURSOR_MARKER, "");
1256 }
1257
1258 let get_buffer_from_context = |path: &Path| {
1259 included_files
1260 .iter()
1261 .find_map(|(_, buffer, probe_path, ranges)| {
1262 if probe_path.as_ref() == path {
1263 Some((buffer, ranges.as_slice()))
1264 } else {
1265 None
1266 }
1267 })
1268 };
1269
1270 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1271 PromptFormat::NumLinesUniDiff => {
1272 // TODO: Implement parsing of multi-file diffs
1273 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1274 }
1275 PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1276 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1277 let edits = vec![];
1278 (&active_snapshot, edits)
1279 } else {
1280 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1281 }
1282 }
1283 PromptFormat::OldTextNewText => {
1284 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1285 .await?
1286 }
1287 _ => {
1288 bail!("unsupported prompt format {}", options.prompt_format)
1289 }
1290 };
1291
1292 let edited_buffer = included_files
1293 .iter()
1294 .find_map(|(buffer, snapshot, _, _)| {
1295 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1296 Some(buffer.clone())
1297 } else {
1298 None
1299 }
1300 })
1301 .context("Failed to find buffer in included_buffers")?;
1302
1303 anyhow::Ok((
1304 Some((
1305 request_id,
1306 edited_buffer,
1307 edited_buffer_snapshot.clone(),
1308 edits,
1309 )),
1310 usage,
1311 ))
1312 }
1313 });
1314
1315 cx.spawn({
1316 async move |this, cx| {
1317 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1318 Self::handle_api_response(&this, request_task.await, cx)?
1319 else {
1320 return Ok(None);
1321 };
1322
1323 // TODO telemetry: duration, etc
1324 Ok(
1325 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1326 .await,
1327 )
1328 }
1329 })
1330 }
1331
1332 async fn send_raw_llm_request(
1333 request: open_ai::Request,
1334 client: Arc<Client>,
1335 llm_token: LlmApiToken,
1336 app_version: SemanticVersion,
1337 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1338 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1339 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1340 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1341 http_client::Url::parse(&predict_edits_url)?
1342 } else {
1343 client
1344 .http_client()
1345 .build_zed_llm_url("/predict_edits/raw", &[])?
1346 };
1347
1348 #[cfg(feature = "eval-support")]
1349 let cache_key = if let Some(cache) = eval_cache {
1350 use collections::FxHasher;
1351 use std::hash::{Hash, Hasher};
1352
1353 let mut hasher = FxHasher::default();
1354 url.hash(&mut hasher);
1355 let request_str = serde_json::to_string_pretty(&request)?;
1356 request_str.hash(&mut hasher);
1357 let hash = hasher.finish();
1358
1359 let key = (eval_cache_kind, hash);
1360 if let Some(response_str) = cache.read(key) {
1361 return Ok((serde_json::from_str(&response_str)?, None));
1362 }
1363
1364 Some((cache, request_str, key))
1365 } else {
1366 None
1367 };
1368
1369 let (response, usage) = Self::send_api_request(
1370 |builder| {
1371 let req = builder
1372 .uri(url.as_ref())
1373 .body(serde_json::to_string(&request)?.into());
1374 Ok(req?)
1375 },
1376 client,
1377 llm_token,
1378 app_version,
1379 )
1380 .await?;
1381
1382 #[cfg(feature = "eval-support")]
1383 if let Some((cache, request, key)) = cache_key {
1384 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1385 }
1386
1387 Ok((response, usage))
1388 }
1389
1390 fn handle_api_response<T>(
1391 this: &WeakEntity<Self>,
1392 response: Result<(T, Option<EditPredictionUsage>)>,
1393 cx: &mut gpui::AsyncApp,
1394 ) -> Result<T> {
1395 match response {
1396 Ok((data, usage)) => {
1397 if let Some(usage) = usage {
1398 this.update(cx, |this, cx| {
1399 this.user_store.update(cx, |user_store, cx| {
1400 user_store.update_edit_prediction_usage(usage, cx);
1401 });
1402 })
1403 .ok();
1404 }
1405 Ok(data)
1406 }
1407 Err(err) => {
1408 if err.is::<ZedUpdateRequiredError>() {
1409 cx.update(|cx| {
1410 this.update(cx, |this, _cx| {
1411 this.update_required = true;
1412 })
1413 .ok();
1414
1415 let error_message: SharedString = err.to_string().into();
1416 show_app_notification(
1417 NotificationId::unique::<ZedUpdateRequiredError>(),
1418 cx,
1419 move |cx| {
1420 cx.new(|cx| {
1421 ErrorMessagePrompt::new(error_message.clone(), cx)
1422 .with_link_button("Update Zed", "https://zed.dev/releases")
1423 })
1424 },
1425 );
1426 })
1427 .ok();
1428 }
1429 Err(err)
1430 }
1431 }
1432 }
1433
1434 async fn send_api_request<Res>(
1435 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1436 client: Arc<Client>,
1437 llm_token: LlmApiToken,
1438 app_version: SemanticVersion,
1439 ) -> Result<(Res, Option<EditPredictionUsage>)>
1440 where
1441 Res: DeserializeOwned,
1442 {
1443 let http_client = client.http_client();
1444 let mut token = llm_token.acquire(&client).await?;
1445 let mut did_retry = false;
1446
1447 loop {
1448 let request_builder = http_client::Request::builder().method(Method::POST);
1449
1450 let request = build(
1451 request_builder
1452 .header("Content-Type", "application/json")
1453 .header("Authorization", format!("Bearer {}", token))
1454 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1455 )?;
1456
1457 let mut response = http_client.send(request).await?;
1458
1459 if let Some(minimum_required_version) = response
1460 .headers()
1461 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1462 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1463 {
1464 anyhow::ensure!(
1465 app_version >= minimum_required_version,
1466 ZedUpdateRequiredError {
1467 minimum_version: minimum_required_version
1468 }
1469 );
1470 }
1471
1472 if response.status().is_success() {
1473 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1474
1475 let mut body = Vec::new();
1476 response.body_mut().read_to_end(&mut body).await?;
1477 return Ok((serde_json::from_slice(&body)?, usage));
1478 } else if !did_retry
1479 && response
1480 .headers()
1481 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1482 .is_some()
1483 {
1484 did_retry = true;
1485 token = llm_token.refresh(&client).await?;
1486 } else {
1487 let mut body = String::new();
1488 response.body_mut().read_to_string(&mut body).await?;
1489 anyhow::bail!(
1490 "Request failed with status: {:?}\nBody: {}",
1491 response.status(),
1492 body
1493 );
1494 }
1495 }
1496 }
1497
1498 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1499 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1500
1501 // Refresh the related excerpts when the user just beguns editing after
1502 // an idle period, and after they pause editing.
1503 fn refresh_context_if_needed(
1504 &mut self,
1505 project: &Entity<Project>,
1506 buffer: &Entity<language::Buffer>,
1507 cursor_position: language::Anchor,
1508 cx: &mut Context<Self>,
1509 ) {
1510 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1511 return;
1512 }
1513
1514 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1515 return;
1516 };
1517
1518 let now = Instant::now();
1519 let was_idle = zeta_project
1520 .refresh_context_timestamp
1521 .map_or(true, |timestamp| {
1522 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1523 });
1524 zeta_project.refresh_context_timestamp = Some(now);
1525 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1526 let buffer = buffer.clone();
1527 let project = project.clone();
1528 async move |this, cx| {
1529 if was_idle {
1530 log::debug!("refetching edit prediction context after idle");
1531 } else {
1532 cx.background_executor()
1533 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1534 .await;
1535 log::debug!("refetching edit prediction context after pause");
1536 }
1537 this.update(cx, |this, cx| {
1538 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1539
1540 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1541 zeta_project.refresh_context_task = Some(task.log_err());
1542 };
1543 })
1544 .ok()
1545 }
1546 }));
1547 }
1548
1549 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1550 // and avoid spawning more than one concurrent task.
1551 pub fn refresh_context(
1552 &mut self,
1553 project: Entity<Project>,
1554 buffer: Entity<language::Buffer>,
1555 cursor_position: language::Anchor,
1556 cx: &mut Context<Self>,
1557 ) -> Task<Result<()>> {
1558 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1559 return Task::ready(anyhow::Ok(()));
1560 };
1561
1562 let ContextMode::Agentic(options) = &self.options().context else {
1563 return Task::ready(anyhow::Ok(()));
1564 };
1565
1566 let snapshot = buffer.read(cx).snapshot();
1567 let cursor_point = cursor_position.to_point(&snapshot);
1568 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1569 cursor_point,
1570 &snapshot,
1571 &options.excerpt,
1572 None,
1573 ) else {
1574 return Task::ready(Ok(()));
1575 };
1576
1577 let app_version = AppVersion::global(cx);
1578 let client = self.client.clone();
1579 let llm_token = self.llm_token.clone();
1580 let debug_tx = self.debug_tx.clone();
1581 let current_file_path: Arc<Path> = snapshot
1582 .file()
1583 .map(|f| f.full_path(cx).into())
1584 .unwrap_or_else(|| Path::new("untitled").into());
1585
1586 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1587 predict_edits_v3::PlanContextRetrievalRequest {
1588 excerpt: cursor_excerpt.text(&snapshot).body,
1589 excerpt_path: current_file_path,
1590 excerpt_line_range: cursor_excerpt.line_range,
1591 cursor_file_max_row: Line(snapshot.max_point().row),
1592 events: zeta_project
1593 .events
1594 .iter()
1595 .filter_map(|ev| ev.to_request_event(cx))
1596 .collect(),
1597 },
1598 ) {
1599 Ok(prompt) => prompt,
1600 Err(err) => {
1601 return Task::ready(Err(err));
1602 }
1603 };
1604
1605 if let Some(debug_tx) = &debug_tx {
1606 debug_tx
1607 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1608 ZetaContextRetrievalStartedDebugInfo {
1609 project: project.clone(),
1610 timestamp: Instant::now(),
1611 search_prompt: prompt.clone(),
1612 },
1613 ))
1614 .ok();
1615 }
1616
1617 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1618 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1619 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1620 );
1621
1622 let description = schema
1623 .get("description")
1624 .and_then(|description| description.as_str())
1625 .unwrap()
1626 .to_string();
1627
1628 (schema.into(), description)
1629 });
1630
1631 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1632
1633 let request = open_ai::Request {
1634 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1635 messages: vec![open_ai::RequestMessage::User {
1636 content: open_ai::MessageContent::Plain(prompt),
1637 }],
1638 stream: false,
1639 max_completion_tokens: None,
1640 stop: Default::default(),
1641 temperature: 0.7,
1642 tool_choice: None,
1643 parallel_tool_calls: None,
1644 tools: vec![open_ai::ToolDefinition::Function {
1645 function: FunctionDefinition {
1646 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1647 description: Some(tool_description),
1648 parameters: Some(tool_schema),
1649 },
1650 }],
1651 prompt_cache_key: None,
1652 reasoning_effort: None,
1653 };
1654
1655 #[cfg(feature = "eval-support")]
1656 let eval_cache = self.eval_cache.clone();
1657
1658 cx.spawn(async move |this, cx| {
1659 log::trace!("Sending search planning request");
1660 let response = Self::send_raw_llm_request(
1661 request,
1662 client,
1663 llm_token,
1664 app_version,
1665 #[cfg(feature = "eval-support")]
1666 eval_cache.clone(),
1667 #[cfg(feature = "eval-support")]
1668 EvalCacheEntryKind::Context,
1669 )
1670 .await;
1671 let mut response = Self::handle_api_response(&this, response, cx)?;
1672 log::trace!("Got search planning response");
1673
1674 let choice = response
1675 .choices
1676 .pop()
1677 .context("No choices in retrieval response")?;
1678 let open_ai::RequestMessage::Assistant {
1679 content: _,
1680 tool_calls,
1681 } = choice.message
1682 else {
1683 anyhow::bail!("Retrieval response didn't include an assistant message");
1684 };
1685
1686 let mut queries: Vec<SearchToolQuery> = Vec::new();
1687 for tool_call in tool_calls {
1688 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1689 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1690 log::warn!(
1691 "Context retrieval response tried to call an unknown tool: {}",
1692 function.name
1693 );
1694
1695 continue;
1696 }
1697
1698 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1699 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1700 queries.extend(input.queries);
1701 }
1702
1703 if let Some(debug_tx) = &debug_tx {
1704 debug_tx
1705 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1706 ZetaSearchQueryDebugInfo {
1707 project: project.clone(),
1708 timestamp: Instant::now(),
1709 search_queries: queries.clone(),
1710 },
1711 ))
1712 .ok();
1713 }
1714
1715 log::trace!("Running retrieval search: {queries:#?}");
1716
1717 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1718 queries,
1719 project.clone(),
1720 #[cfg(feature = "eval-support")]
1721 eval_cache,
1722 cx,
1723 )
1724 .await;
1725
1726 log::trace!("Search queries executed");
1727
1728 if let Some(debug_tx) = &debug_tx {
1729 debug_tx
1730 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1731 ZetaContextRetrievalDebugInfo {
1732 project: project.clone(),
1733 timestamp: Instant::now(),
1734 },
1735 ))
1736 .ok();
1737 }
1738
1739 this.update(cx, |this, _cx| {
1740 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1741 return Ok(());
1742 };
1743 zeta_project.refresh_context_task.take();
1744 if let Some(debug_tx) = &this.debug_tx {
1745 debug_tx
1746 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1747 ZetaContextRetrievalDebugInfo {
1748 project,
1749 timestamp: Instant::now(),
1750 },
1751 ))
1752 .ok();
1753 }
1754 match related_excerpts_result {
1755 Ok(excerpts) => {
1756 zeta_project.context = Some(excerpts);
1757 Ok(())
1758 }
1759 Err(error) => Err(error),
1760 }
1761 })?
1762 })
1763 }
1764
1765 pub fn set_context(
1766 &mut self,
1767 project: Entity<Project>,
1768 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1769 ) {
1770 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1771 zeta_project.context = Some(context);
1772 }
1773 }
1774
1775 fn gather_nearby_diagnostics(
1776 cursor_offset: usize,
1777 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1778 snapshot: &BufferSnapshot,
1779 max_diagnostics_bytes: usize,
1780 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1781 // TODO: Could make this more efficient
1782 let mut diagnostic_groups = Vec::new();
1783 for (language_server_id, diagnostics) in diagnostic_sets {
1784 let mut groups = Vec::new();
1785 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1786 diagnostic_groups.extend(
1787 groups
1788 .into_iter()
1789 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1790 );
1791 }
1792
1793 // sort by proximity to cursor
1794 diagnostic_groups.sort_by_key(|group| {
1795 let range = &group.entries[group.primary_ix].range;
1796 if range.start >= cursor_offset {
1797 range.start - cursor_offset
1798 } else if cursor_offset >= range.end {
1799 cursor_offset - range.end
1800 } else {
1801 (cursor_offset - range.start).min(range.end - cursor_offset)
1802 }
1803 });
1804
1805 let mut results = Vec::new();
1806 let mut diagnostic_groups_truncated = false;
1807 let mut diagnostics_byte_count = 0;
1808 for group in diagnostic_groups {
1809 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1810 diagnostics_byte_count += raw_value.get().len();
1811 if diagnostics_byte_count > max_diagnostics_bytes {
1812 diagnostic_groups_truncated = true;
1813 break;
1814 }
1815 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1816 }
1817
1818 (results, diagnostic_groups_truncated)
1819 }
1820
1821 // TODO: Dedupe with similar code in request_prediction?
1822 pub fn cloud_request_for_zeta_cli(
1823 &mut self,
1824 project: &Entity<Project>,
1825 buffer: &Entity<Buffer>,
1826 position: language::Anchor,
1827 cx: &mut Context<Self>,
1828 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1829 let project_state = self.projects.get(&project.entity_id());
1830
1831 let index_state = project_state.and_then(|state| {
1832 state
1833 .syntax_index
1834 .as_ref()
1835 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1836 });
1837 let options = self.options.clone();
1838 let snapshot = buffer.read(cx).snapshot();
1839 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1840 return Task::ready(Err(anyhow!("No file path for excerpt")));
1841 };
1842 let worktree_snapshots = project
1843 .read(cx)
1844 .worktrees(cx)
1845 .map(|worktree| worktree.read(cx).snapshot())
1846 .collect::<Vec<_>>();
1847
1848 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1849 let mut path = f.worktree.read(cx).absolutize(&f.path);
1850 if path.pop() { Some(path) } else { None }
1851 });
1852
1853 cx.background_spawn(async move {
1854 let index_state = if let Some(index_state) = index_state {
1855 Some(index_state.lock_owned().await)
1856 } else {
1857 None
1858 };
1859
1860 let cursor_point = position.to_point(&snapshot);
1861
1862 let debug_info = true;
1863 EditPredictionContext::gather_context(
1864 cursor_point,
1865 &snapshot,
1866 parent_abs_path.as_deref(),
1867 match &options.context {
1868 ContextMode::Agentic(_) => {
1869 // TODO
1870 panic!("Llm mode not supported in zeta cli yet");
1871 }
1872 ContextMode::Syntax(edit_prediction_context_options) => {
1873 edit_prediction_context_options
1874 }
1875 },
1876 index_state.as_deref(),
1877 )
1878 .context("Failed to select excerpt")
1879 .map(|context| {
1880 make_syntax_context_cloud_request(
1881 excerpt_path.into(),
1882 context,
1883 // TODO pass everything
1884 Vec::new(),
1885 false,
1886 Vec::new(),
1887 false,
1888 None,
1889 debug_info,
1890 &worktree_snapshots,
1891 index_state.as_deref(),
1892 Some(options.max_prompt_bytes),
1893 options.prompt_format,
1894 )
1895 })
1896 })
1897 }
1898
1899 pub fn wait_for_initial_indexing(
1900 &mut self,
1901 project: &Entity<Project>,
1902 cx: &mut Context<Self>,
1903 ) -> Task<Result<()>> {
1904 let zeta_project = self.get_or_init_zeta_project(project, cx);
1905 if let Some(syntax_index) = &zeta_project.syntax_index {
1906 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1907 } else {
1908 Task::ready(Ok(()))
1909 }
1910 }
1911}
1912
1913pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1914 let choice = res.choices.pop()?;
1915 let output_text = match choice.message {
1916 open_ai::RequestMessage::Assistant {
1917 content: Some(open_ai::MessageContent::Plain(content)),
1918 ..
1919 } => content,
1920 open_ai::RequestMessage::Assistant {
1921 content: Some(open_ai::MessageContent::Multipart(mut content)),
1922 ..
1923 } => {
1924 if content.is_empty() {
1925 log::error!("No output from Baseten completion response");
1926 return None;
1927 }
1928
1929 match content.remove(0) {
1930 open_ai::MessagePart::Text { text } => text,
1931 open_ai::MessagePart::Image { .. } => {
1932 log::error!("Expected text, got an image");
1933 return None;
1934 }
1935 }
1936 }
1937 _ => {
1938 log::error!("Invalid response message: {:?}", choice.message);
1939 return None;
1940 }
1941 };
1942 Some(output_text)
1943}
1944
1945#[derive(Error, Debug)]
1946#[error(
1947 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1948)]
1949pub struct ZedUpdateRequiredError {
1950 minimum_version: SemanticVersion,
1951}
1952
1953fn make_syntax_context_cloud_request(
1954 excerpt_path: Arc<Path>,
1955 context: EditPredictionContext,
1956 events: Vec<predict_edits_v3::Event>,
1957 can_collect_data: bool,
1958 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1959 diagnostic_groups_truncated: bool,
1960 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1961 debug_info: bool,
1962 worktrees: &Vec<worktree::Snapshot>,
1963 index_state: Option<&SyntaxIndexState>,
1964 prompt_max_bytes: Option<usize>,
1965 prompt_format: PromptFormat,
1966) -> predict_edits_v3::PredictEditsRequest {
1967 let mut signatures = Vec::new();
1968 let mut declaration_to_signature_index = HashMap::default();
1969 let mut referenced_declarations = Vec::new();
1970
1971 for snippet in context.declarations {
1972 let project_entry_id = snippet.declaration.project_entry_id();
1973 let Some(path) = worktrees.iter().find_map(|worktree| {
1974 worktree.entry_for_id(project_entry_id).map(|entry| {
1975 let mut full_path = RelPathBuf::new();
1976 full_path.push(worktree.root_name());
1977 full_path.push(&entry.path);
1978 full_path
1979 })
1980 }) else {
1981 continue;
1982 };
1983
1984 let parent_index = index_state.and_then(|index_state| {
1985 snippet.declaration.parent().and_then(|parent| {
1986 add_signature(
1987 parent,
1988 &mut declaration_to_signature_index,
1989 &mut signatures,
1990 index_state,
1991 )
1992 })
1993 });
1994
1995 let (text, text_is_truncated) = snippet.declaration.item_text();
1996 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1997 path: path.as_std_path().into(),
1998 text: text.into(),
1999 range: snippet.declaration.item_line_range(),
2000 text_is_truncated,
2001 signature_range: snippet.declaration.signature_range_in_item_text(),
2002 parent_index,
2003 signature_score: snippet.score(DeclarationStyle::Signature),
2004 declaration_score: snippet.score(DeclarationStyle::Declaration),
2005 score_components: snippet.components,
2006 });
2007 }
2008
2009 let excerpt_parent = index_state.and_then(|index_state| {
2010 context
2011 .excerpt
2012 .parent_declarations
2013 .last()
2014 .and_then(|(parent, _)| {
2015 add_signature(
2016 *parent,
2017 &mut declaration_to_signature_index,
2018 &mut signatures,
2019 index_state,
2020 )
2021 })
2022 });
2023
2024 predict_edits_v3::PredictEditsRequest {
2025 excerpt_path,
2026 excerpt: context.excerpt_text.body,
2027 excerpt_line_range: context.excerpt.line_range,
2028 excerpt_range: context.excerpt.range,
2029 cursor_point: predict_edits_v3::Point {
2030 line: predict_edits_v3::Line(context.cursor_point.row),
2031 column: context.cursor_point.column,
2032 },
2033 referenced_declarations,
2034 included_files: vec![],
2035 signatures,
2036 excerpt_parent,
2037 events,
2038 can_collect_data,
2039 diagnostic_groups,
2040 diagnostic_groups_truncated,
2041 git_info,
2042 debug_info,
2043 prompt_max_bytes,
2044 prompt_format,
2045 }
2046}
2047
2048fn add_signature(
2049 declaration_id: DeclarationId,
2050 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2051 signatures: &mut Vec<Signature>,
2052 index: &SyntaxIndexState,
2053) -> Option<usize> {
2054 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2055 return Some(*signature_index);
2056 }
2057 let Some(parent_declaration) = index.declaration(declaration_id) else {
2058 log::error!("bug: missing parent declaration");
2059 return None;
2060 };
2061 let parent_index = parent_declaration.parent().and_then(|parent| {
2062 add_signature(parent, declaration_to_signature_index, signatures, index)
2063 });
2064 let (text, text_is_truncated) = parent_declaration.signature_text();
2065 let signature_index = signatures.len();
2066 signatures.push(Signature {
2067 text: text.into(),
2068 text_is_truncated,
2069 parent_index,
2070 range: parent_declaration.signature_line_range(),
2071 });
2072 declaration_to_signature_index.insert(declaration_id, signature_index);
2073 Some(signature_index)
2074}
2075
2076#[cfg(feature = "eval-support")]
2077pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2078
2079#[cfg(feature = "eval-support")]
2080#[derive(Debug, Clone, Copy, PartialEq)]
2081pub enum EvalCacheEntryKind {
2082 Context,
2083 Search,
2084 Prediction,
2085}
2086
2087#[cfg(feature = "eval-support")]
2088impl std::fmt::Display for EvalCacheEntryKind {
2089 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2090 match self {
2091 EvalCacheEntryKind::Search => write!(f, "search"),
2092 EvalCacheEntryKind::Context => write!(f, "context"),
2093 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2094 }
2095 }
2096}
2097
2098#[cfg(feature = "eval-support")]
2099pub trait EvalCache: Send + Sync {
2100 fn read(&self, key: EvalCacheKey) -> Option<String>;
2101 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2102}
2103
2104#[cfg(test)]
2105mod tests {
2106 use std::{path::Path, sync::Arc};
2107
2108 use client::UserStore;
2109 use clock::FakeSystemClock;
2110 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2111 use futures::{
2112 AsyncReadExt, StreamExt,
2113 channel::{mpsc, oneshot},
2114 };
2115 use gpui::{
2116 Entity, TestAppContext,
2117 http_client::{FakeHttpClient, Response},
2118 prelude::*,
2119 };
2120 use indoc::indoc;
2121 use language::OffsetRangeExt as _;
2122 use open_ai::Usage;
2123 use pretty_assertions::{assert_eq, assert_matches};
2124 use project::{FakeFs, Project};
2125 use serde_json::json;
2126 use settings::SettingsStore;
2127 use util::path;
2128 use uuid::Uuid;
2129
2130 use crate::{BufferEditPrediction, Zeta};
2131
2132 #[gpui::test]
2133 async fn test_current_state(cx: &mut TestAppContext) {
2134 let (zeta, mut req_rx) = init_test(cx);
2135 let fs = FakeFs::new(cx.executor());
2136 fs.insert_tree(
2137 "/root",
2138 json!({
2139 "1.txt": "Hello!\nHow\nBye\n",
2140 "2.txt": "Hola!\nComo\nAdios\n"
2141 }),
2142 )
2143 .await;
2144 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2145
2146 zeta.update(cx, |zeta, cx| {
2147 zeta.register_project(&project, cx);
2148 });
2149
2150 let buffer1 = project
2151 .update(cx, |project, cx| {
2152 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2153 project.open_buffer(path, cx)
2154 })
2155 .await
2156 .unwrap();
2157 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2158 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2159
2160 // Prediction for current file
2161
2162 let prediction_task = zeta.update(cx, |zeta, cx| {
2163 zeta.refresh_prediction(&project, &buffer1, position, cx)
2164 });
2165 let (_request, respond_tx) = req_rx.next().await.unwrap();
2166
2167 respond_tx
2168 .send(model_response(indoc! {r"
2169 --- a/root/1.txt
2170 +++ b/root/1.txt
2171 @@ ... @@
2172 Hello!
2173 -How
2174 +How are you?
2175 Bye
2176 "}))
2177 .unwrap();
2178 prediction_task.await.unwrap();
2179
2180 zeta.read_with(cx, |zeta, cx| {
2181 let prediction = zeta
2182 .current_prediction_for_buffer(&buffer1, &project, cx)
2183 .unwrap();
2184 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2185 });
2186
2187 // Context refresh
2188 let refresh_task = zeta.update(cx, |zeta, cx| {
2189 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2190 });
2191 let (_request, respond_tx) = req_rx.next().await.unwrap();
2192 respond_tx
2193 .send(open_ai::Response {
2194 id: Uuid::new_v4().to_string(),
2195 object: "response".into(),
2196 created: 0,
2197 model: "model".into(),
2198 choices: vec![open_ai::Choice {
2199 index: 0,
2200 message: open_ai::RequestMessage::Assistant {
2201 content: None,
2202 tool_calls: vec![open_ai::ToolCall {
2203 id: "search".into(),
2204 content: open_ai::ToolCallContent::Function {
2205 function: open_ai::FunctionContent {
2206 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2207 .to_string(),
2208 arguments: serde_json::to_string(&SearchToolInput {
2209 queries: Box::new([SearchToolQuery {
2210 glob: "root/2.txt".to_string(),
2211 syntax_node: vec![],
2212 content: Some(".".into()),
2213 }]),
2214 })
2215 .unwrap(),
2216 },
2217 },
2218 }],
2219 },
2220 finish_reason: None,
2221 }],
2222 usage: Usage {
2223 prompt_tokens: 0,
2224 completion_tokens: 0,
2225 total_tokens: 0,
2226 },
2227 })
2228 .unwrap();
2229 refresh_task.await.unwrap();
2230
2231 zeta.update(cx, |zeta, _cx| {
2232 zeta.discard_current_prediction(&project);
2233 });
2234
2235 // Prediction for another file
2236 let prediction_task = zeta.update(cx, |zeta, cx| {
2237 zeta.refresh_prediction(&project, &buffer1, position, cx)
2238 });
2239 let (_request, respond_tx) = req_rx.next().await.unwrap();
2240 respond_tx
2241 .send(model_response(indoc! {r#"
2242 --- a/root/2.txt
2243 +++ b/root/2.txt
2244 Hola!
2245 -Como
2246 +Como estas?
2247 Adios
2248 "#}))
2249 .unwrap();
2250 prediction_task.await.unwrap();
2251 zeta.read_with(cx, |zeta, cx| {
2252 let prediction = zeta
2253 .current_prediction_for_buffer(&buffer1, &project, cx)
2254 .unwrap();
2255 assert_matches!(
2256 prediction,
2257 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2258 );
2259 });
2260
2261 let buffer2 = project
2262 .update(cx, |project, cx| {
2263 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2264 project.open_buffer(path, cx)
2265 })
2266 .await
2267 .unwrap();
2268
2269 zeta.read_with(cx, |zeta, cx| {
2270 let prediction = zeta
2271 .current_prediction_for_buffer(&buffer2, &project, cx)
2272 .unwrap();
2273 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2274 });
2275 }
2276
2277 #[gpui::test]
2278 async fn test_simple_request(cx: &mut TestAppContext) {
2279 let (zeta, mut req_rx) = init_test(cx);
2280 let fs = FakeFs::new(cx.executor());
2281 fs.insert_tree(
2282 "/root",
2283 json!({
2284 "foo.md": "Hello!\nHow\nBye\n"
2285 }),
2286 )
2287 .await;
2288 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2289
2290 let buffer = project
2291 .update(cx, |project, cx| {
2292 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2293 project.open_buffer(path, cx)
2294 })
2295 .await
2296 .unwrap();
2297 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2298 let position = snapshot.anchor_before(language::Point::new(1, 3));
2299
2300 let prediction_task = zeta.update(cx, |zeta, cx| {
2301 zeta.request_prediction(&project, &buffer, position, cx)
2302 });
2303
2304 let (_, respond_tx) = req_rx.next().await.unwrap();
2305
2306 // TODO Put back when we have a structured request again
2307 // assert_eq!(
2308 // request.excerpt_path.as_ref(),
2309 // Path::new(path!("root/foo.md"))
2310 // );
2311 // assert_eq!(
2312 // request.cursor_point,
2313 // Point {
2314 // line: Line(1),
2315 // column: 3
2316 // }
2317 // );
2318
2319 respond_tx
2320 .send(model_response(indoc! { r"
2321 --- a/root/foo.md
2322 +++ b/root/foo.md
2323 @@ ... @@
2324 Hello!
2325 -How
2326 +How are you?
2327 Bye
2328 "}))
2329 .unwrap();
2330
2331 let prediction = prediction_task.await.unwrap().unwrap();
2332
2333 assert_eq!(prediction.edits.len(), 1);
2334 assert_eq!(
2335 prediction.edits[0].0.to_point(&snapshot).start,
2336 language::Point::new(1, 3)
2337 );
2338 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2339 }
2340
2341 #[gpui::test]
2342 async fn test_request_events(cx: &mut TestAppContext) {
2343 let (zeta, mut req_rx) = init_test(cx);
2344 let fs = FakeFs::new(cx.executor());
2345 fs.insert_tree(
2346 "/root",
2347 json!({
2348 "foo.md": "Hello!\n\nBye\n"
2349 }),
2350 )
2351 .await;
2352 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2353
2354 let buffer = project
2355 .update(cx, |project, cx| {
2356 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2357 project.open_buffer(path, cx)
2358 })
2359 .await
2360 .unwrap();
2361
2362 zeta.update(cx, |zeta, cx| {
2363 zeta.register_buffer(&buffer, &project, cx);
2364 });
2365
2366 buffer.update(cx, |buffer, cx| {
2367 buffer.edit(vec![(7..7, "How")], None, cx);
2368 });
2369
2370 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2371 let position = snapshot.anchor_before(language::Point::new(1, 3));
2372
2373 let prediction_task = zeta.update(cx, |zeta, cx| {
2374 zeta.request_prediction(&project, &buffer, position, cx)
2375 });
2376
2377 let (request, respond_tx) = req_rx.next().await.unwrap();
2378
2379 let prompt = prompt_from_request(&request);
2380 assert!(
2381 prompt.contains(indoc! {"
2382 --- a/root/foo.md
2383 +++ b/root/foo.md
2384 @@ -1,3 +1,3 @@
2385 Hello!
2386 -
2387 +How
2388 Bye
2389 "}),
2390 "{prompt}"
2391 );
2392
2393 respond_tx
2394 .send(model_response(indoc! {r#"
2395 --- a/root/foo.md
2396 +++ b/root/foo.md
2397 @@ ... @@
2398 Hello!
2399 -How
2400 +How are you?
2401 Bye
2402 "#}))
2403 .unwrap();
2404
2405 let prediction = prediction_task.await.unwrap().unwrap();
2406
2407 assert_eq!(prediction.edits.len(), 1);
2408 assert_eq!(
2409 prediction.edits[0].0.to_point(&snapshot).start,
2410 language::Point::new(1, 3)
2411 );
2412 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2413 }
2414
2415 // Skipped until we start including diagnostics in prompt
2416 // #[gpui::test]
2417 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2418 // let (zeta, mut req_rx) = init_test(cx);
2419 // let fs = FakeFs::new(cx.executor());
2420 // fs.insert_tree(
2421 // "/root",
2422 // json!({
2423 // "foo.md": "Hello!\nBye"
2424 // }),
2425 // )
2426 // .await;
2427 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2428
2429 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2430 // let diagnostic = lsp::Diagnostic {
2431 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2432 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2433 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2434 // ..Default::default()
2435 // };
2436
2437 // project.update(cx, |project, cx| {
2438 // project.lsp_store().update(cx, |lsp_store, cx| {
2439 // // Create some diagnostics
2440 // lsp_store
2441 // .update_diagnostics(
2442 // LanguageServerId(0),
2443 // lsp::PublishDiagnosticsParams {
2444 // uri: path_to_buffer_uri.clone(),
2445 // diagnostics: vec![diagnostic],
2446 // version: None,
2447 // },
2448 // None,
2449 // language::DiagnosticSourceKind::Pushed,
2450 // &[],
2451 // cx,
2452 // )
2453 // .unwrap();
2454 // });
2455 // });
2456
2457 // let buffer = project
2458 // .update(cx, |project, cx| {
2459 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2460 // project.open_buffer(path, cx)
2461 // })
2462 // .await
2463 // .unwrap();
2464
2465 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2466 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2467
2468 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2469 // zeta.request_prediction(&project, &buffer, position, cx)
2470 // });
2471
2472 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2473
2474 // assert_eq!(request.diagnostic_groups.len(), 1);
2475 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2476 // .unwrap();
2477 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2478 // assert_eq!(
2479 // value,
2480 // json!({
2481 // "entries": [{
2482 // "range": {
2483 // "start": 8,
2484 // "end": 10
2485 // },
2486 // "diagnostic": {
2487 // "source": null,
2488 // "code": null,
2489 // "code_description": null,
2490 // "severity": 1,
2491 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2492 // "markdown": null,
2493 // "group_id": 0,
2494 // "is_primary": true,
2495 // "is_disk_based": false,
2496 // "is_unnecessary": false,
2497 // "source_kind": "Pushed",
2498 // "data": null,
2499 // "underline": true
2500 // }
2501 // }],
2502 // "primary_ix": 0
2503 // })
2504 // );
2505 // }
2506
2507 fn model_response(text: &str) -> open_ai::Response {
2508 open_ai::Response {
2509 id: Uuid::new_v4().to_string(),
2510 object: "response".into(),
2511 created: 0,
2512 model: "model".into(),
2513 choices: vec![open_ai::Choice {
2514 index: 0,
2515 message: open_ai::RequestMessage::Assistant {
2516 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2517 tool_calls: vec![],
2518 },
2519 finish_reason: None,
2520 }],
2521 usage: Usage {
2522 prompt_tokens: 0,
2523 completion_tokens: 0,
2524 total_tokens: 0,
2525 },
2526 }
2527 }
2528
2529 fn prompt_from_request(request: &open_ai::Request) -> &str {
2530 assert_eq!(request.messages.len(), 1);
2531 let open_ai::RequestMessage::User {
2532 content: open_ai::MessageContent::Plain(content),
2533 ..
2534 } = &request.messages[0]
2535 else {
2536 panic!(
2537 "Request does not have single user message of type Plain. {:#?}",
2538 request
2539 );
2540 };
2541 content
2542 }
2543
2544 fn init_test(
2545 cx: &mut TestAppContext,
2546 ) -> (
2547 Entity<Zeta>,
2548 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2549 ) {
2550 cx.update(move |cx| {
2551 let settings_store = SettingsStore::test(cx);
2552 cx.set_global(settings_store);
2553 zlog::init_test();
2554
2555 let (req_tx, req_rx) = mpsc::unbounded();
2556
2557 let http_client = FakeHttpClient::create({
2558 move |req| {
2559 let uri = req.uri().path().to_string();
2560 let mut body = req.into_body();
2561 let req_tx = req_tx.clone();
2562 async move {
2563 let resp = match uri.as_str() {
2564 "/client/llm_tokens" => serde_json::to_string(&json!({
2565 "token": "test"
2566 }))
2567 .unwrap(),
2568 "/predict_edits/raw" => {
2569 let mut buf = Vec::new();
2570 body.read_to_end(&mut buf).await.ok();
2571 let req = serde_json::from_slice(&buf).unwrap();
2572
2573 let (res_tx, res_rx) = oneshot::channel();
2574 req_tx.unbounded_send((req, res_tx)).unwrap();
2575 serde_json::to_string(&res_rx.await?).unwrap()
2576 }
2577 _ => {
2578 panic!("Unexpected path: {}", uri)
2579 }
2580 };
2581
2582 Ok(Response::builder().body(resp.into()).unwrap())
2583 }
2584 }
2585 });
2586
2587 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2588 client.cloud_client().set_credentials(1, "test".into());
2589
2590 language_model::init(client.clone(), cx);
2591
2592 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2593 let zeta = Zeta::global(&client, &user_store, cx);
2594
2595 (zeta, req_rx)
2596 })
2597 }
2598}