1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
6use cloud_llm_client::{
7 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::HashMap;
13use edit_prediction_context::{
14 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
15 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
16 SyntaxIndex, SyntaxIndexState,
17};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::AsyncReadExt as _;
20use futures::channel::{mpsc, oneshot};
21use gpui::http_client::{AsyncBody, Method};
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity,
24 http_client, prelude::*,
25};
26use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
27use language::{BufferSnapshot, OffsetRangeExt};
28use language_model::{LlmApiToken, RefreshLlmTokenListener};
29use lsp::DiagnosticSeverity;
30use open_ai::FunctionDefinition;
31use project::{Project, ProjectPath};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use std::collections::{VecDeque, hash_map};
36
37use std::fmt::Write;
38use std::ops::Range;
39use std::path::Path;
40use std::str::FromStr;
41use std::sync::{Arc, LazyLock};
42use std::time::{Duration, Instant};
43use std::{env, mem};
44use thiserror::Error;
45use util::rel_path::RelPathBuf;
46use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
47use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
48
49pub mod assemble_excerpts;
50mod prediction;
51mod provider;
52pub mod retrieval_search;
53mod sweep_ai;
54pub mod udiff;
55mod xml_edits;
56
57use crate::assemble_excerpts::assemble_excerpts;
58pub use crate::prediction::EditPrediction;
59pub use crate::prediction::EditPredictionId;
60pub use provider::ZetaEditPredictionProvider;
61
62/// Maximum number of events to track.
63const EVENT_COUNT_MAX_SWEEP: usize = 6;
64const EVENT_COUNT_MAX_ZETA: usize = 16;
65const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
66
67pub struct SweepFeatureFlag;
68
69impl FeatureFlag for SweepFeatureFlag {
70 const NAME: &str = "sweep-ai";
71}
72pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
73 max_bytes: 512,
74 min_bytes: 128,
75 target_before_cursor_over_total_bytes: 0.5,
76};
77
78pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
79 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
80
81pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
82 excerpt: DEFAULT_EXCERPT_OPTIONS,
83};
84
85pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
86 EditPredictionContextOptions {
87 use_imports: true,
88 max_retrieved_declarations: 0,
89 excerpt: DEFAULT_EXCERPT_OPTIONS,
90 score: EditPredictionScoreOptions {
91 omit_excerpt_overlaps: true,
92 },
93 };
94
95pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
96 context: DEFAULT_CONTEXT_OPTIONS,
97 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
98 max_diagnostic_bytes: 2048,
99 prompt_format: PromptFormat::DEFAULT,
100 file_indexing_parallelism: 1,
101 buffer_change_grouping_interval: Duration::from_secs(1),
102};
103
104static USE_OLLAMA: LazyLock<bool> =
105 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
106static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
107 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
108 "qwen3-coder:30b".to_string()
109 } else {
110 "yqvev8r3".to_string()
111 })
112});
113static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
114 match env::var("ZED_ZETA2_MODEL").as_deref() {
115 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
116 Ok(model) => model,
117 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
118 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
119 }
120 .to_string()
121});
122static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
123 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
124 if *USE_OLLAMA {
125 Some("http://localhost:11434/v1/chat/completions".into())
126 } else {
127 None
128 }
129 })
130});
131
132pub struct Zeta2FeatureFlag;
133
134impl FeatureFlag for Zeta2FeatureFlag {
135 const NAME: &'static str = "zeta2";
136
137 fn enabled_for_staff() -> bool {
138 false
139 }
140}
141
142#[derive(Clone)]
143struct ZetaGlobal(Entity<Zeta>);
144
145impl Global for ZetaGlobal {}
146
147pub struct Zeta {
148 client: Arc<Client>,
149 user_store: Entity<UserStore>,
150 llm_token: LlmApiToken,
151 _llm_token_subscription: Subscription,
152 projects: HashMap<EntityId, ZetaProject>,
153 options: ZetaOptions,
154 update_required: bool,
155 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
156 #[cfg(feature = "eval-support")]
157 eval_cache: Option<Arc<dyn EvalCache>>,
158 edit_prediction_model: ZetaEditPredictionModel,
159 sweep_api_token: Option<String>,
160 sweep_ai_debug_info: Arc<str>,
161}
162
163#[derive(PartialEq, Eq)]
164pub enum ZetaEditPredictionModel {
165 ZedCloud,
166 Sweep,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub struct ZetaOptions {
171 pub context: ContextMode,
172 pub max_prompt_bytes: usize,
173 pub max_diagnostic_bytes: usize,
174 pub prompt_format: predict_edits_v3::PromptFormat,
175 pub file_indexing_parallelism: usize,
176 pub buffer_change_grouping_interval: Duration,
177}
178
179#[derive(Debug, Clone, PartialEq)]
180pub enum ContextMode {
181 Agentic(AgenticContextOptions),
182 Syntax(EditPredictionContextOptions),
183}
184
185#[derive(Debug, Clone, PartialEq)]
186pub struct AgenticContextOptions {
187 pub excerpt: EditPredictionExcerptOptions,
188}
189
190impl ContextMode {
191 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
192 match self {
193 ContextMode::Agentic(options) => &options.excerpt,
194 ContextMode::Syntax(options) => &options.excerpt,
195 }
196 }
197}
198
199#[derive(Debug)]
200pub enum ZetaDebugInfo {
201 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
202 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
203 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
204 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
205 EditPredictionRequested(ZetaEditPredictionDebugInfo),
206}
207
208#[derive(Debug)]
209pub struct ZetaContextRetrievalStartedDebugInfo {
210 pub project: Entity<Project>,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ZetaContextRetrievalDebugInfo {
217 pub project: Entity<Project>,
218 pub timestamp: Instant,
219}
220
221#[derive(Debug)]
222pub struct ZetaEditPredictionDebugInfo {
223 pub request: predict_edits_v3::PredictEditsRequest,
224 pub retrieval_time: TimeDelta,
225 pub buffer: WeakEntity<Buffer>,
226 pub position: language::Anchor,
227 pub local_prompt: Result<String, String>,
228 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
229}
230
231#[derive(Debug)]
232pub struct ZetaSearchQueryDebugInfo {
233 pub project: Entity<Project>,
234 pub timestamp: Instant,
235 pub search_queries: Vec<SearchToolQuery>,
236}
237
238pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
239
240struct ZetaProject {
241 syntax_index: Option<Entity<SyntaxIndex>>,
242 events: VecDeque<Event>,
243 recent_paths: VecDeque<ProjectPath>,
244 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
245 current_prediction: Option<CurrentEditPrediction>,
246 next_pending_prediction_id: usize,
247 pending_predictions: ArrayVec<PendingPrediction, 2>,
248 last_prediction_refresh: Option<(EntityId, Instant)>,
249 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
250 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
251 refresh_context_debounce_task: Option<Task<Option<()>>>,
252 refresh_context_timestamp: Option<Instant>,
253 _subscription: gpui::Subscription,
254}
255
256#[derive(Debug, Clone)]
257struct CurrentEditPrediction {
258 pub requested_by: PredictionRequestedBy,
259 pub prediction: EditPrediction,
260}
261
262impl CurrentEditPrediction {
263 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
264 let Some(new_edits) = self
265 .prediction
266 .interpolate(&self.prediction.buffer.read(cx))
267 else {
268 return false;
269 };
270
271 if self.prediction.buffer != old_prediction.prediction.buffer {
272 return true;
273 }
274
275 let Some(old_edits) = old_prediction
276 .prediction
277 .interpolate(&old_prediction.prediction.buffer.read(cx))
278 else {
279 return true;
280 };
281
282 let requested_by_buffer_id = self.requested_by.buffer_id();
283
284 // This reduces the occurrence of UI thrash from replacing edits
285 //
286 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
287 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
288 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
289 && old_edits.len() == 1
290 && new_edits.len() == 1
291 {
292 let (old_range, old_text) = &old_edits[0];
293 let (new_range, new_text) = &new_edits[0];
294 new_range == old_range && new_text.starts_with(old_text.as_ref())
295 } else {
296 true
297 }
298 }
299}
300
301#[derive(Debug, Clone)]
302enum PredictionRequestedBy {
303 DiagnosticsUpdate,
304 Buffer(EntityId),
305}
306
307impl PredictionRequestedBy {
308 pub fn buffer_id(&self) -> Option<EntityId> {
309 match self {
310 PredictionRequestedBy::DiagnosticsUpdate => None,
311 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
312 }
313 }
314}
315
316struct PendingPrediction {
317 id: usize,
318 _task: Task<()>,
319}
320
321/// A prediction from the perspective of a buffer.
322#[derive(Debug)]
323enum BufferEditPrediction<'a> {
324 Local { prediction: &'a EditPrediction },
325 Jump { prediction: &'a EditPrediction },
326}
327
328struct RegisteredBuffer {
329 snapshot: BufferSnapshot,
330 _subscriptions: [gpui::Subscription; 2],
331}
332
333#[derive(Clone)]
334pub enum Event {
335 BufferChange {
336 old_snapshot: BufferSnapshot,
337 new_snapshot: BufferSnapshot,
338 end_edit_anchor: Option<Anchor>,
339 timestamp: Instant,
340 },
341}
342
343impl Event {
344 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
345 match self {
346 Event::BufferChange {
347 old_snapshot,
348 new_snapshot,
349 ..
350 } => {
351 let path = new_snapshot.file().map(|f| f.full_path(cx));
352
353 let old_path = old_snapshot.file().and_then(|f| {
354 let old_path = f.full_path(cx);
355 if Some(&old_path) != path.as_ref() {
356 Some(old_path)
357 } else {
358 None
359 }
360 });
361
362 // TODO [zeta2] move to bg?
363 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
364
365 if path == old_path && diff.is_empty() {
366 None
367 } else {
368 Some(predict_edits_v3::Event::BufferChange {
369 old_path,
370 path,
371 diff,
372 //todo: Actually detect if this edit was predicted or not
373 predicted: false,
374 })
375 }
376 }
377 }
378 }
379
380 pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
381 match self {
382 Event::BufferChange { new_snapshot, .. } => new_snapshot
383 .file()
384 .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
385 }
386 }
387}
388
389impl Zeta {
390 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
391 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
392 }
393
394 pub fn global(
395 client: &Arc<Client>,
396 user_store: &Entity<UserStore>,
397 cx: &mut App,
398 ) -> Entity<Self> {
399 cx.try_global::<ZetaGlobal>()
400 .map(|global| global.0.clone())
401 .unwrap_or_else(|| {
402 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
403 cx.set_global(ZetaGlobal(zeta.clone()));
404 zeta
405 })
406 }
407
408 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
409 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
410
411 Self {
412 projects: HashMap::default(),
413 client,
414 user_store,
415 options: DEFAULT_OPTIONS,
416 llm_token: LlmApiToken::default(),
417 _llm_token_subscription: cx.subscribe(
418 &refresh_llm_token_listener,
419 |this, _listener, _event, cx| {
420 let client = this.client.clone();
421 let llm_token = this.llm_token.clone();
422 cx.spawn(async move |_this, _cx| {
423 llm_token.refresh(&client).await?;
424 anyhow::Ok(())
425 })
426 .detach_and_log_err(cx);
427 },
428 ),
429 update_required: false,
430 debug_tx: None,
431 #[cfg(feature = "eval-support")]
432 eval_cache: None,
433 edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
434 sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
435 .context("No SWEEP_AI_TOKEN environment variable set")
436 .log_err(),
437 sweep_ai_debug_info: sweep_ai::debug_info(cx),
438 }
439 }
440
441 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
442 self.edit_prediction_model = model;
443 }
444
445 pub fn has_sweep_api_token(&self) -> bool {
446 self.sweep_api_token.is_some()
447 }
448
449 #[cfg(feature = "eval-support")]
450 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
451 self.eval_cache = Some(cache);
452 }
453
454 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
455 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
456 self.debug_tx = Some(debug_watch_tx);
457 debug_watch_rx
458 }
459
460 pub fn options(&self) -> &ZetaOptions {
461 &self.options
462 }
463
464 pub fn set_options(&mut self, options: ZetaOptions) {
465 self.options = options;
466 }
467
468 pub fn clear_history(&mut self) {
469 for zeta_project in self.projects.values_mut() {
470 zeta_project.events.clear();
471 }
472 }
473
474 pub fn history_for_project(
475 &self,
476 project: &Entity<Project>,
477 ) -> impl DoubleEndedIterator<Item = &Event> {
478 self.projects
479 .get(&project.entity_id())
480 .map(|project| project.events.iter())
481 .into_iter()
482 .flatten()
483 }
484
485 pub fn context_for_project(
486 &self,
487 project: &Entity<Project>,
488 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
489 self.projects
490 .get(&project.entity_id())
491 .and_then(|project| {
492 Some(
493 project
494 .context
495 .as_ref()?
496 .iter()
497 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
498 )
499 })
500 .into_iter()
501 .flatten()
502 }
503
504 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
505 if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
506 self.user_store.read(cx).edit_prediction_usage()
507 } else {
508 None
509 }
510 }
511
512 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
513 self.get_or_init_zeta_project(project, cx);
514 }
515
516 pub fn register_buffer(
517 &mut self,
518 buffer: &Entity<Buffer>,
519 project: &Entity<Project>,
520 cx: &mut Context<Self>,
521 ) {
522 let zeta_project = self.get_or_init_zeta_project(project, cx);
523 Self::register_buffer_impl(zeta_project, buffer, project, cx);
524 }
525
526 fn get_or_init_zeta_project(
527 &mut self,
528 project: &Entity<Project>,
529 cx: &mut Context<Self>,
530 ) -> &mut ZetaProject {
531 self.projects
532 .entry(project.entity_id())
533 .or_insert_with(|| ZetaProject {
534 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
535 Some(cx.new(|cx| {
536 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
537 }))
538 } else {
539 None
540 },
541 events: VecDeque::new(),
542 recent_paths: VecDeque::new(),
543 registered_buffers: HashMap::default(),
544 current_prediction: None,
545 pending_predictions: ArrayVec::new(),
546 next_pending_prediction_id: 0,
547 last_prediction_refresh: None,
548 context: None,
549 refresh_context_task: None,
550 refresh_context_debounce_task: None,
551 refresh_context_timestamp: None,
552 _subscription: cx.subscribe(&project, Self::handle_project_event),
553 })
554 }
555
556 fn handle_project_event(
557 &mut self,
558 project: Entity<Project>,
559 event: &project::Event,
560 cx: &mut Context<Self>,
561 ) {
562 // TODO [zeta2] init with recent paths
563 match event {
564 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
565 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
566 return;
567 };
568 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
569 if let Some(path) = path {
570 if let Some(ix) = zeta_project
571 .recent_paths
572 .iter()
573 .position(|probe| probe == &path)
574 {
575 zeta_project.recent_paths.remove(ix);
576 }
577 zeta_project.recent_paths.push_front(path);
578 }
579 }
580 project::Event::DiagnosticsUpdated { .. } => {
581 self.refresh_prediction_from_diagnostics(project, cx);
582 }
583 _ => (),
584 }
585 }
586
587 fn register_buffer_impl<'a>(
588 zeta_project: &'a mut ZetaProject,
589 buffer: &Entity<Buffer>,
590 project: &Entity<Project>,
591 cx: &mut Context<Self>,
592 ) -> &'a mut RegisteredBuffer {
593 let buffer_id = buffer.entity_id();
594 match zeta_project.registered_buffers.entry(buffer_id) {
595 hash_map::Entry::Occupied(entry) => entry.into_mut(),
596 hash_map::Entry::Vacant(entry) => {
597 let snapshot = buffer.read(cx).snapshot();
598 let project_entity_id = project.entity_id();
599 entry.insert(RegisteredBuffer {
600 snapshot,
601 _subscriptions: [
602 cx.subscribe(buffer, {
603 let project = project.downgrade();
604 move |this, buffer, event, cx| {
605 if let language::BufferEvent::Edited = event
606 && let Some(project) = project.upgrade()
607 {
608 this.report_changes_for_buffer(&buffer, &project, cx);
609 }
610 }
611 }),
612 cx.observe_release(buffer, move |this, _buffer, _cx| {
613 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
614 else {
615 return;
616 };
617 zeta_project.registered_buffers.remove(&buffer_id);
618 }),
619 ],
620 })
621 }
622 }
623 }
624
625 fn report_changes_for_buffer(
626 &mut self,
627 buffer: &Entity<Buffer>,
628 project: &Entity<Project>,
629 cx: &mut Context<Self>,
630 ) {
631 let event_count_max = match self.edit_prediction_model {
632 ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
633 ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
634 };
635
636 let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
637 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
638
639 let new_snapshot = buffer.read(cx).snapshot();
640 if new_snapshot.version == registered_buffer.snapshot.version {
641 return;
642 }
643
644 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
645 let end_edit_anchor = new_snapshot
646 .anchored_edits_since::<Point>(&old_snapshot.version)
647 .last()
648 .map(|(_, range)| range.end);
649 let events = &mut sweep_ai_project.events;
650
651 if let Some(Event::BufferChange {
652 new_snapshot: last_new_snapshot,
653 end_edit_anchor: last_end_edit_anchor,
654 ..
655 }) = events.back_mut()
656 {
657 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
658 == last_new_snapshot.remote_id()
659 && old_snapshot.version == last_new_snapshot.version;
660
661 let should_coalesce = is_next_snapshot_of_same_buffer
662 && end_edit_anchor
663 .as_ref()
664 .zip(last_end_edit_anchor.as_ref())
665 .is_some_and(|(a, b)| {
666 let a = a.to_point(&new_snapshot);
667 let b = b.to_point(&new_snapshot);
668 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
669 });
670
671 if should_coalesce {
672 *last_end_edit_anchor = end_edit_anchor;
673 *last_new_snapshot = new_snapshot;
674 return;
675 }
676 }
677
678 if events.len() >= event_count_max {
679 events.pop_front();
680 }
681
682 events.push_back(Event::BufferChange {
683 old_snapshot,
684 new_snapshot,
685 end_edit_anchor,
686 timestamp: Instant::now(),
687 });
688 }
689
690 fn current_prediction_for_buffer(
691 &self,
692 buffer: &Entity<Buffer>,
693 project: &Entity<Project>,
694 cx: &App,
695 ) -> Option<BufferEditPrediction<'_>> {
696 let project_state = self.projects.get(&project.entity_id())?;
697
698 let CurrentEditPrediction {
699 requested_by,
700 prediction,
701 } = project_state.current_prediction.as_ref()?;
702
703 if prediction.targets_buffer(buffer.read(cx)) {
704 Some(BufferEditPrediction::Local { prediction })
705 } else {
706 let show_jump = match requested_by {
707 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
708 requested_by_buffer_id == &buffer.entity_id()
709 }
710 PredictionRequestedBy::DiagnosticsUpdate => true,
711 };
712
713 if show_jump {
714 Some(BufferEditPrediction::Jump { prediction })
715 } else {
716 None
717 }
718 }
719 }
720
721 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
722 if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
723 return;
724 }
725
726 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
727 return;
728 };
729
730 let Some(prediction) = project_state.current_prediction.take() else {
731 return;
732 };
733 let request_id = prediction.prediction.id.to_string();
734 project_state.pending_predictions.clear();
735
736 let client = self.client.clone();
737 let llm_token = self.llm_token.clone();
738 let app_version = AppVersion::global(cx);
739 cx.spawn(async move |this, cx| {
740 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
741 http_client::Url::parse(&predict_edits_url)?
742 } else {
743 client
744 .http_client()
745 .build_zed_llm_url("/predict_edits/accept", &[])?
746 };
747
748 let response = cx
749 .background_spawn(Self::send_api_request::<()>(
750 move |builder| {
751 let req = builder.uri(url.as_ref()).body(
752 serde_json::to_string(&AcceptEditPredictionBody {
753 request_id: request_id.clone(),
754 })?
755 .into(),
756 );
757 Ok(req?)
758 },
759 client,
760 llm_token,
761 app_version,
762 ))
763 .await;
764
765 Self::handle_api_response(&this, response, cx)?;
766 anyhow::Ok(())
767 })
768 .detach_and_log_err(cx);
769 }
770
771 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
772 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
773 project_state.current_prediction.take();
774 project_state.pending_predictions.clear();
775 };
776 }
777
778 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
779 self.projects
780 .get(&project.entity_id())
781 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
782 }
783
784 pub fn refresh_prediction_from_buffer(
785 &mut self,
786 project: Entity<Project>,
787 buffer: Entity<Buffer>,
788 position: language::Anchor,
789 cx: &mut Context<Self>,
790 ) {
791 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
792 let Some(request_task) = this
793 .update(cx, |this, cx| {
794 this.request_prediction(&project, &buffer, position, cx)
795 })
796 .log_err()
797 else {
798 return Task::ready(anyhow::Ok(()));
799 };
800
801 let project = project.clone();
802 cx.spawn(async move |cx| {
803 if let Some(prediction) = request_task.await? {
804 this.update(cx, |this, cx| {
805 let project_state = this
806 .projects
807 .get_mut(&project.entity_id())
808 .context("Project not found")?;
809
810 let new_prediction = CurrentEditPrediction {
811 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
812 prediction: prediction,
813 };
814
815 if project_state
816 .current_prediction
817 .as_ref()
818 .is_none_or(|old_prediction| {
819 new_prediction.should_replace_prediction(&old_prediction, cx)
820 })
821 {
822 project_state.current_prediction = Some(new_prediction);
823 cx.notify();
824 }
825 anyhow::Ok(())
826 })??;
827 }
828 Ok(())
829 })
830 })
831 }
832
833 pub fn refresh_prediction_from_diagnostics(
834 &mut self,
835 project: Entity<Project>,
836 cx: &mut Context<Self>,
837 ) {
838 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
839 return;
840 };
841
842 // Prefer predictions from buffer
843 if zeta_project.current_prediction.is_some() {
844 return;
845 };
846
847 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
848 let Some(open_buffer_task) = project
849 .update(cx, |project, cx| {
850 project
851 .active_entry()
852 .and_then(|entry| project.path_for_entry(entry, cx))
853 .map(|path| project.open_buffer(path, cx))
854 })
855 .log_err()
856 .flatten()
857 else {
858 return Task::ready(anyhow::Ok(()));
859 };
860
861 cx.spawn(async move |cx| {
862 let active_buffer = open_buffer_task.await?;
863 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
864
865 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
866 active_buffer,
867 &snapshot,
868 Default::default(),
869 Default::default(),
870 &project,
871 cx,
872 )
873 .await?
874 else {
875 return anyhow::Ok(());
876 };
877
878 let Some(prediction) = this
879 .update(cx, |this, cx| {
880 this.request_prediction(&project, &jump_buffer, jump_position, cx)
881 })?
882 .await?
883 else {
884 return anyhow::Ok(());
885 };
886
887 this.update(cx, |this, cx| {
888 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
889 zeta_project.current_prediction.get_or_insert_with(|| {
890 cx.notify();
891 CurrentEditPrediction {
892 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
893 prediction,
894 }
895 });
896 }
897 })?;
898
899 anyhow::Ok(())
900 })
901 });
902 }
903
904 #[cfg(not(test))]
905 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
906 #[cfg(test)]
907 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
908
909 fn queue_prediction_refresh(
910 &mut self,
911 project: Entity<Project>,
912 throttle_entity: EntityId,
913 cx: &mut Context<Self>,
914 do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
915 ) {
916 let zeta_project = self.get_or_init_zeta_project(&project, cx);
917 let pending_prediction_id = zeta_project.next_pending_prediction_id;
918 zeta_project.next_pending_prediction_id += 1;
919 let last_request = zeta_project.last_prediction_refresh;
920
921 // TODO report cancelled requests like in zeta1
922 let task = cx.spawn(async move |this, cx| {
923 if let Some((last_entity, last_timestamp)) = last_request
924 && throttle_entity == last_entity
925 && let Some(timeout) =
926 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
927 {
928 cx.background_executor().timer(timeout).await;
929 }
930
931 do_refresh(this.clone(), cx).await.log_err();
932
933 this.update(cx, |this, cx| {
934 let zeta_project = this.get_or_init_zeta_project(&project, cx);
935
936 if zeta_project.pending_predictions[0].id == pending_prediction_id {
937 zeta_project.pending_predictions.remove(0);
938 } else {
939 zeta_project.pending_predictions.clear();
940 }
941
942 cx.notify();
943 })
944 .ok();
945 });
946
947 if zeta_project.pending_predictions.len() <= 1 {
948 zeta_project.pending_predictions.push(PendingPrediction {
949 id: pending_prediction_id,
950 _task: task,
951 });
952 } else if zeta_project.pending_predictions.len() == 2 {
953 zeta_project.pending_predictions.pop();
954 zeta_project.pending_predictions.push(PendingPrediction {
955 id: pending_prediction_id,
956 _task: task,
957 });
958 }
959 }
960
961 pub fn request_prediction(
962 &mut self,
963 project: &Entity<Project>,
964 active_buffer: &Entity<Buffer>,
965 position: language::Anchor,
966 cx: &mut Context<Self>,
967 ) -> Task<Result<Option<EditPrediction>>> {
968 match self.edit_prediction_model {
969 ZetaEditPredictionModel::ZedCloud => {
970 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
971 }
972 ZetaEditPredictionModel::Sweep => {
973 self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
974 }
975 }
976 }
977
978 fn request_prediction_with_sweep(
979 &mut self,
980 project: &Entity<Project>,
981 active_buffer: &Entity<Buffer>,
982 position: language::Anchor,
983 allow_jump: bool,
984 cx: &mut Context<Self>,
985 ) -> Task<Result<Option<EditPrediction>>> {
986 let snapshot = active_buffer.read(cx).snapshot();
987 let debug_info = self.sweep_ai_debug_info.clone();
988 let Some(api_token) = self.sweep_api_token.clone() else {
989 return Task::ready(Ok(None));
990 };
991 let full_path: Arc<Path> = snapshot
992 .file()
993 .map(|file| file.full_path(cx))
994 .unwrap_or_else(|| "untitled".into())
995 .into();
996
997 let project_file = project::File::from_dyn(snapshot.file());
998 let repo_name = project_file
999 .map(|file| file.worktree.read(cx).root_name_str())
1000 .unwrap_or("untitled")
1001 .into();
1002 let offset = position.to_offset(&snapshot);
1003
1004 let project_state = self.get_or_init_zeta_project(project, cx);
1005 let events = project_state.events.clone();
1006 let has_events = !events.is_empty();
1007 let recent_buffers = project_state.recent_paths.iter().cloned();
1008 let http_client = cx.http_client();
1009
1010 let recent_buffer_snapshots = recent_buffers
1011 .filter_map(|project_path| {
1012 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1013 if active_buffer == &buffer {
1014 None
1015 } else {
1016 Some(buffer.read(cx).snapshot())
1017 }
1018 })
1019 .take(3)
1020 .collect::<Vec<_>>();
1021
1022 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1023
1024 let cursor_point = position.to_point(&snapshot);
1025 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1026 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1027 let diagnostic_search_range =
1028 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1029
1030 let result = cx.background_spawn({
1031 let snapshot = snapshot.clone();
1032 let diagnostic_search_range = diagnostic_search_range.clone();
1033 async move {
1034 let text = snapshot.text();
1035
1036 let mut recent_changes = String::new();
1037 for event in events {
1038 sweep_ai::write_event(event, &mut recent_changes).unwrap();
1039 }
1040
1041 let mut file_chunks = recent_buffer_snapshots
1042 .into_iter()
1043 .map(|snapshot| {
1044 let end_point = Point::new(30, 0).min(snapshot.max_point());
1045 sweep_ai::FileChunk {
1046 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1047 file_path: snapshot
1048 .file()
1049 .map(|f| f.path().as_unix_str())
1050 .unwrap_or("untitled")
1051 .to_string(),
1052 start_line: 0,
1053 end_line: end_point.row as usize,
1054 timestamp: snapshot.file().and_then(|file| {
1055 Some(
1056 file.disk_state()
1057 .mtime()?
1058 .to_seconds_and_nanos_for_persistence()?
1059 .0,
1060 )
1061 }),
1062 }
1063 })
1064 .collect::<Vec<_>>();
1065
1066 let diagnostic_entries =
1067 snapshot.diagnostics_in_range(diagnostic_search_range, false);
1068 let mut diagnostic_content = String::new();
1069 let mut diagnostic_count = 0;
1070
1071 for entry in diagnostic_entries {
1072 let start_point: Point = entry.range.start;
1073
1074 let severity = match entry.diagnostic.severity {
1075 DiagnosticSeverity::ERROR => "error",
1076 DiagnosticSeverity::WARNING => "warning",
1077 DiagnosticSeverity::INFORMATION => "info",
1078 DiagnosticSeverity::HINT => "hint",
1079 _ => continue,
1080 };
1081
1082 diagnostic_count += 1;
1083
1084 writeln!(
1085 &mut diagnostic_content,
1086 "{} at line {}: {}",
1087 severity,
1088 start_point.row + 1,
1089 entry.diagnostic.message
1090 )?;
1091 }
1092
1093 if !diagnostic_content.is_empty() {
1094 file_chunks.push(sweep_ai::FileChunk {
1095 file_path: format!("Diagnostics for {}", full_path.display()),
1096 start_line: 0,
1097 end_line: diagnostic_count,
1098 content: diagnostic_content,
1099 timestamp: None,
1100 });
1101 }
1102
1103 let request_body = sweep_ai::AutocompleteRequest {
1104 debug_info,
1105 repo_name,
1106 file_path: full_path.clone(),
1107 file_contents: text.clone(),
1108 original_file_contents: text,
1109 cursor_position: offset,
1110 recent_changes: recent_changes.clone(),
1111 changes_above_cursor: true,
1112 multiple_suggestions: false,
1113 branch: None,
1114 file_chunks,
1115 retrieval_chunks: vec![],
1116 recent_user_actions: vec![],
1117 // TODO
1118 privacy_mode_enabled: false,
1119 };
1120
1121 let mut buf: Vec<u8> = Vec::new();
1122 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1123 serde_json::to_writer(writer, &request_body)?;
1124 let body: AsyncBody = buf.into();
1125
1126 const SWEEP_API_URL: &str =
1127 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1128
1129 let request = http_client::Request::builder()
1130 .uri(SWEEP_API_URL)
1131 .header("Content-Type", "application/json")
1132 .header("Authorization", format!("Bearer {}", api_token))
1133 .header("Connection", "keep-alive")
1134 .header("Content-Encoding", "br")
1135 .method(Method::POST)
1136 .body(body)?;
1137
1138 let mut response = http_client.send(request).await?;
1139
1140 let mut body: Vec<u8> = Vec::new();
1141 response.body_mut().read_to_end(&mut body).await?;
1142
1143 if !response.status().is_success() {
1144 anyhow::bail!(
1145 "Request failed with status: {:?}\nBody: {}",
1146 response.status(),
1147 String::from_utf8_lossy(&body),
1148 );
1149 };
1150
1151 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1152
1153 let old_text = snapshot
1154 .text_for_range(response.start_index..response.end_index)
1155 .collect::<String>();
1156 let edits = language::text_diff(&old_text, &response.completion)
1157 .into_iter()
1158 .map(|(range, text)| {
1159 (
1160 snapshot.anchor_after(response.start_index + range.start)
1161 ..snapshot.anchor_before(response.start_index + range.end),
1162 text,
1163 )
1164 })
1165 .collect::<Vec<_>>();
1166
1167 anyhow::Ok((response.autocomplete_id, edits, snapshot))
1168 }
1169 });
1170
1171 let buffer = active_buffer.clone();
1172 let project = project.clone();
1173 let active_buffer = active_buffer.clone();
1174
1175 cx.spawn(async move |this, cx| {
1176 let (id, edits, old_snapshot) = result.await?;
1177
1178 if edits.is_empty() {
1179 if has_events
1180 && allow_jump
1181 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1182 active_buffer,
1183 &snapshot,
1184 diagnostic_search_range,
1185 cursor_point,
1186 &project,
1187 cx,
1188 )
1189 .await?
1190 {
1191 return this
1192 .update(cx, |this, cx| {
1193 this.request_prediction_with_sweep(
1194 &project,
1195 &jump_buffer,
1196 jump_position,
1197 false,
1198 cx,
1199 )
1200 })?
1201 .await;
1202 }
1203
1204 return anyhow::Ok(None);
1205 }
1206
1207 let Some((edits, new_snapshot, preview_task)) =
1208 buffer.read_with(cx, |buffer, cx| {
1209 let new_snapshot = buffer.snapshot();
1210
1211 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
1212 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
1213 .into();
1214 let preview_task = buffer.preview_edits(edits.clone(), cx);
1215
1216 Some((edits, new_snapshot, preview_task))
1217 })?
1218 else {
1219 return anyhow::Ok(None);
1220 };
1221
1222 let prediction = EditPrediction {
1223 id: EditPredictionId(id.into()),
1224 edits,
1225 snapshot: new_snapshot,
1226 edit_preview: preview_task.await,
1227 buffer,
1228 };
1229
1230 anyhow::Ok(Some(prediction))
1231 })
1232 }
1233
1234 async fn next_diagnostic_location(
1235 active_buffer: Entity<Buffer>,
1236 active_buffer_snapshot: &BufferSnapshot,
1237 active_buffer_diagnostic_search_range: Range<Point>,
1238 active_buffer_cursor_point: Point,
1239 project: &Entity<Project>,
1240 cx: &mut AsyncApp,
1241 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1242 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1243 let mut jump_location = active_buffer_snapshot
1244 .diagnostic_groups(None)
1245 .into_iter()
1246 .filter_map(|(_, group)| {
1247 let range = &group.entries[group.primary_ix]
1248 .range
1249 .to_point(&active_buffer_snapshot);
1250 if range.overlaps(&active_buffer_diagnostic_search_range) {
1251 None
1252 } else {
1253 Some(range.start)
1254 }
1255 })
1256 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1257 .map(|position| {
1258 (
1259 active_buffer.clone(),
1260 active_buffer_snapshot.anchor_before(position),
1261 )
1262 });
1263
1264 if jump_location.is_none() {
1265 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1266 let file = buffer.file()?;
1267
1268 Some(ProjectPath {
1269 worktree_id: file.worktree_id(cx),
1270 path: file.path().clone(),
1271 })
1272 })?;
1273
1274 let buffer_task = project.update(cx, |project, cx| {
1275 let (path, _, _) = project
1276 .diagnostic_summaries(false, cx)
1277 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1278 .max_by_key(|(path, _, _)| {
1279 // find the buffer with errors that shares most parent directories
1280 path.path
1281 .components()
1282 .zip(
1283 active_buffer_path
1284 .as_ref()
1285 .map(|p| p.path.components())
1286 .unwrap_or_default(),
1287 )
1288 .take_while(|(a, b)| a == b)
1289 .count()
1290 })?;
1291
1292 Some(project.open_buffer(path, cx))
1293 })?;
1294
1295 if let Some(buffer_task) = buffer_task {
1296 let closest_buffer = buffer_task.await?;
1297
1298 jump_location = closest_buffer
1299 .read_with(cx, |buffer, _cx| {
1300 buffer
1301 .buffer_diagnostics(None)
1302 .into_iter()
1303 .min_by_key(|entry| entry.diagnostic.severity)
1304 .map(|entry| entry.range.start)
1305 })?
1306 .map(|position| (closest_buffer, position));
1307 }
1308 }
1309
1310 anyhow::Ok(jump_location)
1311 }
1312
1313 fn request_prediction_with_zed_cloud(
1314 &mut self,
1315 project: &Entity<Project>,
1316 active_buffer: &Entity<Buffer>,
1317 position: language::Anchor,
1318 cx: &mut Context<Self>,
1319 ) -> Task<Result<Option<EditPrediction>>> {
1320 let project_state = self.projects.get(&project.entity_id());
1321
1322 let index_state = project_state.and_then(|state| {
1323 state
1324 .syntax_index
1325 .as_ref()
1326 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1327 });
1328 let options = self.options.clone();
1329 let active_snapshot = active_buffer.read(cx).snapshot();
1330 let Some(excerpt_path) = active_snapshot
1331 .file()
1332 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1333 else {
1334 return Task::ready(Err(anyhow!("No file path for excerpt")));
1335 };
1336 let client = self.client.clone();
1337 let llm_token = self.llm_token.clone();
1338 let app_version = AppVersion::global(cx);
1339 let worktree_snapshots = project
1340 .read(cx)
1341 .worktrees(cx)
1342 .map(|worktree| worktree.read(cx).snapshot())
1343 .collect::<Vec<_>>();
1344 let debug_tx = self.debug_tx.clone();
1345
1346 let events = project_state
1347 .map(|state| {
1348 state
1349 .events
1350 .iter()
1351 .filter_map(|event| event.to_request_event(cx))
1352 .collect::<Vec<_>>()
1353 })
1354 .unwrap_or_default();
1355
1356 let diagnostics = active_snapshot.diagnostic_sets().clone();
1357
1358 let parent_abs_path =
1359 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1360 let mut path = f.worktree.read(cx).absolutize(&f.path);
1361 if path.pop() { Some(path) } else { None }
1362 });
1363
1364 // TODO data collection
1365 let can_collect_data = cx.is_staff();
1366
1367 let empty_context_files = HashMap::default();
1368 let context_files = project_state
1369 .and_then(|project_state| project_state.context.as_ref())
1370 .unwrap_or(&empty_context_files);
1371
1372 #[cfg(feature = "eval-support")]
1373 let parsed_fut = futures::future::join_all(
1374 context_files
1375 .keys()
1376 .map(|buffer| buffer.read(cx).parsing_idle()),
1377 );
1378
1379 let mut included_files = context_files
1380 .iter()
1381 .filter_map(|(buffer_entity, ranges)| {
1382 let buffer = buffer_entity.read(cx);
1383 Some((
1384 buffer_entity.clone(),
1385 buffer.snapshot(),
1386 buffer.file()?.full_path(cx).into(),
1387 ranges.clone(),
1388 ))
1389 })
1390 .collect::<Vec<_>>();
1391
1392 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1393 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1394 });
1395
1396 #[cfg(feature = "eval-support")]
1397 let eval_cache = self.eval_cache.clone();
1398
1399 let request_task = cx.background_spawn({
1400 let active_buffer = active_buffer.clone();
1401 async move {
1402 #[cfg(feature = "eval-support")]
1403 parsed_fut.await;
1404
1405 let index_state = if let Some(index_state) = index_state {
1406 Some(index_state.lock_owned().await)
1407 } else {
1408 None
1409 };
1410
1411 let cursor_offset = position.to_offset(&active_snapshot);
1412 let cursor_point = cursor_offset.to_point(&active_snapshot);
1413
1414 let before_retrieval = chrono::Utc::now();
1415
1416 let (diagnostic_groups, diagnostic_groups_truncated) =
1417 Self::gather_nearby_diagnostics(
1418 cursor_offset,
1419 &diagnostics,
1420 &active_snapshot,
1421 options.max_diagnostic_bytes,
1422 );
1423
1424 let cloud_request = match options.context {
1425 ContextMode::Agentic(context_options) => {
1426 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1427 cursor_point,
1428 &active_snapshot,
1429 &context_options.excerpt,
1430 index_state.as_deref(),
1431 ) else {
1432 return Ok((None, None));
1433 };
1434
1435 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1436 ..active_snapshot.anchor_before(excerpt.range.end);
1437
1438 if let Some(buffer_ix) =
1439 included_files.iter().position(|(_, snapshot, _, _)| {
1440 snapshot.remote_id() == active_snapshot.remote_id()
1441 })
1442 {
1443 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1444 ranges.push(excerpt_anchor_range);
1445 retrieval_search::merge_anchor_ranges(ranges, buffer);
1446 let last_ix = included_files.len() - 1;
1447 included_files.swap(buffer_ix, last_ix);
1448 } else {
1449 included_files.push((
1450 active_buffer.clone(),
1451 active_snapshot.clone(),
1452 excerpt_path.clone(),
1453 vec![excerpt_anchor_range],
1454 ));
1455 }
1456
1457 let included_files = included_files
1458 .iter()
1459 .map(|(_, snapshot, path, ranges)| {
1460 let ranges = ranges
1461 .iter()
1462 .map(|range| {
1463 let point_range = range.to_point(&snapshot);
1464 Line(point_range.start.row)..Line(point_range.end.row)
1465 })
1466 .collect::<Vec<_>>();
1467 let excerpts = assemble_excerpts(&snapshot, ranges);
1468 predict_edits_v3::IncludedFile {
1469 path: path.clone(),
1470 max_row: Line(snapshot.max_point().row),
1471 excerpts,
1472 }
1473 })
1474 .collect::<Vec<_>>();
1475
1476 predict_edits_v3::PredictEditsRequest {
1477 excerpt_path,
1478 excerpt: String::new(),
1479 excerpt_line_range: Line(0)..Line(0),
1480 excerpt_range: 0..0,
1481 cursor_point: predict_edits_v3::Point {
1482 line: predict_edits_v3::Line(cursor_point.row),
1483 column: cursor_point.column,
1484 },
1485 included_files,
1486 referenced_declarations: vec![],
1487 events,
1488 can_collect_data,
1489 diagnostic_groups,
1490 diagnostic_groups_truncated,
1491 debug_info: debug_tx.is_some(),
1492 prompt_max_bytes: Some(options.max_prompt_bytes),
1493 prompt_format: options.prompt_format,
1494 // TODO [zeta2]
1495 signatures: vec![],
1496 excerpt_parent: None,
1497 git_info: None,
1498 }
1499 }
1500 ContextMode::Syntax(context_options) => {
1501 let Some(context) = EditPredictionContext::gather_context(
1502 cursor_point,
1503 &active_snapshot,
1504 parent_abs_path.as_deref(),
1505 &context_options,
1506 index_state.as_deref(),
1507 ) else {
1508 return Ok((None, None));
1509 };
1510
1511 make_syntax_context_cloud_request(
1512 excerpt_path,
1513 context,
1514 events,
1515 can_collect_data,
1516 diagnostic_groups,
1517 diagnostic_groups_truncated,
1518 None,
1519 debug_tx.is_some(),
1520 &worktree_snapshots,
1521 index_state.as_deref(),
1522 Some(options.max_prompt_bytes),
1523 options.prompt_format,
1524 )
1525 }
1526 };
1527
1528 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1529
1530 let retrieval_time = chrono::Utc::now() - before_retrieval;
1531
1532 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1533 let (response_tx, response_rx) = oneshot::channel();
1534
1535 debug_tx
1536 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1537 ZetaEditPredictionDebugInfo {
1538 request: cloud_request.clone(),
1539 retrieval_time,
1540 buffer: active_buffer.downgrade(),
1541 local_prompt: match prompt_result.as_ref() {
1542 Ok((prompt, _)) => Ok(prompt.clone()),
1543 Err(err) => Err(err.to_string()),
1544 },
1545 position,
1546 response_rx,
1547 },
1548 ))
1549 .ok();
1550 Some(response_tx)
1551 } else {
1552 None
1553 };
1554
1555 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1556 if let Some(debug_response_tx) = debug_response_tx {
1557 debug_response_tx
1558 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1559 .ok();
1560 }
1561 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1562 }
1563
1564 let (prompt, _) = prompt_result?;
1565 let generation_params =
1566 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1567 let request = open_ai::Request {
1568 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1569 messages: vec![open_ai::RequestMessage::User {
1570 content: open_ai::MessageContent::Plain(prompt),
1571 }],
1572 stream: false,
1573 max_completion_tokens: None,
1574 stop: generation_params.stop.unwrap_or_default(),
1575 temperature: generation_params.temperature.unwrap_or(0.7),
1576 tool_choice: None,
1577 parallel_tool_calls: None,
1578 tools: vec![],
1579 prompt_cache_key: None,
1580 reasoning_effort: None,
1581 };
1582
1583 log::trace!("Sending edit prediction request");
1584
1585 let before_request = chrono::Utc::now();
1586 let response = Self::send_raw_llm_request(
1587 request,
1588 client,
1589 llm_token,
1590 app_version,
1591 #[cfg(feature = "eval-support")]
1592 eval_cache,
1593 #[cfg(feature = "eval-support")]
1594 EvalCacheEntryKind::Prediction,
1595 )
1596 .await;
1597 let request_time = chrono::Utc::now() - before_request;
1598
1599 log::trace!("Got edit prediction response");
1600
1601 if let Some(debug_response_tx) = debug_response_tx {
1602 debug_response_tx
1603 .send((
1604 response
1605 .as_ref()
1606 .map_err(|err| err.to_string())
1607 .map(|response| response.0.clone()),
1608 request_time,
1609 ))
1610 .ok();
1611 }
1612
1613 let (res, usage) = response?;
1614 let request_id = EditPredictionId(res.id.clone().into());
1615 let Some(mut output_text) = text_from_response(res) else {
1616 return Ok((None, usage));
1617 };
1618
1619 if output_text.contains(CURSOR_MARKER) {
1620 log::trace!("Stripping out {CURSOR_MARKER} from response");
1621 output_text = output_text.replace(CURSOR_MARKER, "");
1622 }
1623
1624 let get_buffer_from_context = |path: &Path| {
1625 included_files
1626 .iter()
1627 .find_map(|(_, buffer, probe_path, ranges)| {
1628 if probe_path.as_ref() == path {
1629 Some((buffer, ranges.as_slice()))
1630 } else {
1631 None
1632 }
1633 })
1634 };
1635
1636 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1637 PromptFormat::NumLinesUniDiff => {
1638 // TODO: Implement parsing of multi-file diffs
1639 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1640 }
1641 PromptFormat::Minimal
1642 | PromptFormat::MinimalQwen
1643 | PromptFormat::SeedCoder1120 => {
1644 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1645 let edits = vec![];
1646 (&active_snapshot, edits)
1647 } else {
1648 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1649 }
1650 }
1651 PromptFormat::OldTextNewText => {
1652 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1653 .await?
1654 }
1655 _ => {
1656 bail!("unsupported prompt format {}", options.prompt_format)
1657 }
1658 };
1659
1660 let edited_buffer = included_files
1661 .iter()
1662 .find_map(|(buffer, snapshot, _, _)| {
1663 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1664 Some(buffer.clone())
1665 } else {
1666 None
1667 }
1668 })
1669 .context("Failed to find buffer in included_buffers")?;
1670
1671 anyhow::Ok((
1672 Some((
1673 request_id,
1674 edited_buffer,
1675 edited_buffer_snapshot.clone(),
1676 edits,
1677 )),
1678 usage,
1679 ))
1680 }
1681 });
1682
1683 cx.spawn({
1684 async move |this, cx| {
1685 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1686 Self::handle_api_response(&this, request_task.await, cx)?
1687 else {
1688 return Ok(None);
1689 };
1690
1691 // TODO telemetry: duration, etc
1692 Ok(
1693 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1694 .await,
1695 )
1696 }
1697 })
1698 }
1699
1700 async fn send_raw_llm_request(
1701 request: open_ai::Request,
1702 client: Arc<Client>,
1703 llm_token: LlmApiToken,
1704 app_version: Version,
1705 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1706 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1707 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1708 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1709 http_client::Url::parse(&predict_edits_url)?
1710 } else {
1711 client
1712 .http_client()
1713 .build_zed_llm_url("/predict_edits/raw", &[])?
1714 };
1715
1716 #[cfg(feature = "eval-support")]
1717 let cache_key = if let Some(cache) = eval_cache {
1718 use collections::FxHasher;
1719 use std::hash::{Hash, Hasher};
1720
1721 let mut hasher = FxHasher::default();
1722 url.hash(&mut hasher);
1723 let request_str = serde_json::to_string_pretty(&request)?;
1724 request_str.hash(&mut hasher);
1725 let hash = hasher.finish();
1726
1727 let key = (eval_cache_kind, hash);
1728 if let Some(response_str) = cache.read(key) {
1729 return Ok((serde_json::from_str(&response_str)?, None));
1730 }
1731
1732 Some((cache, request_str, key))
1733 } else {
1734 None
1735 };
1736
1737 let (response, usage) = Self::send_api_request(
1738 |builder| {
1739 let req = builder
1740 .uri(url.as_ref())
1741 .body(serde_json::to_string(&request)?.into());
1742 Ok(req?)
1743 },
1744 client,
1745 llm_token,
1746 app_version,
1747 )
1748 .await?;
1749
1750 #[cfg(feature = "eval-support")]
1751 if let Some((cache, request, key)) = cache_key {
1752 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1753 }
1754
1755 Ok((response, usage))
1756 }
1757
1758 fn handle_api_response<T>(
1759 this: &WeakEntity<Self>,
1760 response: Result<(T, Option<EditPredictionUsage>)>,
1761 cx: &mut gpui::AsyncApp,
1762 ) -> Result<T> {
1763 match response {
1764 Ok((data, usage)) => {
1765 if let Some(usage) = usage {
1766 this.update(cx, |this, cx| {
1767 this.user_store.update(cx, |user_store, cx| {
1768 user_store.update_edit_prediction_usage(usage, cx);
1769 });
1770 })
1771 .ok();
1772 }
1773 Ok(data)
1774 }
1775 Err(err) => {
1776 if err.is::<ZedUpdateRequiredError>() {
1777 cx.update(|cx| {
1778 this.update(cx, |this, _cx| {
1779 this.update_required = true;
1780 })
1781 .ok();
1782
1783 let error_message: SharedString = err.to_string().into();
1784 show_app_notification(
1785 NotificationId::unique::<ZedUpdateRequiredError>(),
1786 cx,
1787 move |cx| {
1788 cx.new(|cx| {
1789 ErrorMessagePrompt::new(error_message.clone(), cx)
1790 .with_link_button("Update Zed", "https://zed.dev/releases")
1791 })
1792 },
1793 );
1794 })
1795 .ok();
1796 }
1797 Err(err)
1798 }
1799 }
1800 }
1801
1802 async fn send_api_request<Res>(
1803 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1804 client: Arc<Client>,
1805 llm_token: LlmApiToken,
1806 app_version: Version,
1807 ) -> Result<(Res, Option<EditPredictionUsage>)>
1808 where
1809 Res: DeserializeOwned,
1810 {
1811 let http_client = client.http_client();
1812 let mut token = llm_token.acquire(&client).await?;
1813 let mut did_retry = false;
1814
1815 loop {
1816 let request_builder = http_client::Request::builder().method(Method::POST);
1817
1818 let request = build(
1819 request_builder
1820 .header("Content-Type", "application/json")
1821 .header("Authorization", format!("Bearer {}", token))
1822 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1823 )?;
1824
1825 let mut response = http_client.send(request).await?;
1826
1827 if let Some(minimum_required_version) = response
1828 .headers()
1829 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1830 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1831 {
1832 anyhow::ensure!(
1833 app_version >= minimum_required_version,
1834 ZedUpdateRequiredError {
1835 minimum_version: minimum_required_version
1836 }
1837 );
1838 }
1839
1840 if response.status().is_success() {
1841 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1842
1843 let mut body = Vec::new();
1844 response.body_mut().read_to_end(&mut body).await?;
1845 return Ok((serde_json::from_slice(&body)?, usage));
1846 } else if !did_retry
1847 && response
1848 .headers()
1849 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1850 .is_some()
1851 {
1852 did_retry = true;
1853 token = llm_token.refresh(&client).await?;
1854 } else {
1855 let mut body = String::new();
1856 response.body_mut().read_to_string(&mut body).await?;
1857 anyhow::bail!(
1858 "Request failed with status: {:?}\nBody: {}",
1859 response.status(),
1860 body
1861 );
1862 }
1863 }
1864 }
1865
1866 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1867 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1868
1869 // Refresh the related excerpts when the user just beguns editing after
1870 // an idle period, and after they pause editing.
1871 fn refresh_context_if_needed(
1872 &mut self,
1873 project: &Entity<Project>,
1874 buffer: &Entity<language::Buffer>,
1875 cursor_position: language::Anchor,
1876 cx: &mut Context<Self>,
1877 ) {
1878 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1879 return;
1880 }
1881
1882 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1883 return;
1884 };
1885
1886 let now = Instant::now();
1887 let was_idle = zeta_project
1888 .refresh_context_timestamp
1889 .map_or(true, |timestamp| {
1890 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1891 });
1892 zeta_project.refresh_context_timestamp = Some(now);
1893 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1894 let buffer = buffer.clone();
1895 let project = project.clone();
1896 async move |this, cx| {
1897 if was_idle {
1898 log::debug!("refetching edit prediction context after idle");
1899 } else {
1900 cx.background_executor()
1901 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1902 .await;
1903 log::debug!("refetching edit prediction context after pause");
1904 }
1905 this.update(cx, |this, cx| {
1906 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1907
1908 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1909 zeta_project.refresh_context_task = Some(task.log_err());
1910 };
1911 })
1912 .ok()
1913 }
1914 }));
1915 }
1916
1917 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1918 // and avoid spawning more than one concurrent task.
1919 pub fn refresh_context(
1920 &mut self,
1921 project: Entity<Project>,
1922 buffer: Entity<language::Buffer>,
1923 cursor_position: language::Anchor,
1924 cx: &mut Context<Self>,
1925 ) -> Task<Result<()>> {
1926 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1927 return Task::ready(anyhow::Ok(()));
1928 };
1929
1930 let ContextMode::Agentic(options) = &self.options().context else {
1931 return Task::ready(anyhow::Ok(()));
1932 };
1933
1934 let snapshot = buffer.read(cx).snapshot();
1935 let cursor_point = cursor_position.to_point(&snapshot);
1936 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1937 cursor_point,
1938 &snapshot,
1939 &options.excerpt,
1940 None,
1941 ) else {
1942 return Task::ready(Ok(()));
1943 };
1944
1945 let app_version = AppVersion::global(cx);
1946 let client = self.client.clone();
1947 let llm_token = self.llm_token.clone();
1948 let debug_tx = self.debug_tx.clone();
1949 let current_file_path: Arc<Path> = snapshot
1950 .file()
1951 .map(|f| f.full_path(cx).into())
1952 .unwrap_or_else(|| Path::new("untitled").into());
1953
1954 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1955 predict_edits_v3::PlanContextRetrievalRequest {
1956 excerpt: cursor_excerpt.text(&snapshot).body,
1957 excerpt_path: current_file_path,
1958 excerpt_line_range: cursor_excerpt.line_range,
1959 cursor_file_max_row: Line(snapshot.max_point().row),
1960 events: zeta_project
1961 .events
1962 .iter()
1963 .filter_map(|ev| ev.to_request_event(cx))
1964 .collect(),
1965 },
1966 ) {
1967 Ok(prompt) => prompt,
1968 Err(err) => {
1969 return Task::ready(Err(err));
1970 }
1971 };
1972
1973 if let Some(debug_tx) = &debug_tx {
1974 debug_tx
1975 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1976 ZetaContextRetrievalStartedDebugInfo {
1977 project: project.clone(),
1978 timestamp: Instant::now(),
1979 search_prompt: prompt.clone(),
1980 },
1981 ))
1982 .ok();
1983 }
1984
1985 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1986 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1987 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1988 );
1989
1990 let description = schema
1991 .get("description")
1992 .and_then(|description| description.as_str())
1993 .unwrap()
1994 .to_string();
1995
1996 (schema.into(), description)
1997 });
1998
1999 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2000
2001 let request = open_ai::Request {
2002 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2003 messages: vec![open_ai::RequestMessage::User {
2004 content: open_ai::MessageContent::Plain(prompt),
2005 }],
2006 stream: false,
2007 max_completion_tokens: None,
2008 stop: Default::default(),
2009 temperature: 0.7,
2010 tool_choice: None,
2011 parallel_tool_calls: None,
2012 tools: vec![open_ai::ToolDefinition::Function {
2013 function: FunctionDefinition {
2014 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2015 description: Some(tool_description),
2016 parameters: Some(tool_schema),
2017 },
2018 }],
2019 prompt_cache_key: None,
2020 reasoning_effort: None,
2021 };
2022
2023 #[cfg(feature = "eval-support")]
2024 let eval_cache = self.eval_cache.clone();
2025
2026 cx.spawn(async move |this, cx| {
2027 log::trace!("Sending search planning request");
2028 let response = Self::send_raw_llm_request(
2029 request,
2030 client,
2031 llm_token,
2032 app_version,
2033 #[cfg(feature = "eval-support")]
2034 eval_cache.clone(),
2035 #[cfg(feature = "eval-support")]
2036 EvalCacheEntryKind::Context,
2037 )
2038 .await;
2039 let mut response = Self::handle_api_response(&this, response, cx)?;
2040 log::trace!("Got search planning response");
2041
2042 let choice = response
2043 .choices
2044 .pop()
2045 .context("No choices in retrieval response")?;
2046 let open_ai::RequestMessage::Assistant {
2047 content: _,
2048 tool_calls,
2049 } = choice.message
2050 else {
2051 anyhow::bail!("Retrieval response didn't include an assistant message");
2052 };
2053
2054 let mut queries: Vec<SearchToolQuery> = Vec::new();
2055 for tool_call in tool_calls {
2056 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2057 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2058 log::warn!(
2059 "Context retrieval response tried to call an unknown tool: {}",
2060 function.name
2061 );
2062
2063 continue;
2064 }
2065
2066 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2067 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2068 queries.extend(input.queries);
2069 }
2070
2071 if let Some(debug_tx) = &debug_tx {
2072 debug_tx
2073 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2074 ZetaSearchQueryDebugInfo {
2075 project: project.clone(),
2076 timestamp: Instant::now(),
2077 search_queries: queries.clone(),
2078 },
2079 ))
2080 .ok();
2081 }
2082
2083 log::trace!("Running retrieval search: {queries:#?}");
2084
2085 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2086 queries,
2087 project.clone(),
2088 #[cfg(feature = "eval-support")]
2089 eval_cache,
2090 cx,
2091 )
2092 .await;
2093
2094 log::trace!("Search queries executed");
2095
2096 if let Some(debug_tx) = &debug_tx {
2097 debug_tx
2098 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2099 ZetaContextRetrievalDebugInfo {
2100 project: project.clone(),
2101 timestamp: Instant::now(),
2102 },
2103 ))
2104 .ok();
2105 }
2106
2107 this.update(cx, |this, _cx| {
2108 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2109 return Ok(());
2110 };
2111 zeta_project.refresh_context_task.take();
2112 if let Some(debug_tx) = &this.debug_tx {
2113 debug_tx
2114 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2115 ZetaContextRetrievalDebugInfo {
2116 project,
2117 timestamp: Instant::now(),
2118 },
2119 ))
2120 .ok();
2121 }
2122 match related_excerpts_result {
2123 Ok(excerpts) => {
2124 zeta_project.context = Some(excerpts);
2125 Ok(())
2126 }
2127 Err(error) => Err(error),
2128 }
2129 })?
2130 })
2131 }
2132
2133 pub fn set_context(
2134 &mut self,
2135 project: Entity<Project>,
2136 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2137 ) {
2138 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2139 zeta_project.context = Some(context);
2140 }
2141 }
2142
2143 fn gather_nearby_diagnostics(
2144 cursor_offset: usize,
2145 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2146 snapshot: &BufferSnapshot,
2147 max_diagnostics_bytes: usize,
2148 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2149 // TODO: Could make this more efficient
2150 let mut diagnostic_groups = Vec::new();
2151 for (language_server_id, diagnostics) in diagnostic_sets {
2152 let mut groups = Vec::new();
2153 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2154 diagnostic_groups.extend(
2155 groups
2156 .into_iter()
2157 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2158 );
2159 }
2160
2161 // sort by proximity to cursor
2162 diagnostic_groups.sort_by_key(|group| {
2163 let range = &group.entries[group.primary_ix].range;
2164 if range.start >= cursor_offset {
2165 range.start - cursor_offset
2166 } else if cursor_offset >= range.end {
2167 cursor_offset - range.end
2168 } else {
2169 (cursor_offset - range.start).min(range.end - cursor_offset)
2170 }
2171 });
2172
2173 let mut results = Vec::new();
2174 let mut diagnostic_groups_truncated = false;
2175 let mut diagnostics_byte_count = 0;
2176 for group in diagnostic_groups {
2177 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2178 diagnostics_byte_count += raw_value.get().len();
2179 if diagnostics_byte_count > max_diagnostics_bytes {
2180 diagnostic_groups_truncated = true;
2181 break;
2182 }
2183 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2184 }
2185
2186 (results, diagnostic_groups_truncated)
2187 }
2188
2189 // TODO: Dedupe with similar code in request_prediction?
2190 pub fn cloud_request_for_zeta_cli(
2191 &mut self,
2192 project: &Entity<Project>,
2193 buffer: &Entity<Buffer>,
2194 position: language::Anchor,
2195 cx: &mut Context<Self>,
2196 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2197 let project_state = self.projects.get(&project.entity_id());
2198
2199 let index_state = project_state.and_then(|state| {
2200 state
2201 .syntax_index
2202 .as_ref()
2203 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2204 });
2205 let options = self.options.clone();
2206 let snapshot = buffer.read(cx).snapshot();
2207 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2208 return Task::ready(Err(anyhow!("No file path for excerpt")));
2209 };
2210 let worktree_snapshots = project
2211 .read(cx)
2212 .worktrees(cx)
2213 .map(|worktree| worktree.read(cx).snapshot())
2214 .collect::<Vec<_>>();
2215
2216 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2217 let mut path = f.worktree.read(cx).absolutize(&f.path);
2218 if path.pop() { Some(path) } else { None }
2219 });
2220
2221 cx.background_spawn(async move {
2222 let index_state = if let Some(index_state) = index_state {
2223 Some(index_state.lock_owned().await)
2224 } else {
2225 None
2226 };
2227
2228 let cursor_point = position.to_point(&snapshot);
2229
2230 let debug_info = true;
2231 EditPredictionContext::gather_context(
2232 cursor_point,
2233 &snapshot,
2234 parent_abs_path.as_deref(),
2235 match &options.context {
2236 ContextMode::Agentic(_) => {
2237 // TODO
2238 panic!("Llm mode not supported in zeta cli yet");
2239 }
2240 ContextMode::Syntax(edit_prediction_context_options) => {
2241 edit_prediction_context_options
2242 }
2243 },
2244 index_state.as_deref(),
2245 )
2246 .context("Failed to select excerpt")
2247 .map(|context| {
2248 make_syntax_context_cloud_request(
2249 excerpt_path.into(),
2250 context,
2251 // TODO pass everything
2252 Vec::new(),
2253 false,
2254 Vec::new(),
2255 false,
2256 None,
2257 debug_info,
2258 &worktree_snapshots,
2259 index_state.as_deref(),
2260 Some(options.max_prompt_bytes),
2261 options.prompt_format,
2262 )
2263 })
2264 })
2265 }
2266
2267 pub fn wait_for_initial_indexing(
2268 &mut self,
2269 project: &Entity<Project>,
2270 cx: &mut Context<Self>,
2271 ) -> Task<Result<()>> {
2272 let zeta_project = self.get_or_init_zeta_project(project, cx);
2273 if let Some(syntax_index) = &zeta_project.syntax_index {
2274 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2275 } else {
2276 Task::ready(Ok(()))
2277 }
2278 }
2279}
2280
2281pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2282 let choice = res.choices.pop()?;
2283 let output_text = match choice.message {
2284 open_ai::RequestMessage::Assistant {
2285 content: Some(open_ai::MessageContent::Plain(content)),
2286 ..
2287 } => content,
2288 open_ai::RequestMessage::Assistant {
2289 content: Some(open_ai::MessageContent::Multipart(mut content)),
2290 ..
2291 } => {
2292 if content.is_empty() {
2293 log::error!("No output from Baseten completion response");
2294 return None;
2295 }
2296
2297 match content.remove(0) {
2298 open_ai::MessagePart::Text { text } => text,
2299 open_ai::MessagePart::Image { .. } => {
2300 log::error!("Expected text, got an image");
2301 return None;
2302 }
2303 }
2304 }
2305 _ => {
2306 log::error!("Invalid response message: {:?}", choice.message);
2307 return None;
2308 }
2309 };
2310 Some(output_text)
2311}
2312
2313#[derive(Error, Debug)]
2314#[error(
2315 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2316)]
2317pub struct ZedUpdateRequiredError {
2318 minimum_version: Version,
2319}
2320
2321fn make_syntax_context_cloud_request(
2322 excerpt_path: Arc<Path>,
2323 context: EditPredictionContext,
2324 events: Vec<predict_edits_v3::Event>,
2325 can_collect_data: bool,
2326 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2327 diagnostic_groups_truncated: bool,
2328 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2329 debug_info: bool,
2330 worktrees: &Vec<worktree::Snapshot>,
2331 index_state: Option<&SyntaxIndexState>,
2332 prompt_max_bytes: Option<usize>,
2333 prompt_format: PromptFormat,
2334) -> predict_edits_v3::PredictEditsRequest {
2335 let mut signatures = Vec::new();
2336 let mut declaration_to_signature_index = HashMap::default();
2337 let mut referenced_declarations = Vec::new();
2338
2339 for snippet in context.declarations {
2340 let project_entry_id = snippet.declaration.project_entry_id();
2341 let Some(path) = worktrees.iter().find_map(|worktree| {
2342 worktree.entry_for_id(project_entry_id).map(|entry| {
2343 let mut full_path = RelPathBuf::new();
2344 full_path.push(worktree.root_name());
2345 full_path.push(&entry.path);
2346 full_path
2347 })
2348 }) else {
2349 continue;
2350 };
2351
2352 let parent_index = index_state.and_then(|index_state| {
2353 snippet.declaration.parent().and_then(|parent| {
2354 add_signature(
2355 parent,
2356 &mut declaration_to_signature_index,
2357 &mut signatures,
2358 index_state,
2359 )
2360 })
2361 });
2362
2363 let (text, text_is_truncated) = snippet.declaration.item_text();
2364 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2365 path: path.as_std_path().into(),
2366 text: text.into(),
2367 range: snippet.declaration.item_line_range(),
2368 text_is_truncated,
2369 signature_range: snippet.declaration.signature_range_in_item_text(),
2370 parent_index,
2371 signature_score: snippet.score(DeclarationStyle::Signature),
2372 declaration_score: snippet.score(DeclarationStyle::Declaration),
2373 score_components: snippet.components,
2374 });
2375 }
2376
2377 let excerpt_parent = index_state.and_then(|index_state| {
2378 context
2379 .excerpt
2380 .parent_declarations
2381 .last()
2382 .and_then(|(parent, _)| {
2383 add_signature(
2384 *parent,
2385 &mut declaration_to_signature_index,
2386 &mut signatures,
2387 index_state,
2388 )
2389 })
2390 });
2391
2392 predict_edits_v3::PredictEditsRequest {
2393 excerpt_path,
2394 excerpt: context.excerpt_text.body,
2395 excerpt_line_range: context.excerpt.line_range,
2396 excerpt_range: context.excerpt.range,
2397 cursor_point: predict_edits_v3::Point {
2398 line: predict_edits_v3::Line(context.cursor_point.row),
2399 column: context.cursor_point.column,
2400 },
2401 referenced_declarations,
2402 included_files: vec![],
2403 signatures,
2404 excerpt_parent,
2405 events,
2406 can_collect_data,
2407 diagnostic_groups,
2408 diagnostic_groups_truncated,
2409 git_info,
2410 debug_info,
2411 prompt_max_bytes,
2412 prompt_format,
2413 }
2414}
2415
2416fn add_signature(
2417 declaration_id: DeclarationId,
2418 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2419 signatures: &mut Vec<Signature>,
2420 index: &SyntaxIndexState,
2421) -> Option<usize> {
2422 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2423 return Some(*signature_index);
2424 }
2425 let Some(parent_declaration) = index.declaration(declaration_id) else {
2426 log::error!("bug: missing parent declaration");
2427 return None;
2428 };
2429 let parent_index = parent_declaration.parent().and_then(|parent| {
2430 add_signature(parent, declaration_to_signature_index, signatures, index)
2431 });
2432 let (text, text_is_truncated) = parent_declaration.signature_text();
2433 let signature_index = signatures.len();
2434 signatures.push(Signature {
2435 text: text.into(),
2436 text_is_truncated,
2437 parent_index,
2438 range: parent_declaration.signature_line_range(),
2439 });
2440 declaration_to_signature_index.insert(declaration_id, signature_index);
2441 Some(signature_index)
2442}
2443
2444#[cfg(feature = "eval-support")]
2445pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2446
2447#[cfg(feature = "eval-support")]
2448#[derive(Debug, Clone, Copy, PartialEq)]
2449pub enum EvalCacheEntryKind {
2450 Context,
2451 Search,
2452 Prediction,
2453}
2454
2455#[cfg(feature = "eval-support")]
2456impl std::fmt::Display for EvalCacheEntryKind {
2457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2458 match self {
2459 EvalCacheEntryKind::Search => write!(f, "search"),
2460 EvalCacheEntryKind::Context => write!(f, "context"),
2461 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2462 }
2463 }
2464}
2465
2466#[cfg(feature = "eval-support")]
2467pub trait EvalCache: Send + Sync {
2468 fn read(&self, key: EvalCacheKey) -> Option<String>;
2469 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2470}
2471
2472#[cfg(test)]
2473mod tests {
2474 use std::{path::Path, sync::Arc};
2475
2476 use client::UserStore;
2477 use clock::FakeSystemClock;
2478 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2479 use futures::{
2480 AsyncReadExt, StreamExt,
2481 channel::{mpsc, oneshot},
2482 };
2483 use gpui::{
2484 Entity, TestAppContext,
2485 http_client::{FakeHttpClient, Response},
2486 prelude::*,
2487 };
2488 use indoc::indoc;
2489 use language::OffsetRangeExt as _;
2490 use open_ai::Usage;
2491 use pretty_assertions::{assert_eq, assert_matches};
2492 use project::{FakeFs, Project};
2493 use serde_json::json;
2494 use settings::SettingsStore;
2495 use util::path;
2496 use uuid::Uuid;
2497
2498 use crate::{BufferEditPrediction, Zeta};
2499
2500 #[gpui::test]
2501 async fn test_current_state(cx: &mut TestAppContext) {
2502 let (zeta, mut req_rx) = init_test(cx);
2503 let fs = FakeFs::new(cx.executor());
2504 fs.insert_tree(
2505 "/root",
2506 json!({
2507 "1.txt": "Hello!\nHow\nBye\n",
2508 "2.txt": "Hola!\nComo\nAdios\n"
2509 }),
2510 )
2511 .await;
2512 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2513
2514 zeta.update(cx, |zeta, cx| {
2515 zeta.register_project(&project, cx);
2516 });
2517
2518 let buffer1 = project
2519 .update(cx, |project, cx| {
2520 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2521 project.open_buffer(path, cx)
2522 })
2523 .await
2524 .unwrap();
2525 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2526 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2527
2528 // Prediction for current file
2529
2530 zeta.update(cx, |zeta, cx| {
2531 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2532 });
2533 let (_request, respond_tx) = req_rx.next().await.unwrap();
2534
2535 respond_tx
2536 .send(model_response(indoc! {r"
2537 --- a/root/1.txt
2538 +++ b/root/1.txt
2539 @@ ... @@
2540 Hello!
2541 -How
2542 +How are you?
2543 Bye
2544 "}))
2545 .unwrap();
2546
2547 cx.run_until_parked();
2548
2549 zeta.read_with(cx, |zeta, cx| {
2550 let prediction = zeta
2551 .current_prediction_for_buffer(&buffer1, &project, cx)
2552 .unwrap();
2553 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2554 });
2555
2556 // Context refresh
2557 let refresh_task = zeta.update(cx, |zeta, cx| {
2558 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2559 });
2560 let (_request, respond_tx) = req_rx.next().await.unwrap();
2561 respond_tx
2562 .send(open_ai::Response {
2563 id: Uuid::new_v4().to_string(),
2564 object: "response".into(),
2565 created: 0,
2566 model: "model".into(),
2567 choices: vec![open_ai::Choice {
2568 index: 0,
2569 message: open_ai::RequestMessage::Assistant {
2570 content: None,
2571 tool_calls: vec![open_ai::ToolCall {
2572 id: "search".into(),
2573 content: open_ai::ToolCallContent::Function {
2574 function: open_ai::FunctionContent {
2575 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2576 .to_string(),
2577 arguments: serde_json::to_string(&SearchToolInput {
2578 queries: Box::new([SearchToolQuery {
2579 glob: "root/2.txt".to_string(),
2580 syntax_node: vec![],
2581 content: Some(".".into()),
2582 }]),
2583 })
2584 .unwrap(),
2585 },
2586 },
2587 }],
2588 },
2589 finish_reason: None,
2590 }],
2591 usage: Usage {
2592 prompt_tokens: 0,
2593 completion_tokens: 0,
2594 total_tokens: 0,
2595 },
2596 })
2597 .unwrap();
2598 refresh_task.await.unwrap();
2599
2600 zeta.update(cx, |zeta, _cx| {
2601 zeta.discard_current_prediction(&project);
2602 });
2603
2604 // Prediction for another file
2605 zeta.update(cx, |zeta, cx| {
2606 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2607 });
2608 let (_request, respond_tx) = req_rx.next().await.unwrap();
2609 respond_tx
2610 .send(model_response(indoc! {r#"
2611 --- a/root/2.txt
2612 +++ b/root/2.txt
2613 Hola!
2614 -Como
2615 +Como estas?
2616 Adios
2617 "#}))
2618 .unwrap();
2619 cx.run_until_parked();
2620
2621 zeta.read_with(cx, |zeta, cx| {
2622 let prediction = zeta
2623 .current_prediction_for_buffer(&buffer1, &project, cx)
2624 .unwrap();
2625 assert_matches!(
2626 prediction,
2627 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2628 );
2629 });
2630
2631 let buffer2 = project
2632 .update(cx, |project, cx| {
2633 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2634 project.open_buffer(path, cx)
2635 })
2636 .await
2637 .unwrap();
2638
2639 zeta.read_with(cx, |zeta, cx| {
2640 let prediction = zeta
2641 .current_prediction_for_buffer(&buffer2, &project, cx)
2642 .unwrap();
2643 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2644 });
2645 }
2646
2647 #[gpui::test]
2648 async fn test_simple_request(cx: &mut TestAppContext) {
2649 let (zeta, mut req_rx) = init_test(cx);
2650 let fs = FakeFs::new(cx.executor());
2651 fs.insert_tree(
2652 "/root",
2653 json!({
2654 "foo.md": "Hello!\nHow\nBye\n"
2655 }),
2656 )
2657 .await;
2658 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2659
2660 let buffer = project
2661 .update(cx, |project, cx| {
2662 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2663 project.open_buffer(path, cx)
2664 })
2665 .await
2666 .unwrap();
2667 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2668 let position = snapshot.anchor_before(language::Point::new(1, 3));
2669
2670 let prediction_task = zeta.update(cx, |zeta, cx| {
2671 zeta.request_prediction(&project, &buffer, position, cx)
2672 });
2673
2674 let (_, respond_tx) = req_rx.next().await.unwrap();
2675
2676 // TODO Put back when we have a structured request again
2677 // assert_eq!(
2678 // request.excerpt_path.as_ref(),
2679 // Path::new(path!("root/foo.md"))
2680 // );
2681 // assert_eq!(
2682 // request.cursor_point,
2683 // Point {
2684 // line: Line(1),
2685 // column: 3
2686 // }
2687 // );
2688
2689 respond_tx
2690 .send(model_response(indoc! { r"
2691 --- a/root/foo.md
2692 +++ b/root/foo.md
2693 @@ ... @@
2694 Hello!
2695 -How
2696 +How are you?
2697 Bye
2698 "}))
2699 .unwrap();
2700
2701 let prediction = prediction_task.await.unwrap().unwrap();
2702
2703 assert_eq!(prediction.edits.len(), 1);
2704 assert_eq!(
2705 prediction.edits[0].0.to_point(&snapshot).start,
2706 language::Point::new(1, 3)
2707 );
2708 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2709 }
2710
2711 #[gpui::test]
2712 async fn test_request_events(cx: &mut TestAppContext) {
2713 let (zeta, mut req_rx) = init_test(cx);
2714 let fs = FakeFs::new(cx.executor());
2715 fs.insert_tree(
2716 "/root",
2717 json!({
2718 "foo.md": "Hello!\n\nBye\n"
2719 }),
2720 )
2721 .await;
2722 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2723
2724 let buffer = project
2725 .update(cx, |project, cx| {
2726 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2727 project.open_buffer(path, cx)
2728 })
2729 .await
2730 .unwrap();
2731
2732 zeta.update(cx, |zeta, cx| {
2733 zeta.register_buffer(&buffer, &project, cx);
2734 });
2735
2736 buffer.update(cx, |buffer, cx| {
2737 buffer.edit(vec![(7..7, "How")], None, cx);
2738 });
2739
2740 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2741 let position = snapshot.anchor_before(language::Point::new(1, 3));
2742
2743 let prediction_task = zeta.update(cx, |zeta, cx| {
2744 zeta.request_prediction(&project, &buffer, position, cx)
2745 });
2746
2747 let (request, respond_tx) = req_rx.next().await.unwrap();
2748
2749 let prompt = prompt_from_request(&request);
2750 assert!(
2751 prompt.contains(indoc! {"
2752 --- a/root/foo.md
2753 +++ b/root/foo.md
2754 @@ -1,3 +1,3 @@
2755 Hello!
2756 -
2757 +How
2758 Bye
2759 "}),
2760 "{prompt}"
2761 );
2762
2763 respond_tx
2764 .send(model_response(indoc! {r#"
2765 --- a/root/foo.md
2766 +++ b/root/foo.md
2767 @@ ... @@
2768 Hello!
2769 -How
2770 +How are you?
2771 Bye
2772 "#}))
2773 .unwrap();
2774
2775 let prediction = prediction_task.await.unwrap().unwrap();
2776
2777 assert_eq!(prediction.edits.len(), 1);
2778 assert_eq!(
2779 prediction.edits[0].0.to_point(&snapshot).start,
2780 language::Point::new(1, 3)
2781 );
2782 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2783 }
2784
2785 // Skipped until we start including diagnostics in prompt
2786 // #[gpui::test]
2787 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2788 // let (zeta, mut req_rx) = init_test(cx);
2789 // let fs = FakeFs::new(cx.executor());
2790 // fs.insert_tree(
2791 // "/root",
2792 // json!({
2793 // "foo.md": "Hello!\nBye"
2794 // }),
2795 // )
2796 // .await;
2797 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2798
2799 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2800 // let diagnostic = lsp::Diagnostic {
2801 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2802 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2803 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2804 // ..Default::default()
2805 // };
2806
2807 // project.update(cx, |project, cx| {
2808 // project.lsp_store().update(cx, |lsp_store, cx| {
2809 // // Create some diagnostics
2810 // lsp_store
2811 // .update_diagnostics(
2812 // LanguageServerId(0),
2813 // lsp::PublishDiagnosticsParams {
2814 // uri: path_to_buffer_uri.clone(),
2815 // diagnostics: vec![diagnostic],
2816 // version: None,
2817 // },
2818 // None,
2819 // language::DiagnosticSourceKind::Pushed,
2820 // &[],
2821 // cx,
2822 // )
2823 // .unwrap();
2824 // });
2825 // });
2826
2827 // let buffer = project
2828 // .update(cx, |project, cx| {
2829 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2830 // project.open_buffer(path, cx)
2831 // })
2832 // .await
2833 // .unwrap();
2834
2835 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2836 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2837
2838 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2839 // zeta.request_prediction(&project, &buffer, position, cx)
2840 // });
2841
2842 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2843
2844 // assert_eq!(request.diagnostic_groups.len(), 1);
2845 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2846 // .unwrap();
2847 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2848 // assert_eq!(
2849 // value,
2850 // json!({
2851 // "entries": [{
2852 // "range": {
2853 // "start": 8,
2854 // "end": 10
2855 // },
2856 // "diagnostic": {
2857 // "source": null,
2858 // "code": null,
2859 // "code_description": null,
2860 // "severity": 1,
2861 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2862 // "markdown": null,
2863 // "group_id": 0,
2864 // "is_primary": true,
2865 // "is_disk_based": false,
2866 // "is_unnecessary": false,
2867 // "source_kind": "Pushed",
2868 // "data": null,
2869 // "underline": true
2870 // }
2871 // }],
2872 // "primary_ix": 0
2873 // })
2874 // );
2875 // }
2876
2877 fn model_response(text: &str) -> open_ai::Response {
2878 open_ai::Response {
2879 id: Uuid::new_v4().to_string(),
2880 object: "response".into(),
2881 created: 0,
2882 model: "model".into(),
2883 choices: vec![open_ai::Choice {
2884 index: 0,
2885 message: open_ai::RequestMessage::Assistant {
2886 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2887 tool_calls: vec![],
2888 },
2889 finish_reason: None,
2890 }],
2891 usage: Usage {
2892 prompt_tokens: 0,
2893 completion_tokens: 0,
2894 total_tokens: 0,
2895 },
2896 }
2897 }
2898
2899 fn prompt_from_request(request: &open_ai::Request) -> &str {
2900 assert_eq!(request.messages.len(), 1);
2901 let open_ai::RequestMessage::User {
2902 content: open_ai::MessageContent::Plain(content),
2903 ..
2904 } = &request.messages[0]
2905 else {
2906 panic!(
2907 "Request does not have single user message of type Plain. {:#?}",
2908 request
2909 );
2910 };
2911 content
2912 }
2913
2914 fn init_test(
2915 cx: &mut TestAppContext,
2916 ) -> (
2917 Entity<Zeta>,
2918 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2919 ) {
2920 cx.update(move |cx| {
2921 let settings_store = SettingsStore::test(cx);
2922 cx.set_global(settings_store);
2923 zlog::init_test();
2924
2925 let (req_tx, req_rx) = mpsc::unbounded();
2926
2927 let http_client = FakeHttpClient::create({
2928 move |req| {
2929 let uri = req.uri().path().to_string();
2930 let mut body = req.into_body();
2931 let req_tx = req_tx.clone();
2932 async move {
2933 let resp = match uri.as_str() {
2934 "/client/llm_tokens" => serde_json::to_string(&json!({
2935 "token": "test"
2936 }))
2937 .unwrap(),
2938 "/predict_edits/raw" => {
2939 let mut buf = Vec::new();
2940 body.read_to_end(&mut buf).await.ok();
2941 let req = serde_json::from_slice(&buf).unwrap();
2942
2943 let (res_tx, res_rx) = oneshot::channel();
2944 req_tx.unbounded_send((req, res_tx)).unwrap();
2945 serde_json::to_string(&res_rx.await?).unwrap()
2946 }
2947 _ => {
2948 panic!("Unexpected path: {}", uri)
2949 }
2950 };
2951
2952 Ok(Response::builder().body(resp.into()).unwrap())
2953 }
2954 }
2955 });
2956
2957 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2958 client.cloud_client().set_credentials(1, "test".into());
2959
2960 language_model::init(client.clone(), cx);
2961
2962 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2963 let zeta = Zeta::global(&client, &user_store, cx);
2964
2965 (zeta, req_rx)
2966 })
2967 }
2968}